diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2d6f6ac89d..5695d5aff8 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -53,6 +53,7 @@ import synapse
import logging
import os
import re
+import resource
import subprocess
import sqlite3
import syweb
@@ -146,8 +147,8 @@ class SynapseHomeServer(HomeServer):
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
- for (full_path, resource) in desired_tree:
- logger.info("Attaching %s to path %s", resource, full_path)
+ for full_path, res in desired_tree:
+ logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
@@ -178,12 +179,12 @@ class SynapseHomeServer(HomeServer):
child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
- resource.putChild(child_name, child_resource)
+ res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
- last_resource.putChild(last_path_seg, resource)
+ last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg)
- resource_mappings[res_id] = resource
+ resource_mappings[res_id] = res
return self.root_resource
@@ -275,6 +276,20 @@ def get_version_string():
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
+def change_resource_limit(soft_file_no):
+ try:
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+
+ if not soft_file_no:
+ soft_file_no = hard
+
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
+
+ logger.info("Set file limit to: %d", soft_file_no)
+ except (ValueError, resource.error) as e:
+ logger.warn("Failed to set file limit: %s", e)
+
+
def setup():
config = HomeServerConfig.load_config(
"Synapse Homeserver",
@@ -351,10 +366,11 @@ def setup():
if config.daemonize:
print config.pid_file
+
daemon = Daemonize(
app="synapse-homeserver",
pid=config.pid_file,
- action=run,
+ action=lambda: run(config),
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -362,11 +378,13 @@ def setup():
daemon.start()
else:
- reactor.run()
+ run(config)
-def run():
+def run(config):
with LoggingContext("run"):
+ change_resource_limit(config.soft_file_limit)
+
reactor.run()
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 381b4cfc4a..a268a6bcc4 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -46,22 +46,34 @@ class ApplicationService(object):
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
# {
- # users: ["regex",...],
- # aliases: ["regex",...],
- # rooms: ["regex",...],
+ # users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
+ # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
+ # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
return None
for ns in ApplicationService.NS_LIST:
+ if ns not in namespaces:
+ namespaces[ns] = []
+ continue
+
if type(namespaces[ns]) != list:
- raise ValueError("Bad namespace value for '%s'", ns)
- for regex in namespaces[ns]:
- if not isinstance(regex, basestring):
- raise ValueError("Expected string regex for ns '%s'", ns)
+ raise ValueError("Bad namespace value for '%s'" % ns)
+ for regex_obj in namespaces[ns]:
+ if not isinstance(regex_obj, dict):
+ raise ValueError("Expected dict regex for ns '%s'" % ns)
+ if not isinstance(regex_obj.get("exclusive"), bool):
+ raise ValueError(
+ "Expected bool for 'exclusive' in ns '%s'" % ns
+ )
+ if not isinstance(regex_obj.get("regex"), basestring):
+ raise ValueError(
+ "Expected string for 'regex' in ns '%s'" % ns
+ )
return namespaces
- def _matches_regex(self, test_string, namespace_key):
+ def _matches_regex(self, test_string, namespace_key, return_obj=False):
if not isinstance(test_string, basestring):
logger.error(
"Expected a string to test regex against, but got %s",
@@ -69,11 +81,19 @@ class ApplicationService(object):
)
return False
- for regex in self.namespaces[namespace_key]:
- if re.match(regex, test_string):
+ for regex_obj in self.namespaces[namespace_key]:
+ if re.match(regex_obj["regex"], test_string):
+ if return_obj:
+ return regex_obj
return True
return False
+ def _is_exclusive(self, ns_key, test_string):
+ regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+ if regex_obj:
+ return regex_obj["exclusive"]
+ return False
+
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
@@ -143,5 +163,14 @@ class ApplicationService(object):
def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+ def is_exclusive_user(self, user_id):
+ return self._is_exclusive(ApplicationService.NS_USERS, user_id)
+
+ def is_exclusive_alias(self, alias):
+ return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
+
+ def is_exclusive_room(self, room_id):
+ return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
+
def __str__(self):
return "ApplicationService: %s" % (self.__dict__,)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 17c7e64ce7..862c07ef8c 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -22,6 +22,12 @@ class RatelimitConfig(Config):
self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count
+ self.federation_rc_window_size = args.federation_rc_window_size
+ self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
+ self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
+ self.federation_rc_reject_limit = args.federation_rc_reject_limit
+ self.federation_rc_concurrent = args.federation_rc_concurrent
+
@classmethod
def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser)
@@ -34,3 +40,33 @@ class RatelimitConfig(Config):
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)
+
+ rc_group.add_argument(
+ "--federation-rc-window-size", type=int, default=10000,
+ help="The federation window size in milliseconds",
+ )
+
+ rc_group.add_argument(
+ "--federation-rc-sleep-limit", type=int, default=10,
+ help="The number of federation requests from a single server"
+ " in a window before the server will delay processing the"
+ " request.",
+ )
+
+ rc_group.add_argument(
+ "--federation-rc-sleep-delay", type=int, default=500,
+ help="The duration in milliseconds to delay processing events from"
+ " remote servers by if they go over the sleep limit.",
+ )
+
+ rc_group.add_argument(
+ "--federation-rc-reject-limit", type=int, default=50,
+ help="The maximum number of concurrent federation requests allowed"
+ " from a single server",
+ )
+
+ rc_group.add_argument(
+ "--federation-rc-concurrent", type=int, default=3,
+ help="The number of federation requests to concurrently process"
+ " from a single server",
+ )
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 31e44cc857..4e4892d40b 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -31,6 +31,7 @@ class ServerConfig(Config):
self.webclient = True
self.manhole = args.manhole
self.no_tls = args.no_tls
+ self.soft_file_limit = args.soft_file_limit
if not args.content_addr:
host = args.server_name
@@ -77,6 +78,12 @@ class ServerConfig(Config):
"content repository")
server_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
+ server_group.add_argument("--soft-file-limit", type=int, default=0,
+ help="Set the soft limit on the number of "
+ "file descriptors synapse can use. "
+ "Zero is used to indicate synapse "
+ "should set the soft limit to the hard"
+ "limit.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index cd3c962d50..ca89a0787c 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -439,6 +439,25 @@ class FederationClient(FederationBase):
defer.returnValue(ret)
+ @defer.inlineCallbacks
+ def get_missing_events(self, destination, room_id, earliest_events,
+ latest_events, limit, min_depth):
+ content = yield self.transport_layer.get_missing_events(
+ destination, room_id, earliest_events, latest_events, limit,
+ min_depth,
+ )
+
+ events = [
+ self.event_from_pdu_json(e)
+ for e in content.get("events", [])
+ ]
+
+ signed_events = yield self._check_sigs_and_hash_and_fetch(
+ destination, events, outlier=True
+ )
+
+ defer.returnValue(signed_events)
+
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 22b9663831..4264d857be 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -112,17 +112,20 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
- dl = []
+ results = []
+
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
- def handle_failure(failure):
- failure.trap(FederationError)
- self.send_failure(failure.value, transaction.origin)
-
- d.addErrback(handle_failure)
-
- dl.append(d)
+ try:
+ yield d
+ results.append({})
+ except FederationError as e:
+ self.send_failure(e, transaction.origin)
+ results.append({"error": str(e)})
+ except Exception as e:
+ results.append({"error": str(e)})
+ logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
@@ -135,21 +138,11 @@ class FederationServer(FederationBase):
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
- results = yield defer.DeferredList(dl, consumeErrors=True)
-
- ret = []
- for r in results:
- if r[0]:
- ret.append({})
- else:
- logger.exception(r[1])
- ret.append({"error": str(r[1].value)})
-
- logger.debug("Returning: %s", str(ret))
+ logger.debug("Returning: %s", str(results))
response = {
"pdus": dict(zip(
- (p.event_id for p in pdu_list), ret
+ (p.event_id for p in pdu_list), results
)),
}
@@ -305,6 +298,20 @@ class FederationServer(FederationBase):
(200, send_content)
)
+ @defer.inlineCallbacks
+ @log_function
+ def on_get_missing_events(self, origin, room_id, earliest_events,
+ latest_events, limit, min_depth):
+ missing_events = yield self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit, min_depth
+ )
+
+ time_now = self._clock.time_msec()
+
+ defer.returnValue({
+ "events": [ev.get_pdu_json(time_now) for ev in missing_events],
+ })
+
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
@@ -331,7 +338,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, origin, pdu, max_recursion=10):
+ def _handle_new_pdu(self, origin, pdu, get_missing=True):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
@@ -383,48 +390,50 @@ class FederationServer(FederationBase):
pdu.room_id, min_depth
)
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = set(have_seen.keys())
+
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
- elif min_depth and pdu.depth > min_depth and max_recursion > 0:
- for event_id, hashes in pdu.prev_events:
- if event_id not in have_seen:
- logger.debug(
- "_handle_new_pdu requesting pdu %s",
- event_id
+ elif min_depth and pdu.depth > min_depth:
+ if get_missing and prevs - seen:
+ latest_tuples = yield self.store.get_latest_events_in_room(
+ pdu.room_id
+ )
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(e_id for e_id, _, _ in latest_tuples)
+ latest |= seen
+
+ missing_events = yield self.get_missing_events(
+ origin,
+ pdu.room_id,
+ earliest_events=list(latest),
+ latest_events=[pdu.event_id],
+ limit=10,
+ min_depth=min_depth,
+ )
+
+ for e in missing_events:
+ yield self._handle_new_pdu(
+ origin,
+ e,
+ get_missing=False
)
- try:
- new_pdu = yield self.federation_client.get_pdu(
- [origin, pdu.origin],
- event_id=event_id,
- )
-
- if new_pdu:
- yield self._handle_new_pdu(
- origin,
- new_pdu,
- max_recursion=max_recursion-1
- )
-
- logger.debug("Processed pdu %s", event_id)
- else:
- logger.warn("Failed to get PDU %s", event_id)
- fetch_state = True
- except:
- # TODO(erikj): Do some more intelligent retries.
- logger.exception("Failed to get PDU")
- fetch_state = True
- else:
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
- if prevs - seen:
- fetch_state = True
- else:
- fetch_state = True
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
+
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = set(have_seen.keys())
+ if prevs - seen:
+ fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 7d30c924d1..741a4e7a1a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -224,6 +224,8 @@ class TransactionQueue(object):
]
try:
+ self.pending_transactions[destination] = 1
+
limiter = yield get_retry_limiter(
destination,
self._clock,
@@ -239,8 +241,6 @@ class TransactionQueue(object):
len(pending_failures)
)
- self.pending_transactions[destination] = 1
-
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
@@ -287,7 +287,7 @@ class TransactionQueue(object):
code = 200
if response:
- for e_id, r in getattr(response, "pdus", {}).items():
+ for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 6800ac46c5..2a671b9aec 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
from .server import TransportLayerServer
from .client import TransportLayerClient
+from synapse.util.ratelimitutils import FederationRateLimiter
+
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
@@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests
"""
self.keyring = homeserver.get_keyring()
+ self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
+
+ self.ratelimiter = FederationRateLimiter(
+ self.clock,
+ window_size=homeserver.config.federation_rc_window_size,
+ sleep_limit=homeserver.config.federation_rc_sleep_limit,
+ sleep_msec=homeserver.config.federation_rc_sleep_delay,
+ reject_limit=homeserver.config.federation_rc_reject_limit,
+ concurrent_requests=homeserver.config.federation_rc_concurrent,
+ )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8b137e7128..80d03012b7 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -219,3 +219,22 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_missing_events(self, destination, room_id, earliest_events,
+ latest_events, limit, min_depth):
+ path = PREFIX + "/get_missing_events/%s" % (room_id,)
+
+ content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data={
+ "limit": int(limit),
+ "min_depth": int(min_depth),
+ "earliest_events": earliest_events,
+ "latest_events": latest_events,
+ }
+ )
+
+ defer.returnValue(content)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2ffb37aa18..ece6dbcf62 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -98,15 +98,23 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs):
try:
(origin, content) = yield self._authenticate_request(request)
- response = yield handler(
- origin, content, request.args, *args, **kwargs
- )
+ with self.ratelimiter.ratelimit(origin) as d:
+ yield d
+ response = yield handler(
+ origin, content, request.args, *args, **kwargs
+ )
except:
logger.exception("_authenticate_request failed")
raise
defer.returnValue(response)
return new_handler
+ def rate_limit_origin(self, handler):
+ def new_handler(origin, *args, **kwargs):
+ response = yield handler(origin, *args, **kwargs)
+ defer.returnValue(response)
+ return new_handler()
+
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@@ -234,6 +242,7 @@ class TransportLayerServer(object):
)
)
)
+
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
@@ -245,6 +254,17 @@ class TransportLayerServer(object):
)
)
+ self.server.register_path(
+ "POST",
+ re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
+ self._with_authentication(
+ lambda origin, content, query, room_id:
+ self._get_missing_events(
+ origin, content, room_id,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -344,3 +364,22 @@ class TransportLayerServer(object):
)
defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _get_missing_events(self, origin, content, room_id):
+ limit = int(content.get("limit", 10))
+ min_depth = int(content.get("min_depth", 0))
+ earliest_events = content.get("earliest_events", [])
+ latest_events = content.get("latest_events", [])
+
+ content = yield self.request_handler.on_get_missing_events(
+ origin,
+ room_id=room_id,
+ earliest_events=earliest_events,
+ latest_events=latest_events,
+ min_depth=min_depth,
+ limit=limit,
+ )
+
+ defer.returnValue((200, content))
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 11d20a5d2d..f76febee8f 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -232,13 +232,23 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None):
+ # Any application service "interested" in an alias they are regexing on
+ # can modify the alias.
+ # Users can only modify the alias if ALL the interested services have
+ # non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string())
]
+
for service in interested_services:
if user_id == service.sender:
- # this user IS the app service
+ # this user IS the app service so they can do whatever they like
defer.returnValue(True)
return
- defer.returnValue(len(interested_services) == 0)
+ elif service.is_exclusive_alias(alias.to_string()):
+ # another service has an exclusive lock on this alias.
+ defer.returnValue(False)
+ return
+ # either no interested services, or no service with an exclusive lock
+ defer.returnValue(True)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 025e7e7e62..8d5f5c8499 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -69,9 +69,6 @@ class EventStreamHandler(BaseHandler):
)
self._streams_per_user[auth_user] += 1
- if pagin_config.from_token is None:
- pagin_config.from_token = None
-
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7deed16f9c..ae4e9b316d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -581,12 +581,13 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
- def get_state_for_pdu(self, origin, room_id, event_id):
+ def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor()
- in_room = yield self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
+ if do_auth:
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
[event_id]
@@ -789,6 +790,29 @@ class FederationHandler(BaseHandler):
defer.returnValue(ret)
@defer.inlineCallbacks
+ def on_get_missing_events(self, origin, room_id, earliest_events,
+ latest_events, limit, min_depth):
+ in_room = yield self.auth.check_host_in_room(
+ room_id,
+ origin
+ )
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ limit = min(limit, 20)
+ min_depth = max(min_depth, 0)
+
+ missing_events = yield self.store.get_missing_events(
+ room_id=room_id,
+ earliest_events=earliest_events,
+ latest_events=latest_events,
+ limit=limit,
+ min_depth=min_depth,
+ )
+
+ defer.returnValue(missing_events)
+
+ @defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 516a936cee..cda4a8502a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -201,11 +201,12 @@ class RegistrationHandler(BaseHandler):
interested_services = [
s for s in services if s.is_interested_in_user(user_id)
]
- if len(interested_services) > 0:
- raise SynapseError(
- 400, "This user ID is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
- )
+ for service in interested_services:
+ if service.is_exclusive_user(user_id):
+ raise SynapseError(
+ 400, "This user ID is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE
+ )
def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 914742d913..80f7ee3f12 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -510,9 +510,16 @@ class RoomMemberHandler(BaseHandler):
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
"""Returns a list of roomids that the user has any of the given
membership states in."""
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user.to_string(), membership_list=membership_list
+
+ app_service = yield self.store.get_app_service_by_user_id(
+ user.to_string()
)
+ if app_service:
+ rooms = yield self.store.get_app_service_rooms(app_service)
+ else:
+ rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=user.to_string(), membership_list=membership_list
+ )
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
@@ -559,13 +566,24 @@ class RoomEventSource(object):
to_key = yield self.get_current_key()
- events, end_key = yield self.store.get_room_events_stream(
- user_id=user.to_string(),
- from_key=from_key,
- to_key=to_key,
- room_id=None,
- limit=limit,
+ app_service = yield self.store.get_app_service_by_user_id(
+ user.to_string()
)
+ if app_service:
+ events, end_key = yield self.store.get_appservice_room_stream(
+ service=app_service,
+ from_key=from_key,
+ to_key=to_key,
+ limit=limit,
+ )
+ else:
+ events, end_key = yield self.store.get_room_events_stream(
+ user_id=user.to_string(),
+ from_key=from_key,
+ to_key=to_key,
+ room_id=None,
+ limit=limit,
+ )
defer.returnValue((events, end_key))
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 22b7145ac1..b53a07aa2d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -143,7 +143,7 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
- json_str = json.dumps(json_body)
+ json_str = encode_canonical_json(json_body)
response = yield self.agent.request(
"PUT",
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 2475f3ffbe..09d23e79b8 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -36,8 +36,10 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class.
"""
- def __init__(self, user, rooms, from_token, limit, timeout, deferred):
+ def __init__(self, user, rooms, from_token, limit, timeout, deferred,
+ appservice=None):
self.user = user
+ self.appservice = appservice
self.from_token = from_token
self.limit = limit
self.timeout = timeout
@@ -65,6 +67,10 @@ class _NotificationListener(object):
lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self)
+ if self.appservice:
+ notifier.appservice_to_listeners.get(
+ self.appservice, set()
+ ).discard(self)
class Notifier(object):
@@ -79,6 +85,7 @@ class Notifier(object):
self.rooms_to_listeners = {}
self.user_to_listeners = {}
+ self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources()
@@ -114,6 +121,17 @@ class Notifier(object):
for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy()
+ for appservice in self.appservice_to_listeners:
+ # TODO (kegan): Redundant appservice listener checks?
+ # App services will already be in the rooms_to_listeners set, but
+ # that isn't enough. They need to be checked here in order to
+ # receive *invites* for users they are interested in. Does this
+ # make the rooms_to_listeners check somewhat obselete?
+ if appservice.is_interested(event):
+ listeners |= self.appservice_to_listeners.get(
+ appservice, set()
+ ).copy()
+
logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the
@@ -280,6 +298,10 @@ class Notifier(object):
if not from_token:
from_token = yield self.event_sources.get_current_token()
+ appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
+ user.to_string()
+ )
+
listener = _NotificationListener(
user,
rooms,
@@ -287,6 +309,7 @@ class Notifier(object):
limit,
timeout,
deferred,
+ appservice=appservice
)
def _timeout_listener():
@@ -319,6 +342,11 @@ class Notifier(object):
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
+ if listener.appservice:
+ self.appservice_to_listeners.setdefault(
+ listener.appservice, set()
+ ).add(listener)
+
@defer.inlineCallbacks
@log_function
def _check_for_updates(self, listener):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c71df75404..5fe8a825e3 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"],
- "matrix_angular_sdk>=0.6.3": ["syweb>=0.6.3"],
+ "matrix_angular_sdk>=0.6.4": ["syweb>=0.6.4"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@@ -36,8 +36,8 @@ DEPENDENCY_LINKS = [
),
github_link(
project="matrix-org/matrix-angular-sdk",
- version="v0.6.3",
- egg="matrix_angular_sdk-0.6.3",
+ version="v0.6.4",
+ egg="matrix_angular_sdk-0.6.4",
),
]
diff --git a/synapse/rest/appservice/v1/register.py b/synapse/rest/appservice/v1/register.py
index 3bd0c1220c..a4f6159773 100644
--- a/synapse/rest/appservice/v1/register.py
+++ b/synapse/rest/appservice/v1/register.py
@@ -48,18 +48,12 @@ class RegisterRestServlet(AppServiceRestServlet):
400, "Missed required keys: as_token(str) / url(str)."
)
- namespaces = {
- "users": [],
- "rooms": [],
- "aliases": []
- }
-
- if "namespaces" in params:
- self._parse_namespace(namespaces, params["namespaces"], "users")
- self._parse_namespace(namespaces, params["namespaces"], "rooms")
- self._parse_namespace(namespaces, params["namespaces"], "aliases")
-
- app_service = ApplicationService(as_token, as_url, namespaces)
+ try:
+ app_service = ApplicationService(
+ as_token, as_url, params["namespaces"]
+ )
+ except ValueError as e:
+ raise SynapseError(400, e.message)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
@@ -68,23 +62,6 @@ class RegisterRestServlet(AppServiceRestServlet):
"hs_token": hs_token
}))
- def _parse_namespace(self, target_ns, origin_ns, ns):
- if ns not in target_ns or ns not in origin_ns:
- return # nothing to parse / map through to.
-
- possible_regex_list = origin_ns[ns]
- if not type(possible_regex_list) == list:
- raise SynapseError(400, "Namespace %s isn't an array." % ns)
-
- for regex in possible_regex_list:
- if not isinstance(regex, basestring):
- raise SynapseError(
- 400, "Regex '%s' isn't a string in namespace %s" %
- (regex, ns)
- )
-
- target_ns[ns] = origin_ns[ns]
-
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d16e7b8fac..d6ec446bd2 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -74,7 +74,7 @@ SCHEMAS = [
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 13
+SCHEMA_VERSION = 14
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -633,11 +633,11 @@ def prepare_database(db_conn):
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
- if v == 10:
+ if v in (10, 14,):
raise UpgradeDatabaseException(
"No delta for version 10"
)
- sql_script = read_schema("delta/v%d" % (v))
+ sql_script = read_schema("delta/v%d" % (v,))
c.executescript(sql_script)
db_conn.commit()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c98dd36aed..3725c9795d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -450,7 +450,8 @@ class SQLBaseStore(object):
Args:
table : string giving the table name
- keyvalues : dict of column names and values to select the rows with
+ keyvalues : dict of column names and values to select the rows with,
+ or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
@@ -469,13 +470,20 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
- )
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+ txn.execute(sql, keyvalues.values())
+ else:
+ sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ ", ".join(retcols),
+ table
+ )
+ txn.execute(sql)
- txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index dc3666efd4..e30265750a 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,22 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import simplejson
+from simplejson import JSONDecodeError
from twisted.internet import defer
+from synapse.api.constants import Membership
from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService
+from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+def log_failure(failure):
+ logger.error("Failed to detect application services: %s", failure.value)
+ logger.error(failure.getTraceback())
+
+
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
self.services_cache = []
self.cache_defer = self._populate_cache()
+ self.cache_defer.addErrback(log_failure)
@defer.inlineCallbacks
def unregister_app_service(self, token):
@@ -128,11 +138,11 @@ class ApplicationServiceStore(SQLBaseStore):
)
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces:
- for regex in service.namespaces[ns_str]:
+ for regex_obj in service.namespaces[ns_str]:
txn.execute(
"INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)",
- (as_id, ns_int, regex)
+ (as_id, ns_int, simplejson.dumps(regex_obj))
)
return True
@@ -151,8 +161,31 @@ class ApplicationServiceStore(SQLBaseStore):
defer.returnValue(self.services_cache)
@defer.inlineCallbacks
+ def get_app_service_by_user_id(self, user_id):
+ """Retrieve an application service from their user ID.
+
+ All application services have associated with them a particular user ID.
+ There is no distinguishing feature on the user ID which indicates it
+ represents an application service. This function allows you to map from
+ a user ID to an application service.
+
+ Args:
+ user_id(str): The user ID to see if it is an application service.
+ Returns:
+ synapse.appservice.ApplicationService or None.
+ """
+
+ yield self.cache_defer # make sure the cache is ready
+
+ for service in self.services_cache:
+ if service.sender == user_id:
+ defer.returnValue(service)
+ return
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True):
- """Get the application service with the given token.
+ """Get the application service with the given appservice token.
Args:
token (str): The application service token.
@@ -173,6 +206,77 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table.
+ def get_app_service_rooms(self, service):
+ """Get a list of RoomsForUser for this application service.
+
+ Application services may be "interested" in lots of rooms depending on
+ the room ID, the room aliases, or the members in the room. This function
+ takes all of these into account and returns a list of RoomsForUser which
+ represent the entire list of room IDs that this application service
+ wants to know about.
+
+ Args:
+ service: The application service to get a room list for.
+ Returns:
+ A list of RoomsForUser.
+ """
+ return self.runInteraction(
+ "get_app_service_rooms",
+ self._get_app_service_rooms_txn,
+ service,
+ )
+
+ def _get_app_service_rooms_txn(self, txn, service):
+ # get all rooms matching the room ID regex.
+ room_entries = self._simple_select_list_txn(
+ txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
+ )
+ matching_room_list = set([
+ r["room_id"] for r in room_entries if
+ service.is_interested_in_room(r["room_id"])
+ ])
+
+ # resolve room IDs for matching room alias regex.
+ room_alias_mappings = self._simple_select_list_txn(
+ txn=txn, table="room_aliases", keyvalues=None,
+ retcols=["room_id", "room_alias"]
+ )
+ matching_room_list |= set([
+ r["room_id"] for r in room_alias_mappings if
+ service.is_interested_in_alias(r["room_alias"])
+ ])
+
+ # get all rooms for every user for this AS. This is scoped to users on
+ # this HS only.
+ user_list = self._simple_select_list_txn(
+ txn=txn, table="users", keyvalues=None, retcols=["name"]
+ )
+ user_list = [
+ u["name"] for u in user_list if
+ service.is_interested_in_user(u["name"])
+ ]
+ rooms_for_user_matching_user_id = set() # RoomsForUser list
+ for user_id in user_list:
+ # FIXME: This assumes this store is linked with RoomMemberStore :(
+ rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
+ txn=txn,
+ user_id=user_id,
+ membership_list=[Membership.JOIN]
+ )
+ rooms_for_user_matching_user_id |= set(rooms_for_user)
+
+ # make RoomsForUser tuples for room ids and aliases which are not in the
+ # main rooms_for_user_list - e.g. they are rooms which do not have AS
+ # registered users in it.
+ known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
+ missing_rooms_for_user = [
+ RoomsForUser(r, service.sender, "join") for r in
+ matching_room_list if r not in known_room_ids
+ ]
+ rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
+
+ return rooms_for_user_matching_user_id
+
@defer.inlineCallbacks
def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database."""
@@ -215,10 +319,12 @@ class ApplicationServiceStore(SQLBaseStore):
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
- res["regex"]
+ simplejson.loads(res["regex"])
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
+ except JSONDecodeError:
+ logger.error("Bad regex object '%s'", res["regex"])
# TODO get last successful txn id f.e. service
for service in services.values():
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3fbc090224..2deda8ac50 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -64,6 +64,9 @@ class EventFederationStore(SQLBaseStore):
for f in front:
txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()])
+
+ new_front -= results
+
front = new_front
results.update(front)
@@ -378,3 +381,51 @@ class EventFederationStore(SQLBaseStore):
event_results += new_front
return self._get_events_txn(txn, event_results)
+
+ def get_missing_events(self, room_id, earliest_events, latest_events,
+ limit, min_depth):
+ return self.runInteraction(
+ "get_missing_events",
+ self._get_missing_events,
+ room_id, earliest_events, latest_events, limit, min_depth
+ )
+
+ def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
+ limit, min_depth):
+
+ earliest_events = set(earliest_events)
+ front = set(latest_events) - earliest_events
+
+ event_results = set()
+
+ query = (
+ "SELECT prev_event_id FROM event_edges "
+ "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+ "LIMIT ?"
+ )
+
+ while front and len(event_results) < limit:
+ new_front = set()
+ for event_id in front:
+ txn.execute(
+ query,
+ (room_id, event_id, limit - len(event_results))
+ )
+
+ for e_id, in txn.fetchall():
+ new_front.add(e_id)
+
+ new_front -= earliest_events
+ new_front -= event_results
+
+ front = new_front
+ event_results |= new_front
+
+ events = self._get_events_txn(txn, event_results)
+
+ events = sorted(
+ [ev for ev in events if ev.depth >= min_depth],
+ key=lambda e: e.depth,
+ )
+
+ return events[:limit]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 58aa376c20..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -180,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list:
return defer.succeed(None)
+ return self.runInteraction(
+ "get_rooms_for_user_where_membership_is",
+ self._get_rooms_for_user_where_membership_is_txn,
+ user_id, membership_list
+ )
+
+ def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+ membership_list):
where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
@@ -187,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id]
args.extend(membership_list)
- def f(txn):
- sql = (
- "SELECT m.room_id, m.sender, m.membership"
- " FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership"
+ " FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id"
+ " WHERE %s"
+ ) % (where_clause,)
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- f
- )
+ txn.execute(sql, args)
+ return [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 3ccb6f8a61..09bc522210 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function
@@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
class StreamStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
+ # NB this lives here instead of appservice.py so we can reuse the
+ # 'private' StreamToken class in this file.
+ if limit:
+ limit = max(limit, MAX_STREAM_SIZE)
+ else:
+ limit = MAX_STREAM_SIZE
+
+ # From and to keys should be integers from ordering.
+ from_id = _StreamToken.parse_stream_token(from_key)
+ to_id = _StreamToken.parse_stream_token(to_key)
+
+ if from_key == to_key:
+ defer.returnValue(([], to_key))
+ return
+
+ # select all the events between from/to with a sensible limit
+ sql = (
+ "SELECT e.event_id, e.room_id, e.type, s.state_key, "
+ "e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
+ "e.event_id = s.event_id "
+ "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
+ "ORDER BY stream_ordering ASC LIMIT %(limit)d "
+ ) % {
+ "limit": limit
+ }
+
+ def f(txn):
+ # pull out all the events between the tokens
+ txn.execute(sql, (from_id.stream, to_id.stream,))
+ rows = self.cursor_to_dict(txn)
+
+ # Logic:
+ # - We want ALL events which match the AS room_id regex
+ # - We want ALL events which match the rooms represented by the AS
+ # room_alias regex
+ # - We want ALL events for rooms that AS users have joined.
+ # This is currently supported via get_app_service_rooms (which is
+ # used for the Notifier listener rooms). We can't reasonably make a
+ # SQL query for these room IDs, so we'll pull all the events between
+ # from/to and filter in python.
+ rooms_for_as = self._get_app_service_rooms_txn(txn, service)
+ room_ids_for_as = [r.room_id for r in rooms_for_as]
+
+ def app_service_interested(row):
+ if row["room_id"] in room_ids_for_as:
+ return True
+
+ if row["type"] == EventTypes.Member:
+ if service.is_interested_in_user(row.get("state_key")):
+ return True
+ return False
+
+ ret = self._get_events_txn(
+ txn,
+ # apply the filter on the room id list
+ [
+ r["event_id"] for r in rows
+ if app_service_interested(r)
+ ],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows)
+
+ if rows:
+ key = "s%d" % max(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = to_key
+
+ return ret, key
+
+ results = yield self.runInteraction("get_appservice_room_stream", f)
+ defer.returnValue(results)
+
@log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False):
@@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
self._set_before_and_after(ret, rows)
if rows:
- key = "s%d" % max([r["stream_ordering"] for r in rows])
-
+ key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
new file mode 100644
index 0000000000..d4457af950
--- /dev/null
+++ b/synapse/util/ratelimitutils.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import LimitExceededError
+
+from synapse.util.async import sleep
+
+import collections
+import contextlib
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationRateLimiter(object):
+ def __init__(self, clock, window_size, sleep_limit, sleep_msec,
+ reject_limit, concurrent_requests):
+ """
+ Args:
+ clock (Clock)
+ window_size (int): The window size in milliseconds.
+ sleep_limit (int): The number of requests received in the last
+ `window_size` milliseconds before we artificially start
+ delaying processing of requests.
+ sleep_msec (int): The number of milliseconds to delay processing
+ of incoming requests by.
+ reject_limit (int): The maximum number of requests that are can be
+ queued for processing before we start rejecting requests with
+ a 429 Too Many Requests response.
+ concurrent_requests (int): The number of concurrent requests to
+ process.
+ """
+ self.clock = clock
+
+ self.window_size = window_size
+ self.sleep_limit = sleep_limit
+ self.sleep_msec = sleep_msec
+ self.reject_limit = reject_limit
+ self.concurrent_requests = concurrent_requests
+
+ self.ratelimiters = {}
+
+ def ratelimit(self, host):
+ """Used to ratelimit an incoming request from given host
+
+ Example usage:
+
+ with rate_limiter.ratelimit(origin) as wait_deferred:
+ yield wait_deferred
+ # Handle request ...
+
+ Args:
+ host (str): Origin of incoming request.
+
+ Returns:
+ _PerHostRatelimiter
+ """
+ return self.ratelimiters.setdefault(
+ host,
+ _PerHostRatelimiter(
+ clock=self.clock,
+ window_size=self.window_size,
+ sleep_limit=self.sleep_limit,
+ sleep_msec=self.sleep_msec,
+ reject_limit=self.reject_limit,
+ concurrent_requests=self.concurrent_requests,
+ )
+ ).ratelimit()
+
+
+class _PerHostRatelimiter(object):
+ def __init__(self, clock, window_size, sleep_limit, sleep_msec,
+ reject_limit, concurrent_requests):
+ self.clock = clock
+
+ self.window_size = window_size
+ self.sleep_limit = sleep_limit
+ self.sleep_msec = sleep_msec
+ self.reject_limit = reject_limit
+ self.concurrent_requests = concurrent_requests
+
+ self.sleeping_requests = set()
+ self.ready_request_queue = collections.OrderedDict()
+ self.current_processing = set()
+ self.request_times = []
+
+ def is_empty(self):
+ time_now = self.clock.time_msec()
+ self.request_times[:] = [
+ r for r in self.request_times
+ if time_now - r < self.window_size
+ ]
+
+ return not (
+ self.ready_request_queue
+ or self.sleeping_requests
+ or self.current_processing
+ or self.request_times
+ )
+
+ @contextlib.contextmanager
+ def ratelimit(self):
+ # `contextlib.contextmanager` takes a generator and turns it into a
+ # context manager. The generator should only yield once with a value
+ # to be returned by manager.
+ # Exceptions will be reraised at the yield.
+
+ request_id = object()
+ ret = self._on_enter(request_id)
+ try:
+ yield ret
+ finally:
+ self._on_exit(request_id)
+
+ def _on_enter(self, request_id):
+ time_now = self.clock.time_msec()
+ self.request_times[:] = [
+ r for r in self.request_times
+ if time_now - r < self.window_size
+ ]
+
+ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+ if queue_size > self.reject_limit:
+ raise LimitExceededError(
+ retry_after_ms=int(
+ self.window_size / self.sleep_limit
+ ),
+ )
+
+ self.request_times.append(time_now)
+
+ def queue_request():
+ if len(self.current_processing) > self.concurrent_requests:
+ logger.debug("Ratelimit [%s]: Queue req", id(request_id))
+ queue_defer = defer.Deferred()
+ self.ready_request_queue[request_id] = queue_defer
+ return queue_defer
+ else:
+ return defer.succeed(None)
+
+ logger.debug(
+ "Ratelimit [%s]: len(self.request_times)=%d",
+ id(request_id), len(self.request_times),
+ )
+
+ if len(self.request_times) > self.sleep_limit:
+ logger.debug(
+ "Ratelimit [%s]: sleeping req",
+ id(request_id),
+ )
+ ret_defer = sleep(self.sleep_msec/1000.0)
+
+ self.sleeping_requests.add(request_id)
+
+ def on_wait_finished(_):
+ logger.debug(
+ "Ratelimit [%s]: Finished sleeping",
+ id(request_id),
+ )
+ self.sleeping_requests.discard(request_id)
+ queue_defer = queue_request()
+ return queue_defer
+
+ ret_defer.addBoth(on_wait_finished)
+ else:
+ ret_defer = queue_request()
+
+ def on_start(r):
+ logger.debug(
+ "Ratelimit [%s]: Processing req",
+ id(request_id),
+ )
+ self.current_processing.add(request_id)
+ return r
+
+ def on_err(r):
+ self.current_processing.discard(request_id)
+ return r
+
+ def on_both(r):
+ # Ensure that we've properly cleaned up.
+ self.sleeping_requests.discard(request_id)
+ self.ready_request_queue.pop(request_id, None)
+ return r
+
+ ret_defer.addCallbacks(on_start, on_err)
+ ret_defer.addBoth(on_both)
+ return ret_defer
+
+ def _on_exit(self, request_id):
+ logger.debug(
+ "Ratelimit [%s]: Processed req",
+ id(request_id),
+ )
+ self.current_processing.discard(request_id)
+ try:
+ request_id, deferred = self.ready_request_queue.popitem()
+ self.current_processing.add(request_id)
+ deferred.callback(None)
+ except KeyError:
+ pass
|