diff options
Diffstat (limited to 'synapse')
73 files changed, 1736 insertions, 890 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 74c5584a1a..f46a6df1fb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.7.1-r4" +__version__ = "0.8.0" diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 9485719332..3d43674625 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -18,6 +18,7 @@ CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" FEDERATION_PREFIX = "/_matrix/federation/v1" +STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" SERVER_KEY_PREFIX = "/_matrix/key/v1" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2c463936cb..f96535a978 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -17,7 +17,9 @@ import sys sys.dont_write_bytecode = True -from synapse.storage import prepare_database, UpgradeDatabaseException +from synapse.storage import ( + prepare_database, prepare_sqlite3_database, UpgradeDatabaseException, +) from synapse.server import HomeServer @@ -36,7 +38,8 @@ from synapse.http.server_key_resource import LocalKey from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX + SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX, + STATIC_PREFIX ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory @@ -52,6 +55,7 @@ import synapse import logging import os import re +import resource import subprocess import sqlite3 import syweb @@ -81,6 +85,9 @@ class SynapseHomeServer(HomeServer): webclient_path = os.path.join(syweb_path, "webclient") return File(webclient_path) # TODO configurable? + def build_resource_for_static_content(self): + return File("static") + def build_resource_for_content_repo(self): return ContentRepoResource( self, self.upload_dir, self.auth, self.content_addr @@ -124,7 +131,9 @@ class SynapseHomeServer(HomeServer): (SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (MEDIA_PREFIX, self.get_resource_for_media_repository()), (APP_SERVICE_PREFIX, self.get_resource_for_app_services()), + (STATIC_PREFIX, self.get_resource_for_static_content()), ] + if web_client: logger.info("Adding the web client.") desired_tree.append((WEB_CLIENT_PREFIX, @@ -140,8 +149,8 @@ class SynapseHomeServer(HomeServer): # instead, we'll store a copy of this mapping so we can actually add # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} - for (full_path, resource) in desired_tree: - logger.info("Attaching %s to path %s", resource, full_path) + for full_path, res in desired_tree: + logger.info("Attaching %s to path %s", res, full_path) last_resource = self.root_resource for path_seg in full_path.split('/')[1:-1]: if path_seg not in last_resource.listNames(): @@ -172,12 +181,12 @@ class SynapseHomeServer(HomeServer): child_name) child_resource = resource_mappings[child_res_id] # steal the children - resource.putChild(child_name, child_resource) + res.putChild(child_name, child_resource) # finally, insert the desired resource in the right place - last_resource.putChild(last_path_seg, resource) + last_resource.putChild(last_path_seg, res) res_id = self._resource_id(last_resource, last_path_seg) - resource_mappings[res_id] = resource + resource_mappings[res_id] = res return self.root_resource @@ -272,6 +281,20 @@ def get_version_string(): return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") +def change_resource_limit(soft_file_no): + try: + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + + if not soft_file_no: + soft_file_no = hard + + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) + + logger.info("Set file limit to: %d", soft_file_no) + except (ValueError, resource.error) as e: + logger.warn("Failed to set file limit: %s", e) + + def setup(): config = HomeServerConfig.load_config( "Synapse Homeserver", @@ -317,6 +340,7 @@ def setup(): try: with sqlite3.connect(db_name) as db_conn: + prepare_sqlite3_database(db_conn) prepare_database(db_conn) except UpgradeDatabaseException: sys.stderr.write( @@ -348,10 +372,11 @@ def setup(): if config.daemonize: print config.pid_file + daemon = Daemonize( app="synapse-homeserver", pid=config.pid_file, - action=run, + action=lambda: run(config), auto_close_fds=False, verbose=True, logger=logger, @@ -359,11 +384,13 @@ def setup(): daemon.start() else: - reactor.run() + run(config) -def run(): +def run(config): with LoggingContext("run"): + change_resource_limit(config.soft_file_limit) + reactor.run() diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 381b4cfc4a..a268a6bcc4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -46,22 +46,34 @@ class ApplicationService(object): def _check_namespaces(self, namespaces): # Sanity check that it is of the form: # { - # users: ["regex",...], - # aliases: ["regex",...], - # rooms: ["regex",...], + # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], + # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], + # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } if not namespaces: return None for ns in ApplicationService.NS_LIST: + if ns not in namespaces: + namespaces[ns] = [] + continue + if type(namespaces[ns]) != list: - raise ValueError("Bad namespace value for '%s'", ns) - for regex in namespaces[ns]: - if not isinstance(regex, basestring): - raise ValueError("Expected string regex for ns '%s'", ns) + raise ValueError("Bad namespace value for '%s'" % ns) + for regex_obj in namespaces[ns]: + if not isinstance(regex_obj, dict): + raise ValueError("Expected dict regex for ns '%s'" % ns) + if not isinstance(regex_obj.get("exclusive"), bool): + raise ValueError( + "Expected bool for 'exclusive' in ns '%s'" % ns + ) + if not isinstance(regex_obj.get("regex"), basestring): + raise ValueError( + "Expected string for 'regex' in ns '%s'" % ns + ) return namespaces - def _matches_regex(self, test_string, namespace_key): + def _matches_regex(self, test_string, namespace_key, return_obj=False): if not isinstance(test_string, basestring): logger.error( "Expected a string to test regex against, but got %s", @@ -69,11 +81,19 @@ class ApplicationService(object): ) return False - for regex in self.namespaces[namespace_key]: - if re.match(regex, test_string): + for regex_obj in self.namespaces[namespace_key]: + if re.match(regex_obj["regex"], test_string): + if return_obj: + return regex_obj return True return False + def _is_exclusive(self, ns_key, test_string): + regex_obj = self._matches_regex(test_string, ns_key, return_obj=True) + if regex_obj: + return regex_obj["exclusive"] + return False + def _matches_user(self, event, member_list): if (hasattr(event, "sender") and self.is_interested_in_user(event.sender)): @@ -143,5 +163,14 @@ class ApplicationService(object): def is_interested_in_room(self, room_id): return self._matches_regex(room_id, ApplicationService.NS_ROOMS) + def is_exclusive_user(self, user_id): + return self._is_exclusive(ApplicationService.NS_USERS, user_id) + + def is_exclusive_alias(self, alias): + return self._is_exclusive(ApplicationService.NS_ALIASES, alias) + + def is_exclusive_room(self, room_id): + return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + def __str__(self): return "ApplicationService: %s" % (self.__dict__,) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b0fe217459..c024535f52 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -22,11 +22,12 @@ from .repository import ContentRepositoryConfig from .captcha import CaptchaConfig from .email import EmailConfig from .voip import VoipConfig +from .registration import RegistrationConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - EmailConfig, VoipConfig): + EmailConfig, VoipConfig, RegistrationConfig,): pass diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 17c7e64ce7..862c07ef8c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -22,6 +22,12 @@ class RatelimitConfig(Config): self.rc_messages_per_second = args.rc_messages_per_second self.rc_message_burst_count = args.rc_message_burst_count + self.federation_rc_window_size = args.federation_rc_window_size + self.federation_rc_sleep_limit = args.federation_rc_sleep_limit + self.federation_rc_sleep_delay = args.federation_rc_sleep_delay + self.federation_rc_reject_limit = args.federation_rc_reject_limit + self.federation_rc_concurrent = args.federation_rc_concurrent + @classmethod def add_arguments(cls, parser): super(RatelimitConfig, cls).add_arguments(parser) @@ -34,3 +40,33 @@ class RatelimitConfig(Config): "--rc-message-burst-count", type=float, default=10, help="number of message a client can send before being throttled" ) + + rc_group.add_argument( + "--federation-rc-window-size", type=int, default=10000, + help="The federation window size in milliseconds", + ) + + rc_group.add_argument( + "--federation-rc-sleep-limit", type=int, default=10, + help="The number of federation requests from a single server" + " in a window before the server will delay processing the" + " request.", + ) + + rc_group.add_argument( + "--federation-rc-sleep-delay", type=int, default=500, + help="The duration in milliseconds to delay processing events from" + " remote servers by if they go over the sleep limit.", + ) + + rc_group.add_argument( + "--federation-rc-reject-limit", type=int, default=50, + help="The maximum number of concurrent federation requests allowed" + " from a single server", + ) + + rc_group.add_argument( + "--federation-rc-concurrent", type=int, default=3, + help="The number of federation requests to concurrently process" + " from a single server", + ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py new file mode 100644 index 0000000000..cca8ab5676 --- /dev/null +++ b/synapse/config/registration.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class RegistrationConfig(Config): + + def __init__(self, args): + super(RegistrationConfig, self).__init__(args) + self.disable_registration = args.disable_registration + + @classmethod + def add_arguments(cls, parser): + super(RegistrationConfig, cls).add_arguments(parser) + reg_group = parser.add_argument_group("registration") + reg_group.add_argument( + "--disable-registration", + action='store_true', + help="Disable registration of new users." + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index 31e44cc857..b042d4eed9 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,7 +30,7 @@ class ServerConfig(Config): self.pid_file = self.abspath(args.pid_file) self.webclient = True self.manhole = args.manhole - self.no_tls = args.no_tls + self.soft_file_limit = args.soft_file_limit if not args.content_addr: host = args.server_name @@ -75,8 +75,12 @@ class ServerConfig(Config): server_group.add_argument("--content-addr", default=None, help="The host and scheme to use for the " "content repository") - server_group.add_argument("--no-tls", action='store_true', - help="Don't bind to the https port.") + server_group.add_argument("--soft-file-limit", type=int, default=0, + help="Set the soft limit on the number of " + "file descriptors synapse can use. " + "Zero is used to indicate synapse " + "should set the soft limit to the hard" + "limit.") def read_signing_key(self, signing_key_path): signing_keys = self.read_file(signing_key_path, "signing_key") diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 384b29e7ba..034f9a7bf0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -28,9 +28,16 @@ class TlsConfig(Config): self.tls_certificate = self.read_tls_certificate( args.tls_certificate_path ) - self.tls_private_key = self.read_tls_private_key( - args.tls_private_key_path - ) + + self.no_tls = args.no_tls + + if self.no_tls: + self.tls_private_key = None + else: + self.tls_private_key = self.read_tls_private_key( + args.tls_private_key_path + ) + self.tls_dh_params_path = self.check_file( args.tls_dh_params_path, "tls_dh_params" ) @@ -45,6 +52,8 @@ class TlsConfig(Config): help="PEM encoded private key for TLS") tls_group.add_argument("--tls-dh-params-path", help="PEM dh parameters for ephemeral keys") + tls_group.add_argument("--no-tls", action='store_true', + help="Don't bind to the https port.") def read_tls_certificate(self, cert_path): cert_pem = self.read_file(cert_path, "tls_certificate") diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a2b822719f..65162d21b7 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -28,7 +28,7 @@ class VoipConfig(Config): super(VoipConfig, cls).add_arguments(parser) group = parser.add_argument_group("voip") group.add_argument( - "--turn-uris", type=str, default=None, + "--turn-uris", type=str, default=None, action='append', help="The public URIs of the TURN server to give to clients" ) group.add_argument( diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 24d4abf3e9..2f8618a0df 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -38,7 +38,10 @@ class ServerContextFactory(ssl.ContextFactory): logger.exception("Failed to enable eliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate(config.tls_certificate) - context.use_privatekey(config.tls_private_key) + + if not config.no_tls: + context.use_privatekey(config.tls_private_key) + context.load_tmp_dh(config.tls_dh_params_path) context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 828aced44a..f4db7b8a05 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -50,18 +50,27 @@ class Keyring(object): ) try: verify_key = yield self.get_server_verify_key(server_name, key_ids) - except IOError: + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) raise SynapseError( 502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED, ) - except: + except Exception as e: + logger.warn( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) raise SynapseError( 401, "No key for %s with id %s" % (server_name, key_ids), Codes.UNAUTHORIZED, ) + try: verify_signed_json(json_object, server_name, verify_key) except: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index cd3c962d50..f131941f45 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,14 +19,18 @@ from twisted.internet import defer from .federation_base import FederationBase from .units import Edu -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import ( + CodeMessageException, HttpResponseException, SynapseError, +) from synapse.util.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import itertools import logging +import random logger = logging.getLogger(__name__) @@ -439,6 +443,116 @@ class FederationClient(FederationBase): defer.returnValue(ret) + @defer.inlineCallbacks + def get_missing_events(self, destination, room_id, earliest_events_ids, + latest_events, limit, min_depth): + """Tries to fetch events we are missing. This is called when we receive + an event without having received all of its ancestors. + + Args: + destination (str) + room_id (str) + earliest_events_ids (list): List of event ids. Effectively the + events we expected to receive, but haven't. `get_missing_events` + should only return events that didn't happen before these. + latest_events (list): List of events we have received that we don't + have all previous events for. + limit (int): Maximum number of events to return. + min_depth (int): Minimum depth of events tor return. + """ + try: + content = yield self.transport_layer.get_missing_events( + destination=destination, + room_id=room_id, + earliest_events=earliest_events_ids, + latest_events=[e.event_id for e in latest_events], + limit=limit, + min_depth=min_depth, + ) + + events = [ + self.event_from_pdu_json(e) + for e in content.get("events", []) + ] + + signed_events = yield self._check_sigs_and_hash_and_fetch( + destination, events, outlier=True + ) + + have_gotten_all_from_destination = True + except HttpResponseException as e: + if not e.code == 400: + raise + + # We are probably hitting an old server that doesn't support + # get_missing_events + signed_events = [] + have_gotten_all_from_destination = False + + if len(signed_events) >= limit: + defer.returnValue(signed_events) + + servers = yield self.store.get_joined_hosts_for_room(room_id) + + servers = set(servers) + servers.discard(self.server_name) + + failed_to_fetch = set() + + while len(signed_events) < limit: + # Are we missing any? + + seen_events = set(earliest_events_ids) + seen_events.update(e.event_id for e in signed_events) + + missing_events = {} + for e in itertools.chain(latest_events, signed_events): + if e.depth > min_depth: + missing_events.update({ + e_id: e.depth for e_id, _ in e.prev_events + if e_id not in seen_events + and e_id not in failed_to_fetch + }) + + if not missing_events: + break + + have_seen = yield self.store.have_events(missing_events) + + for k in have_seen: + missing_events.pop(k, None) + + if not missing_events: + break + + # Okay, we haven't gotten everything yet. Lets get them. + ordered_missing = sorted(missing_events.items(), key=lambda x: x[0]) + + if have_gotten_all_from_destination: + servers.discard(destination) + + def random_server_list(): + srvs = list(servers) + random.shuffle(srvs) + return srvs + + deferreds = [ + self.get_pdu( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id, depth in ordered_missing[:limit - len(signed_events)] + ] + + res = yield defer.DeferredList(deferreds, consumeErrors=True) + for (result, val), (e_id, _) in zip(res, ordered_missing): + if result: + signed_events.append(val) + else: + failed_to_fetch.add(e_id) + + defer.returnValue(signed_events) + def event_from_pdu_json(self, pdu_json, outlier=False): event = FrozenEvent( pdu_json diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 22b9663831..9c7dcdba96 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -112,17 +112,20 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) with PreserveLoggingContext(): - dl = [] + results = [] + for pdu in pdu_list: d = self._handle_new_pdu(transaction.origin, pdu) - def handle_failure(failure): - failure.trap(FederationError) - self.send_failure(failure.value, transaction.origin) - - d.addErrback(handle_failure) - - dl.append(d) + try: + yield d + results.append({}) + except FederationError as e: + self.send_failure(e, transaction.origin) + results.append({"error": str(e)}) + except Exception as e: + results.append({"error": str(e)}) + logger.exception("Failed to handle PDU") if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -135,21 +138,11 @@ class FederationServer(FederationBase): for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) - results = yield defer.DeferredList(dl, consumeErrors=True) - - ret = [] - for r in results: - if r[0]: - ret.append({}) - else: - logger.exception(r[1]) - ret.append({"error": str(r[1].value)}) - - logger.debug("Returning: %s", str(ret)) + logger.debug("Returning: %s", str(results)) response = { "pdus": dict(zip( - (p.event_id for p in pdu_list), ret + (p.event_id for p in pdu_list), results )), } @@ -305,6 +298,20 @@ class FederationServer(FederationBase): (200, send_content) ) + @defer.inlineCallbacks + @log_function + def on_get_missing_events(self, origin, room_id, earliest_events, + latest_events, limit, min_depth): + missing_events = yield self.handler.on_get_missing_events( + origin, room_id, earliest_events, latest_events, limit, min_depth + ) + + time_now = self._clock.time_msec() + + defer.returnValue({ + "events": [ev.get_pdu_json(time_now) for ev in missing_events], + }) + @log_function def _get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. @@ -331,7 +338,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, origin, pdu, max_recursion=10): + def _handle_new_pdu(self, origin, pdu, get_missing=True): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu( origin, pdu.event_id, do_auth=False @@ -383,48 +390,54 @@ class FederationServer(FederationBase): pdu.room_id, min_depth ) + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this # message, to work around the fact that some events will # reference really really old events we really don't want to # send to the clients. pdu.internal_metadata.outlier = True - elif min_depth and pdu.depth > min_depth and max_recursion > 0: - for event_id, hashes in pdu.prev_events: - if event_id not in have_seen: - logger.debug( - "_handle_new_pdu requesting pdu %s", - event_id + elif min_depth and pdu.depth > min_depth: + if get_missing and prevs - seen: + latest_tuples = yield self.store.get_latest_events_in_room( + pdu.room_id + ) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(e_id for e_id, _, _ in latest_tuples) + latest |= seen + + missing_events = yield self.get_missing_events( + origin, + pdu.room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + ) + + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + + for e in missing_events: + yield self._handle_new_pdu( + origin, + e, + get_missing=False ) - try: - new_pdu = yield self.federation_client.get_pdu( - [origin, pdu.origin], - event_id=event_id, - ) - - if new_pdu: - yield self._handle_new_pdu( - origin, - new_pdu, - max_recursion=max_recursion-1 - ) - - logger.debug("Processed pdu %s", event_id) - else: - logger.warn("Failed to get PDU %s", event_id) - fetch_state = True - except: - # TODO(erikj): Do some more intelligent retries. - logger.exception("Failed to get PDU") - fetch_state = True - else: - prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) - if prevs - seen: - fetch_state = True - else: - fetch_state = True + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) + + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if prevs - seen: + fetch_state = True if fetch_state: # We need to get the state at this event, since we haven't diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7d30c924d1..741a4e7a1a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -224,6 +224,8 @@ class TransactionQueue(object): ] try: + self.pending_transactions[destination] = 1 + limiter = yield get_retry_limiter( destination, self._clock, @@ -239,8 +241,6 @@ class TransactionQueue(object): len(pending_failures) ) - self.pending_transactions[destination] = 1 - logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( @@ -287,7 +287,7 @@ class TransactionQueue(object): code = 200 if response: - for e_id, r in getattr(response, "pdus", {}).items(): + for e_id, r in response.get("pdus", {}).items(): if "error" in r: logger.warn( "Transaction returned error for %s: %s", diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index 6800ac46c5..2a671b9aec 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol. from .server import TransportLayerServer from .client import TransportLayerClient +from synapse.util.ratelimitutils import FederationRateLimiter + class TransportLayer(TransportLayerServer, TransportLayerClient): """This is a basic implementation of the transport layer that translates @@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient): send requests """ self.keyring = homeserver.get_keyring() + self.clock = homeserver.get_clock() self.server_name = server_name self.server = server self.client = client self.request_handler = None self.received_handler = None + + self.ratelimiter = FederationRateLimiter( + self.clock, + window_size=homeserver.config.federation_rc_window_size, + sleep_limit=homeserver.config.federation_rc_sleep_limit, + sleep_msec=homeserver.config.federation_rc_sleep_delay, + reject_limit=homeserver.config.federation_rc_reject_limit, + concurrent_requests=homeserver.config.federation_rc_concurrent, + ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8b137e7128..80d03012b7 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -219,3 +219,22 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def get_missing_events(self, destination, room_id, earliest_events, + latest_events, limit, min_depth): + path = PREFIX + "/get_missing_events/%s" % (room_id,) + + content = yield self.client.post_json( + destination=destination, + path=path, + data={ + "limit": int(limit), + "min_depth": int(min_depth), + "earliest_events": earliest_events, + "latest_events": latest_events, + } + ) + + defer.returnValue(content) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2ffb37aa18..ece6dbcf62 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -98,15 +98,23 @@ class TransportLayerServer(object): def new_handler(request, *args, **kwargs): try: (origin, content) = yield self._authenticate_request(request) - response = yield handler( - origin, content, request.args, *args, **kwargs - ) + with self.ratelimiter.ratelimit(origin) as d: + yield d + response = yield handler( + origin, content, request.args, *args, **kwargs + ) except: logger.exception("_authenticate_request failed") raise defer.returnValue(response) return new_handler + def rate_limit_origin(self, handler): + def new_handler(origin, *args, **kwargs): + response = yield handler(origin, *args, **kwargs) + defer.returnValue(response) + return new_handler() + @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -234,6 +242,7 @@ class TransportLayerServer(object): ) ) ) + self.server.register_path( "POST", re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), @@ -245,6 +254,17 @@ class TransportLayerServer(object): ) ) + self.server.register_path( + "POST", + re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"), + self._with_authentication( + lambda origin, content, query, room_id: + self._get_missing_events( + origin, content, room_id, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -344,3 +364,22 @@ class TransportLayerServer(object): ) defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + @log_function + def _get_missing_events(self, origin, content, room_id): + limit = int(content.get("limit", 10)) + min_depth = int(content.get("min_depth", 0)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + content = yield self.request_handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + min_depth=min_depth, + limit=limit, + ) + + defer.returnValue((200, content)) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 20ab9e269c..f76febee8f 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -160,7 +160,7 @@ class DirectoryHandler(BaseHandler): if not room_id: raise SynapseError( 404, - "Room alias %r not found" % (room_alias.to_string(),), + "Room alias %s not found" % (room_alias.to_string(),), Codes.NOT_FOUND ) @@ -232,13 +232,23 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): + # Any application service "interested" in an alias they are regexing on + # can modify the alias. + # Users can only modify the alias if ALL the interested services have + # non-exclusive locks on the alias (or there are no interested services) services = yield self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] + for service in interested_services: if user_id == service.sender: - # this user IS the app service + # this user IS the app service so they can do whatever they like defer.returnValue(True) return - defer.returnValue(len(interested_services) == 0) + elif service.is_exclusive_alias(alias.to_string()): + # another service has an exclusive lock on this alias. + defer.returnValue(False) + return + # either no interested services, or no service with an exclusive lock + defer.returnValue(True) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 025e7e7e62..d3297b7292 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -23,6 +23,7 @@ from synapse.events.utils import serialize_event from ._base import BaseHandler import logging +import random logger = logging.getLogger(__name__) @@ -69,12 +70,17 @@ class EventStreamHandler(BaseHandler): ) self._streams_per_user[auth_user] += 1 - if pagin_config.from_token is None: - pagin_config.from_token = None - rm_handler = self.hs.get_handlers().room_member_handler room_ids = yield rm_handler.get_rooms_for_user(auth_user) + if timeout: + # If they've set a timeout set a minimum limit. + timeout = max(timeout, 500) + + # Add some randomness to this value to try and mitigate against + # thundering herds on restart. + timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7deed16f9c..ae4e9b316d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -581,12 +581,13 @@ class FederationHandler(BaseHandler): defer.returnValue(event) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id): + def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): yield run_on_reactor() - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") + if do_auth: + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") state_groups = yield self.store.get_state_groups( [event_id] @@ -789,6 +790,29 @@ class FederationHandler(BaseHandler): defer.returnValue(ret) @defer.inlineCallbacks + def on_get_missing_events(self, origin, room_id, earliest_events, + latest_events, limit, min_depth): + in_room = yield self.auth.check_host_in_room( + room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + + limit = min(limit, 20) + min_depth = max(min_depth, 0) + + missing_events = yield self.store.get_missing_events( + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + min_depth=min_depth, + ) + + defer.returnValue(missing_events) + + @defer.inlineCallbacks @log_function def do_auth(self, origin, event, context, auth_events): # Check if we have all the auth events. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 03b2159c53..2ddf9d5378 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -212,10 +212,16 @@ class ProfileHandler(BaseHandler): ) msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_event({ - "type": EventTypes.Member, - "room_id": j.room_id, - "state_key": user.to_string(), - "content": content, - "sender": user.to_string() - }, ratelimit=False) + try: + yield msg_handler.create_and_send_event({ + "type": EventTypes.Member, + "room_id": j.room_id, + "state_key": user.to_string(), + "content": content, + "sender": user.to_string() + }, ratelimit=False) + except Exception as e: + logger.warn( + "Failed to update join event for room %s - %s", + j.room_id, str(e.message) + ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 516a936cee..cda4a8502a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -201,11 +201,12 @@ class RegistrationHandler(BaseHandler): interested_services = [ s for s in services if s.is_interested_in_user(user_id) ] - if len(interested_services) > 0: - raise SynapseError( - 400, "This user ID is reserved by an application service.", - errcode=Codes.EXCLUSIVE - ) + for service in interested_services: + if service.is_exclusive_user(user_id): + raise SynapseError( + 400, "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) def _generate_token(self, user_id): # urlsafe variant uses _ and - so use . as the separator and replace diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 914742d913..80f7ee3f12 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -510,9 +510,16 @@ class RoomMemberHandler(BaseHandler): def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): """Returns a list of roomids that the user has any of the given membership states in.""" - rooms = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user.to_string(), membership_list=membership_list + + app_service = yield self.store.get_app_service_by_user_id( + user.to_string() ) + if app_service: + rooms = yield self.store.get_app_service_rooms(app_service) + else: + rooms = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user.to_string(), membership_list=membership_list + ) # For some reason the list of events contains duplicates # TODO(paul): work out why because I really don't think it should @@ -559,13 +566,24 @@ class RoomEventSource(object): to_key = yield self.get_current_key() - events, end_key = yield self.store.get_room_events_stream( - user_id=user.to_string(), - from_key=from_key, - to_key=to_key, - room_id=None, - limit=limit, + app_service = yield self.store.get_app_service_by_user_id( + user.to_string() ) + if app_service: + events, end_key = yield self.store.get_appservice_room_stream( + service=app_service, + from_key=from_key, + to_key=to_key, + limit=limit, + ) + else: + events, end_key = yield self.store.get_room_events_stream( + user_id=user.to_string(), + from_key=from_key, + to_key=to_key, + room_id=None, + limit=limit, + ) defer.returnValue((events, end_key)) diff --git a/synapse/http/client.py b/synapse/http/client.py index 22b7145ac1..b53a07aa2d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -143,7 +143,7 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) - json_str = json.dumps(json_body) + json_str = encode_canonical_json(json_body) response = yield self.agent.request( "PUT", diff --git a/synapse/http/server.py b/synapse/http/server.py index 74a101a5d7..767c3ef79b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -124,27 +124,29 @@ class JsonResource(HttpServer, resource.Resource): # and path regex match for path_entry in self.path_regexs.get(request.method, []): m = path_entry.pattern.match(request.path) - if m: - # We found a match! Trigger callback and then return the - # returned response. We pass both the request and any - # matched groups from the regex to the callback. - - args = [ - urllib.unquote(u).decode("UTF-8") for u in m.groups() - ] - - logger.info( - "Received request: %s %s", - request.method, request.path - ) - - code, response = yield path_entry.callback( - request, - *args - ) - - self._send_response(request, code, response) - return + if not m: + continue + + # We found a match! Trigger callback and then return the + # returned response. We pass both the request and any + # matched groups from the regex to the callback. + + args = [ + urllib.unquote(u).decode("UTF-8") for u in m.groups() + ] + + logger.info( + "Received request: %s %s", + request.method, request.path + ) + + code, response = yield path_entry.callback( + request, + *args + ) + + self._send_response(request, code, response) + return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. raise UnrecognizedRequestError() diff --git a/synapse/notifier.py b/synapse/notifier.py index 2475f3ffbe..df13e8ddb6 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -36,8 +36,10 @@ class _NotificationListener(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user, rooms, from_token, limit, timeout, deferred): + def __init__(self, user, rooms, from_token, limit, timeout, deferred, + appservice=None): self.user = user + self.appservice = appservice self.from_token = from_token self.limit = limit self.timeout = timeout @@ -61,10 +63,14 @@ class _NotificationListener(object): pass for room in self.rooms: - lst = notifier.rooms_to_listeners.get(room, set()) + lst = notifier.room_to_listeners.get(room, set()) lst.discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self) + if self.appservice: + notifier.appservice_to_listeners.get( + self.appservice, set() + ).discard(self) class Notifier(object): @@ -77,8 +83,9 @@ class Notifier(object): def __init__(self, hs): self.hs = hs - self.rooms_to_listeners = {} + self.room_to_listeners = {} self.user_to_listeners = {} + self.appservice_to_listeners = {} self.event_sources = hs.get_event_sources() @@ -109,11 +116,22 @@ class Notifier(object): room_source = self.event_sources.sources["room"] - listeners = self.rooms_to_listeners.get(room_id, set()).copy() + listeners = self.room_to_listeners.get(room_id, set()).copy() for user in extra_users: listeners |= self.user_to_listeners.get(user, set()).copy() + for appservice in self.appservice_to_listeners: + # TODO (kegan): Redundant appservice listener checks? + # App services will already be in the room_to_listeners set, but + # that isn't enough. They need to be checked here in order to + # receive *invites* for users they are interested in. Does this + # make the room_to_listeners check somewhat obselete? + if appservice.is_interested(event): + listeners |= self.appservice_to_listeners.get( + appservice, set() + ).copy() + logger.debug("on_new_room_event listeners %s", listeners) # TODO (erikj): Can we make this more efficient by hitting the @@ -166,7 +184,7 @@ class Notifier(object): listeners |= self.user_to_listeners.get(user, set()).copy() for room in rooms: - listeners |= self.rooms_to_listeners.get(room, set()).copy() + listeners |= self.room_to_listeners.get(room, set()).copy() @defer.inlineCallbacks def notify(listener): @@ -280,6 +298,10 @@ class Notifier(object): if not from_token: from_token = yield self.event_sources.get_current_token() + appservice = yield self.hs.get_datastore().get_app_service_by_user_id( + user.to_string() + ) + listener = _NotificationListener( user, rooms, @@ -287,6 +309,7 @@ class Notifier(object): limit, timeout, deferred, + appservice=appservice ) def _timeout_listener(): @@ -314,11 +337,16 @@ class Notifier(object): @log_function def _register_with_keys(self, listener): for room in listener.rooms: - s = self.rooms_to_listeners.setdefault(room, set()) + s = self.room_to_listeners.setdefault(room, set()) s.add(listener) self.user_to_listeners.setdefault(listener.user, set()).add(listener) + if listener.appservice: + self.appservice_to_listeners.setdefault( + listener.appservice, set() + ).add(listener) + @defer.inlineCallbacks @log_function def _check_for_updates(self, listener): @@ -352,5 +380,5 @@ class Notifier(object): def _user_joined_room(self, user, room_id): new_listeners = self.user_to_listeners.get(user, set()) - listeners = self.rooms_to_listeners.setdefault(room_id, set()) + listeners = self.room_to_listeners.setdefault(room_id, set()) listeners |= new_listeners diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 0fb3e4f7f3..3da0ce8703 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -32,7 +32,7 @@ class Pusher(object): INITIAL_BACKOFF = 1000 MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - DEFAULT_ACTIONS = ['notify'] + DEFAULT_ACTIONS = ['dont-notify'] INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @@ -72,16 +72,14 @@ class Pusher(object): # let's assume you probably know about messages you sent yourself defer.returnValue(['dont_notify']) - if ev['type'] == 'm.room.member': - if ev['state_key'] != self.user_name: - defer.returnValue(['dont_notify']) - - rawrules = yield self.store.get_push_rules_for_user_name(self.user_name) + rawrules = yield self.store.get_push_rules_for_user(self.user_name) for r in rawrules: r['conditions'] = json.loads(r['conditions']) r['actions'] = json.loads(r['actions']) + enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name) + user = UserID.from_string(self.user_name) rules = baserules.list_with_base_rules(rawrules, user) @@ -107,6 +105,8 @@ class Pusher(object): room_member_count += 1 for r in rules: + if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]: + continue matches = True conditions = r['conditions'] @@ -117,7 +117,11 @@ class Pusher(object): ev, c, display_name=my_display_name, room_member_count=room_member_count ) - # ignore rules with no actions (we have an explict 'dont_notify' + logger.debug( + "Rule %s %s", + r['rule_id'], "matches" if matches else "doesn't match" + ) + # ignore rules with no actions (we have an explict 'dont_notify') if len(actions) == 0: logger.warn( "Ignoring rule id %s with no actions for user %s" % diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 162d265f66..6e333a3d21 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -32,12 +32,14 @@ def make_base_rules(user, kind): if kind == 'override': rules = make_base_override_rules() + elif kind == 'underride': + rules = make_base_underride_rules(user) elif kind == 'content': rules = make_base_content_rules(user) for r in rules: r['priority_class'] = PRIORITY_CLASS_MAP[kind] - r['default'] = True + r['default'] = True # Deprecated, left for backwards compat return rules @@ -45,6 +47,7 @@ def make_base_rules(user, kind): def make_base_content_rules(user): return [ { + 'rule_id': 'global/content/.m.rule.contains_user_name', 'conditions': [ { 'kind': 'event_match', @@ -57,6 +60,8 @@ def make_base_content_rules(user): { 'set_tweak': 'sound', 'value': 'default', + }, { + 'set_tweak': 'highlight' } ] }, @@ -66,6 +71,40 @@ def make_base_content_rules(user): def make_base_override_rules(): return [ { + 'rule_id': 'global/override/.m.rule.call', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.call.invite', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'ring' + }, { + 'set_tweak': 'highlight', + 'value': 'false' + } + ] + }, + { + 'rule_id': 'global/override/.m.rule.suppress_notices', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.msgtype', + 'pattern': 'm.notice', + } + ], + 'actions': [ + 'dont_notify', + ] + }, + { + 'rule_id': 'global/override/.m.rule.contains_display_name', 'conditions': [ { 'kind': 'contains_display_name' @@ -76,10 +115,13 @@ def make_base_override_rules(): { 'set_tweak': 'sound', 'value': 'default' + }, { + 'set_tweak': 'highlight' } ] }, { + 'rule_id': 'global/override/.m.rule.room_one_to_one', 'conditions': [ { 'kind': 'room_member_count', @@ -91,6 +133,76 @@ def make_base_override_rules(): { 'set_tweak': 'sound', 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': 'false' + } + ] + } + ] + + +def make_base_underride_rules(user): + return [ + { + 'rule_id': 'global/underride/.m.rule.invite_for_me', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + }, + { + 'kind': 'event_match', + 'key': 'content.membership', + 'pattern': 'invite', + }, + { + 'kind': 'event_match', + 'key': 'state_key', + 'pattern': user.to_string(), + }, + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': 'false' + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.member_event', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': 'false' + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.message', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.message', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': 'false' } ] } diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 202fcb42f4..a02fed57b4 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -88,6 +88,7 @@ class HttpPusher(Pusher): } if event['type'] == 'm.room.member': d['notification']['membership'] = event['content']['membership'] + d['notification']['user_is_target'] = event['state_key'] == self.user_name if 'content' in event: d['notification']['content'] = event['content'] @@ -108,7 +109,7 @@ class HttpPusher(Pusher): try: resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) except: - logger.exception("Failed to push %s ", self.url) + logger.warn("Failed to push %s ", self.url) defer.returnValue(False) rejected = [] if 'rejected' in resp: diff --git a/synapse/rest/appservice/v1/register.py b/synapse/rest/appservice/v1/register.py index 3bd0c1220c..a4f6159773 100644 --- a/synapse/rest/appservice/v1/register.py +++ b/synapse/rest/appservice/v1/register.py @@ -48,18 +48,12 @@ class RegisterRestServlet(AppServiceRestServlet): 400, "Missed required keys: as_token(str) / url(str)." ) - namespaces = { - "users": [], - "rooms": [], - "aliases": [] - } - - if "namespaces" in params: - self._parse_namespace(namespaces, params["namespaces"], "users") - self._parse_namespace(namespaces, params["namespaces"], "rooms") - self._parse_namespace(namespaces, params["namespaces"], "aliases") - - app_service = ApplicationService(as_token, as_url, namespaces) + try: + app_service = ApplicationService( + as_token, as_url, params["namespaces"] + ) + except ValueError as e: + raise SynapseError(400, e.message) app_service = yield self.handler.register(app_service) hs_token = app_service.hs_token @@ -68,23 +62,6 @@ class RegisterRestServlet(AppServiceRestServlet): "hs_token": hs_token })) - def _parse_namespace(self, target_ns, origin_ns, ns): - if ns not in target_ns or ns not in origin_ns: - return # nothing to parse / map through to. - - possible_regex_list = origin_ns[ns] - if not type(possible_regex_list) == list: - raise SynapseError(400, "Namespace %s isn't an array." % ns) - - for regex in possible_regex_list: - if not isinstance(regex, basestring): - raise SynapseError( - 400, "Regex '%s' isn't a string in namespace %s" % - (regex, ns) - ) - - target_ns[ns] = origin_ns[ns] - class UnregisterRestServlet(AppServiceRestServlet): """Handles AS registration with the home server. diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b012f31084..fef0eb6572 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -50,6 +50,10 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + if 'attr' in spec: + self.set_rule_attr(user.to_string(), spec, content) + defer.returnValue((200, {})) + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], @@ -110,7 +114,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name( + rawrules = yield self.hs.get_datastore().get_push_rules_for_user( user.to_string() ) @@ -124,6 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) + enabled_map = yield self.hs.get_datastore().\ + get_push_rules_enabled_for_user(user.to_string()) + for r in ruleslist: rulearray = None @@ -149,6 +156,9 @@ class PushRuleRestServlet(ClientV1RestServlet): template_rule = _rule_to_template(r) if template_rule: + template_rule['enabled'] = True + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] rulearray.append(template_rule) path = request.postpath[1:] @@ -189,6 +199,25 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def set_rule_attr(self, user_name, spec, val): + if spec['attr'] == 'enabled': + if not isinstance(val, bool): + raise SynapseError(400, "Value for 'enabled' must be boolean") + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + self.hs.get_datastore().set_push_rule_enabled( + user_name, namespaced_rule_id, val + ) + else: + raise UnrecognizedRequestError() + + def get_rule_attr(self, user_name, namespaced_rule_id, attr): + if attr == 'enabled': + return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( + user_name, namespaced_rule_id + ) + else: + raise UnrecognizedRequestError() + def _rule_spec_from_path(path): if len(path) < 2: @@ -214,7 +243,7 @@ def _rule_spec_from_path(path): template = path[0] path = path[1:] - if len(path) == 0: + if len(path) == 0 or len(path[0]) == 0: raise UnrecognizedRequestError() rule_id = path[0] @@ -226,6 +255,12 @@ def _rule_spec_from_path(path): } if device: spec['profile_tag'] = device + + path = path[1:] + + if len(path) > 0 and len(path[0]) > 0: + spec['attr'] = path[0] + return spec @@ -275,7 +310,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass - elif isinstance(a, dict) and 'set_sound' in a: + elif isinstance(a, dict) and 'set_tweak' in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -319,10 +354,23 @@ def _filter_ruleset_with_path(ruleset, path): if path[0] == '': return ruleset[template_kind] rule_id = path[0] + + the_rule = None for r in ruleset[template_kind]: if r['rule_id'] == rule_id: - return r - raise NotFoundError + the_rule = r + if the_rule is None: + raise NotFoundError + + path = path[1:] + if len(path) == 0: + return the_rule + + attr = path[0] + if attr in the_rule: + return the_rule[attr] + else: + raise UnrecognizedRequestError() def _priority_class_from_spec(spec): @@ -339,7 +387,7 @@ def _priority_class_from_spec(spec): def _priority_class_to_template_name(pc): if pc > PRIORITY_CLASS_MAP['override']: # per-device - prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + prio_class_index = pc - len(PRIORITY_CLASS_MAP) return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] else: return PRIORITY_CLASS_INVERSE_MAP[pc] @@ -399,9 +447,6 @@ class InvalidRuleException(Exception): def _parse_json(request): try: content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) return content except ValueError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8d2115082b..f5acfb945f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -59,6 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet): # } # TODO: persistent storage self.sessions = {} + self.disable_registration = hs.config.disable_registration def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -107,6 +108,11 @@ class RegisterRestServlet(ClientV1RestServlet): try: login_type = register_json["type"] + + is_application_server = login_type == LoginType.APPLICATION_SERVICE + if self.disable_registration and not is_application_server: + raise SynapseError(403, "Registration has been disabled") + stages = { LoginType.RECAPTCHA: self._do_recaptcha, LoginType.PASSWORD: self._do_password, diff --git a/synapse/server.py b/synapse/server.py index ba2b2593f1..cb8610a1b4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -73,6 +73,7 @@ class BaseHomeServer(object): 'resource_for_client', 'resource_for_client_v2_alpha', 'resource_for_federation', + 'resource_for_static_content', 'resource_for_web_client', 'resource_for_content_repo', 'resource_for_server_key', diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d16e7b8fac..a3ff995695 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,36 +45,19 @@ from syutil.jsonutil import encode_canonical_json from synapse.crypto.event_signing import compute_event_reference_hash +import fnmatch +import imp import logging import os +import re logger = logging.getLogger(__name__) -SCHEMAS = [ - "transactions", - "users", - "profiles", - "presence", - "im", - "room_aliases", - "keys", - "redactions", - "state", - "event_edges", - "event_signatures", - "pusher", - "media_repository", - "application_services", - "filtering", - "rejections", -] - - -# Remember to update this number every time an incompatible change is made to -# database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 13 +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 14 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -576,28 +559,15 @@ class DataStore(RoomMemberStore, RoomStore, ) -def schema_path(schema): - """ Get a filesystem path for the named database schema - - Args: - schema: Name of the database schema. - Returns: - A filesystem path pointing at a ".sql" file. - - """ - schemaPath = os.path.join(dir_path, "schema", schema + ".sql") - return schemaPath - - -def read_schema(schema): +def read_schema(path): """ Read the named database schema. Args: - schema: Name of the datbase schema. + path: Path of the database schema. Returns: A string containing the database schema. """ - with open(schema_path(schema)) as schema_file: + with open(path) as schema_file: return schema_file.read() @@ -610,49 +580,275 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn): - """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we - don't have to worry about overwriting existing content. + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database(cur, user_version, delta_files, upgraded) + else: + _setup_new_database(cur) + + cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. """ - c = db_conn.cursor() - c.execute("PRAGMA user_version") - row = c.fetchone() + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) - if row and row[0]: - user_version = row[0] + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) - if user_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - elif user_version < SCHEMA_VERSION: - logger.info( - "Upgrading database from version %d", - user_version + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + sql_script = "BEGIN TRANSACTION;\n" + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + sql_script += read_schema(sql_loc) + sql_script += "\n" + sql_script += "COMMIT TRANSACTION;" + cur.executescript(sql_script) + + cur.execute( + "INSERT OR REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (max_current_ver, False) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) ) - # Run every version since after the current version. - for v in range(user_version + 1, SCHEMA_VERSION + 1): - if v == 10: - raise UpgradeDatabaseException( - "No delta for version 10" + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file ) - sql_script = read_schema("delta/v%d" % (v)) - c.executescript(sql_script) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur) + elif ext == ".sql": + # A plain old .sql file, just read and execute it + delta_schema = read_schema(absolute_path) + logger.debug("Applying schema %s", relative_path) + cur.executescript(delta_schema) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + (v, relative_path) + ) + + cur.execute( + "INSERT OR REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (v, True) + ) - db_conn.commit() - else: - logger.info("Database is at version %r", user_version) - - else: - sql_script = "BEGIN TRANSACTION;\n" - for sql_loc in SCHEMAS: - logger.debug("Applying schema %r", sql_loc) - sql_script += read_schema(sql_loc) - sql_script += "\n" - sql_script += "COMMIT TRANSACTION;" - c.executescript(sql_script) - db_conn.commit() - c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) - c.close() +def _get_or_create_schema_state(txn): + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + txn.executescript(create_schema) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + "SELECT file FROM applied_schema_deltas WHERE version >= ?", + (current_version,) + ) + return current_version, txn.fetchall(), upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "INSERT OR REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index be9934c66f..3725c9795d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache from twisted.internet import defer -import collections +from collections import namedtuple, OrderedDict import simplejson as json import sys import time @@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") +# TODO(paul): +# * more generic key management +# * export monitoring stats +# * consider other eviction strategies - LRU? +def cached(max_entries=1000): + """ A method decorator that applies a memoizing cache around the function. + + The function is presumed to take one additional argument, which is used as + the key for the cache. Cache hits are served directly from the cache; + misses use the function body to generate the value. + + The wrapped function has an additional member, a callable called + "invalidate". This can be used to remove individual entries from the cache. + + The wrapped function has another additional callable, called "prefill", + which can be used to insert values into the cache specifically, without + calling the calculation function. + """ + def wrap(orig): + cache = OrderedDict() + + def prefill(key, value): + while len(cache) > max_entries: + cache.popitem(last=False) + + cache[key] = value + + @defer.inlineCallbacks + def wrapped(self, key): + if key in cache: + defer.returnValue(cache[key]) + + ret = yield orig(self, key) + prefill(key, ret) + defer.returnValue(ret) + + def invalidate(key): + cache.pop(key, None) + + wrapped.invalidate = invalidate + wrapped.prefill = prefill + return wrapped + + return wrap + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" @@ -404,7 +450,8 @@ class SQLBaseStore(object): Args: table : string giving the table name - keyvalues : dict of column names and values to select the rows with + keyvalues : dict of column names and values to select the rows with, + or None to not apply a WHERE clause. retcols : list of strings giving the names of the columns to return """ return self.runInteraction( @@ -423,13 +470,20 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) - ) + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + txn.execute(sql, keyvalues.values()) + else: + sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( + ", ".join(retcols), + table + ) + txn.execute(sql) - txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn) def _simple_update_one(self, table, keyvalues, updatevalues, @@ -586,8 +640,9 @@ class SQLBaseStore(object): start_time = time.time() * 1000 update_counter = self._get_event_counters.update + cache = self._get_event_cache.setdefault(event_id, {}) + try: - cache = self._get_event_cache.setdefault(event_id, {}) # Separate cache entries for each way to invoke _get_event_txn return cache[(check_redacted, get_prev_content, allow_rejected)] except KeyError: @@ -786,7 +841,7 @@ class JoinHelper(object): for table in self.tables: res += [f for f in table.fields if f not in res] - self.EntryType = collections.namedtuple("JoinHelperEntry", res) + self.EntryType = namedtuple("JoinHelperEntry", res) def get_fields(self, **prefixes): """Get a string representing a list of fields for use in SELECT diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index d941b1f387..e30265750a 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -13,34 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import simplejson +from simplejson import JSONDecodeError from twisted.internet import defer +from synapse.api.constants import Membership from synapse.api.errors import StoreError from synapse.appservice import ApplicationService +from synapse.storage.roommember import RoomsForUser from ._base import SQLBaseStore logger = logging.getLogger(__name__) -class ApplicationServiceCache(object): - """Caches ApplicationServices and provides utility functions on top. - - This class is designed to be invoked on incoming events in order to avoid - hammering the database every time to extract a list of application service - regexes. - """ - - def __init__(self): - self.services = [] +def log_failure(failure): + logger.error("Failed to detect application services: %s", failure.value) + logger.error(failure.getTraceback()) class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) - self.cache = ApplicationServiceCache() + self.services_cache = [] self.cache_defer = self._populate_cache() + self.cache_defer.addErrback(log_failure) @defer.inlineCallbacks def unregister_app_service(self, token): @@ -56,7 +54,7 @@ class ApplicationServiceStore(SQLBaseStore): token, ) # update cache TODO: Should this be in the txn? - for service in self.cache.services: + for service in self.services_cache: if service.token == token: service.url = None service.namespaces = None @@ -110,13 +108,13 @@ class ApplicationServiceStore(SQLBaseStore): ) # update cache TODO: Should this be in the txn? - for (index, cache_service) in enumerate(self.cache.services): + for (index, cache_service) in enumerate(self.services_cache): if service.token == cache_service.token: - self.cache.services[index] = service + self.services_cache[index] = service logger.info("Updated: %s", service) return # new entry - self.cache.services.append(service) + self.services_cache.append(service) logger.info("Updated(new): %s", service) def _update_app_service_txn(self, txn, service): @@ -140,11 +138,11 @@ class ApplicationServiceStore(SQLBaseStore): ) for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST): if ns_str in service.namespaces: - for regex in service.namespaces[ns_str]: + for regex_obj in service.namespaces[ns_str]: txn.execute( "INSERT INTO application_services_regex(" "as_id, namespace, regex) values(?,?,?)", - (as_id, ns_int, regex) + (as_id, ns_int, simplejson.dumps(regex_obj)) ) return True @@ -160,11 +158,34 @@ class ApplicationServiceStore(SQLBaseStore): @defer.inlineCallbacks def get_app_services(self): yield self.cache_defer # make sure the cache is ready - defer.returnValue(self.cache.services) + defer.returnValue(self.services_cache) + + @defer.inlineCallbacks + def get_app_service_by_user_id(self, user_id): + """Retrieve an application service from their user ID. + + All application services have associated with them a particular user ID. + There is no distinguishing feature on the user ID which indicates it + represents an application service. This function allows you to map from + a user ID to an application service. + + Args: + user_id(str): The user ID to see if it is an application service. + Returns: + synapse.appservice.ApplicationService or None. + """ + + yield self.cache_defer # make sure the cache is ready + + for service in self.services_cache: + if service.sender == user_id: + defer.returnValue(service) + return + defer.returnValue(None) @defer.inlineCallbacks def get_app_service_by_token(self, token, from_cache=True): - """Get the application service with the given token. + """Get the application service with the given appservice token. Args: token (str): The application service token. @@ -176,7 +197,7 @@ class ApplicationServiceStore(SQLBaseStore): yield self.cache_defer # make sure the cache is ready if from_cache: - for service in self.cache.services: + for service in self.services_cache: if service.token == token: defer.returnValue(service) return @@ -185,6 +206,77 @@ class ApplicationServiceStore(SQLBaseStore): # TODO: The from_cache=False impl # TODO: This should be JOINed with the application_services_regex table. + def get_app_service_rooms(self, service): + """Get a list of RoomsForUser for this application service. + + Application services may be "interested" in lots of rooms depending on + the room ID, the room aliases, or the members in the room. This function + takes all of these into account and returns a list of RoomsForUser which + represent the entire list of room IDs that this application service + wants to know about. + + Args: + service: The application service to get a room list for. + Returns: + A list of RoomsForUser. + """ + return self.runInteraction( + "get_app_service_rooms", + self._get_app_service_rooms_txn, + service, + ) + + def _get_app_service_rooms_txn(self, txn, service): + # get all rooms matching the room ID regex. + room_entries = self._simple_select_list_txn( + txn=txn, table="rooms", keyvalues=None, retcols=["room_id"] + ) + matching_room_list = set([ + r["room_id"] for r in room_entries if + service.is_interested_in_room(r["room_id"]) + ]) + + # resolve room IDs for matching room alias regex. + room_alias_mappings = self._simple_select_list_txn( + txn=txn, table="room_aliases", keyvalues=None, + retcols=["room_id", "room_alias"] + ) + matching_room_list |= set([ + r["room_id"] for r in room_alias_mappings if + service.is_interested_in_alias(r["room_alias"]) + ]) + + # get all rooms for every user for this AS. This is scoped to users on + # this HS only. + user_list = self._simple_select_list_txn( + txn=txn, table="users", keyvalues=None, retcols=["name"] + ) + user_list = [ + u["name"] for u in user_list if + service.is_interested_in_user(u["name"]) + ] + rooms_for_user_matching_user_id = set() # RoomsForUser list + for user_id in user_list: + # FIXME: This assumes this store is linked with RoomMemberStore :( + rooms_for_user = self._get_rooms_for_user_where_membership_is_txn( + txn=txn, + user_id=user_id, + membership_list=[Membership.JOIN] + ) + rooms_for_user_matching_user_id |= set(rooms_for_user) + + # make RoomsForUser tuples for room ids and aliases which are not in the + # main rooms_for_user_list - e.g. they are rooms which do not have AS + # registered users in it. + known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id] + missing_rooms_for_user = [ + RoomsForUser(r, service.sender, "join") for r in + matching_room_list if r not in known_room_ids + ] + rooms_for_user_matching_user_id |= set(missing_rooms_for_user) + + return rooms_for_user_matching_user_id + @defer.inlineCallbacks def _populate_cache(self): """Populates the ApplicationServiceCache from the database.""" @@ -227,15 +319,17 @@ class ApplicationServiceStore(SQLBaseStore): try: services[as_token]["namespaces"][ ApplicationService.NS_LIST[ns_int]].append( - res["regex"] + simplejson.loads(res["regex"]) ) except IndexError: logger.error("Bad namespace enum '%s'. %s", ns_int, res) + except JSONDecodeError: + logger.error("Bad regex object '%s'", res["regex"]) # TODO get last successful txn id f.e. service for service in services.values(): logger.info("Found application service: %s", service) - self.cache.services.append(ApplicationService( + self.services_cache.append(ApplicationService( token=service["token"], url=service["url"], namespaces=service["namespaces"], diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 3fbc090224..2deda8ac50 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -64,6 +64,9 @@ class EventFederationStore(SQLBaseStore): for f in front: txn.execute(base_sql, (f,)) new_front.update([r[0] for r in txn.fetchall()]) + + new_front -= results + front = new_front results.update(front) @@ -378,3 +381,51 @@ class EventFederationStore(SQLBaseStore): event_results += new_front return self._get_events_txn(txn, event_results) + + def get_missing_events(self, room_id, earliest_events, latest_events, + limit, min_depth): + return self.runInteraction( + "get_missing_events", + self._get_missing_events, + room_id, earliest_events, latest_events, limit, min_depth + ) + + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, + limit, min_depth): + + earliest_events = set(earliest_events) + front = set(latest_events) - earliest_events + + event_results = set() + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? AND is_state = 0 " + "LIMIT ?" + ) + + while front and len(event_results) < limit: + new_front = set() + for event_id in front: + txn.execute( + query, + (room_id, event_id, limit - len(event_results)) + ) + + for e_id, in txn.fetchall(): + new_front.add(e_id) + + new_front -= earliest_events + new_front -= event_results + + front = new_front + event_results |= new_front + + events = self._get_events_txn(txn, event_results) + + events = sorted( + [ev for ev in events if ev.depth >= min_depth], + key=lambda e: e.depth, + ) + + return events[:limit] diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ae46b39cc1..bbf322cc84 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): @defer.inlineCallbacks - def get_push_rules_for_user_name(self, user_name): + def get_push_rules_for_user(self, user_name): sql = ( "SELECT "+",".join(PushRuleTable.fields)+" " "FROM "+PushRuleTable.table_name+" " @@ -46,6 +46,28 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(dicts) @defer.inlineCallbacks + def get_push_rules_enabled_for_user(self, user_name): + results = yield self._simple_select_list( + PushRuleEnableTable.table_name, + {'user_name': user_name}, + PushRuleEnableTable.fields + ) + defer.returnValue( + {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + ) + + @defer.inlineCallbacks + def get_push_rule_enabled_by_user_rule_id(self, user_name, rule_id): + results = yield self._simple_select_list( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id}, + ['enabled'] + ) + if not results: + defer.returnValue(True) + defer.returnValue(results[0]) + + @defer.inlineCallbacks def add_push_rule(self, before, after, **kwargs): vals = copy.copy(kwargs) if 'conditions' in vals: @@ -193,6 +215,20 @@ class PushRuleStore(SQLBaseStore): {'user_name': user_name, 'rule_id': rule_id} ) + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_name, rule_id, enabled): + if enabled: + yield self._simple_delete_one( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id} + ) + else: + yield self._simple_upsert( + PushRuleEnableTable.table_name, + {'user_name': user_name, 'rule_id': rule_id}, + {'enabled': False} + ) + class RuleNotFoundException(Exception): pass @@ -216,3 +252,13 @@ class PushRuleTable(Table): ] EntryType = collections.namedtuple("PushRuleEntry", fields) + + +class PushRuleEnableTable(Table): + table_name = "push_rules_enable" + + fields = [ + "user_name", + "rule_id", + "enabled" + ] diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9bf608bc90..65ffb4627f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,7 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore +from ._base import SQLBaseStore, cached from synapse.api.constants import Membership from synapse.types import UserID @@ -35,11 +35,6 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def __init__(self, *args, **kw): - super(RoomMemberStore, self).__init__(*args, **kw) - - self._user_rooms_cache = {} - def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ @@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore): txn.execute(sql, (event.room_id, domain)) - self.invalidate_rooms_for_user(target_user_id) + self.get_rooms_for_user.invalidate(target_user_id) @defer.inlineCallbacks def get_room_member(self, user_id, room_id): @@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore): if not membership_list: return defer.succeed(None) + return self.runInteraction( + "get_rooms_for_user_where_membership_is", + self._get_rooms_for_user_where_membership_is_txn, + user_id, membership_list + ) + + def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, + membership_list): where_clause = "user_id = ? AND (%s)" % ( " OR ".join(["membership = ?" for _ in membership_list]), ) @@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore): args = [user_id] args.extend(membership_list) - def f(txn): - sql = ( - "SELECT m.room_id, m.sender, m.membership" - " FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id" - " WHERE %s" - ) % (where_clause,) - - txn.execute(sql, args) - return [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + sql = ( + "SELECT m.room_id, m.sender, m.membership" + " FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id" + " WHERE %s" + ) % (where_clause,) - return self.runInteraction( - "get_rooms_for_user_where_membership_is", - f - ) + txn.execute(sql, args) + return [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] def get_joined_hosts_for_room(self, room_id): return self._simple_select_onecol( @@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore): results = self._parse_events_txn(txn, rows) return results - # TODO(paul): Create a nice @cached decorator to do this - # @cached - # def get_foo(...) - # ... - # invalidate_foo = get_foo.invalidator - - @defer.inlineCallbacks + @cached() def get_rooms_for_user(self, user_id): - # TODO(paul): put some performance counters in here so we can easily - # track what impact this cache is having - if user_id in self._user_rooms_cache: - defer.returnValue(self._user_rooms_cache[user_id]) - - rooms = yield self.get_rooms_for_user_where_membership_is( + return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) - # TODO(paul): Consider applying a maximum size; just evict things at - # random, or consider LRU? - - self._user_rooms_cache[user_id] = rooms - defer.returnValue(rooms) - - def invalidate_rooms_for_user(self, user_id): - if user_id in self._user_rooms_cache: - del self._user_rooms_cache[user_id] - @defer.inlineCallbacks def user_rooms_intersect(self, user_id_list): """ Checks whether all the users whose IDs are given in a list share a diff --git a/synapse/storage/schema/application_services.sql b/synapse/storage/schema/application_services.sql deleted file mode 100644 index e491ad5aec..0000000000 --- a/synapse/storage/schema/application_services.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS application_services( - id INTEGER PRIMARY KEY AUTOINCREMENT, - url TEXT, - token TEXT, - hs_token TEXT, - sender TEXT, - UNIQUE(token) ON CONFLICT ROLLBACK -); - -CREATE TABLE IF NOT EXISTS application_services_regex( - id INTEGER PRIMARY KEY AUTOINCREMENT, - as_id INTEGER NOT NULL, - namespace INTEGER, /* enum[room_id|room_alias|user_id] */ - regex TEXT, - FOREIGN KEY(as_id) REFERENCES application_services(id) -); - - - diff --git a/synapse/storage/schema/delta/v11.sql b/synapse/storage/schema/delta/11/v11.sql index 313592221b..313592221b 100644 --- a/synapse/storage/schema/delta/v11.sql +++ b/synapse/storage/schema/delta/11/v11.sql diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/12/v12.sql index b87ef1fe79..b87ef1fe79 100644 --- a/synapse/storage/schema/delta/v12.sql +++ b/synapse/storage/schema/delta/12/v12.sql diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/13/v13.sql index e491ad5aec..e491ad5aec 100644 --- a/synapse/storage/schema/delta/v13.sql +++ b/synapse/storage/schema/delta/13/v13.sql diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py new file mode 100644 index 0000000000..847b1c5b89 --- /dev/null +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -0,0 +1,23 @@ +import json +import logging + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur): + cur.execute("SELECT id, regex FROM application_services_regex") + for row in cur.fetchall(): + try: + logger.debug("Checking %s..." % row[0]) + json.loads(row[1]) + except ValueError: + # row isn't in json, make it so. + string_regex = row[1] + new_regex = json.dumps({ + "regex": string_regex, + "exclusive": True + }) + cur.execute( + "UPDATE application_services_regex SET regex=? WHERE id=?", + (new_regex, row[0]) + ) diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/schema/delta/14/v14.sql new file mode 100644 index 0000000000..0212726448 --- /dev/null +++ b/synapse/storage/schema/delta/14/v14.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled TINYINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/schema/delta/v2.sql b/synapse/storage/schema/delta/v2.sql deleted file mode 100644 index f740f6dd5d..0000000000 --- a/synapse/storage/schema/delta/v2.sql +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS events( - stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT, - topological_ordering INTEGER NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - room_id TEXT NOT NULL, - content TEXT NOT NULL, - unrecognized_keys TEXT, - processed BOOL NOT NULL, - outlier BOOL NOT NULL, - CONSTRAINT ev_uniq UNIQUE (event_id) -); - -CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id); -CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering); -CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); -CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id); - -CREATE TABLE IF NOT EXISTS state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - prev_state TEXT -); - -CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id); -CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id); -CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type); -CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key); - - -CREATE TABLE IF NOT EXISTS current_state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE -); - -CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id); -CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id); -CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type); -CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key); - -CREATE TABLE IF NOT EXISTS room_memberships( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - sender TEXT NOT NULL, - room_id TEXT NOT NULL, - membership TEXT NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id); -CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id); - -CREATE TABLE IF NOT EXISTS feedback( - event_id TEXT NOT NULL, - feedback_type TEXT, - target_event_id TEXT, - sender TEXT, - room_id TEXT -); - -CREATE TABLE IF NOT EXISTS topics( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - topic TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS room_names( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - name TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS rooms( - room_id TEXT PRIMARY KEY NOT NULL, - is_public INTEGER, - creator TEXT -); - -CREATE TABLE IF NOT EXISTS room_join_rules( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - join_rule TEXT NOT NULL -); -CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id); -CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id); - - -CREATE TABLE IF NOT EXISTS room_power_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - user_id TEXT NOT NULL, - level INTEGER NOT NULL -); -CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id); -CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id); -CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id); - - -CREATE TABLE IF NOT EXISTS room_default_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id); -CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_add_state_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id); -CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_send_event_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id); -CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_ops_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - ban_level INTEGER, - kick_level INTEGER -); - -CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id); -CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_hosts( - room_id TEXT NOT NULL, - host TEXT NOT NULL, - CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE -); - -CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); - -PRAGMA user_version = 2; diff --git a/synapse/storage/schema/delta/v3.sql b/synapse/storage/schema/delta/v3.sql deleted file mode 100644 index c67e38ff52..0000000000 --- a/synapse/storage/schema/delta/v3.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE INDEX IF NOT EXISTS room_aliases_alias ON room_aliases(room_alias); -CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); - - -CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); - -DELETE FROM room_aliases WHERE rowid NOT IN (SELECT max(rowid) FROM room_aliases GROUP BY room_alias, room_id); - -CREATE UNIQUE INDEX IF NOT EXISTS room_aliases_uniq ON room_aliases(room_alias, room_id); - -PRAGMA user_version = 3; diff --git a/synapse/storage/schema/delta/v4.sql b/synapse/storage/schema/delta/v4.sql deleted file mode 100644 index d3807b7686..0000000000 --- a/synapse/storage/schema/delta/v4.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS redactions ( - event_id TEXT NOT NULL, - redacts TEXT NOT NULL, - CONSTRAINT ev_uniq UNIQUE (event_id) -); - -CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id); -CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts); - -ALTER TABLE room_ops_levels ADD COLUMN redact_level INTEGER; - -PRAGMA user_version = 4; diff --git a/synapse/storage/schema/delta/v5.sql b/synapse/storage/schema/delta/v5.sql deleted file mode 100644 index 0874a15431..0000000000 --- a/synapse/storage/schema/delta/v5.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS user_ips ( - user TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen INTEGER NOT NULL, - CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE -); - -CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user); - -ALTER TABLE users ADD COLUMN admin BOOL DEFAULT 0 NOT NULL; - -PRAGMA user_version = 5; diff --git a/synapse/storage/schema/delta/v6.sql b/synapse/storage/schema/delta/v6.sql deleted file mode 100644 index a9e0a4fe0d..0000000000 --- a/synapse/storage/schema/delta/v6.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS server_tls_certificates( - server_name TEXT, -- Server name. - fingerprint TEXT, -- Certificate fingerprint. - from_server TEXT, -- Which key server the certificate was fetched from. - ts_added_ms INTEGER, -- When the certifcate was added. - tls_certificate BLOB, -- DER encoded x509 certificate. - CONSTRAINT uniqueness UNIQUE (server_name, fingerprint) -); - -CREATE TABLE IF NOT EXISTS server_signature_keys( - server_name TEXT, -- Server name. - key_id TEXT, -- Key version. - from_server TEXT, -- Which key server the key was fetched form. - ts_added_ms INTEGER, -- When the key was added. - verify_key BLOB, -- NACL verification key. - CONSTRAINT uniqueness UNIQUE (server_name, key_id) -); diff --git a/synapse/storage/schema/delta/v8.sql b/synapse/storage/schema/delta/v8.sql deleted file mode 100644 index 1e9f8b18cb..0000000000 --- a/synapse/storage/schema/delta/v8.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - CREATE TABLE IF NOT EXISTS event_signatures_2 ( - event_id TEXT, - signature_name TEXT, - key_id TEXT, - signature BLOB, - CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) -); - -INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature) -SELECT event_id, signature_name, key_id, signature FROM event_signatures; - -DROP TABLE event_signatures; -ALTER TABLE event_signatures_2 RENAME TO event_signatures; - -CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( - event_id -); - -PRAGMA user_version = 8; \ No newline at end of file diff --git a/synapse/storage/schema/delta/v9.sql b/synapse/storage/schema/delta/v9.sql deleted file mode 100644 index 455d51a70c..0000000000 --- a/synapse/storage/schema/delta/v9.sql +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2014, 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- To track destination health -CREATE TABLE IF NOT EXISTS destinations( - destination TEXT PRIMARY KEY, - retry_last_ts INTEGER, - retry_interval INTEGER -); - - -CREATE TABLE IF NOT EXISTS local_media_repository ( - media_id TEXT, -- The id used to refer to the media. - media_type TEXT, -- The MIME-type of the media. - media_length INTEGER, -- Length of the media in bytes. - created_ts INTEGER, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - user_id TEXT, -- The user who uploaded the file. - CONSTRAINT uniqueness UNIQUE (media_id) -); - -CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_method TEXT, -- The method used to make the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - CONSTRAINT uniqueness UNIQUE ( - media_id, thumbnail_width, thumbnail_height, thumbnail_type - ) -); - -CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); - -CREATE TABLE IF NOT EXISTS remote_media_cache ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media on that server. - media_type TEXT, -- The MIME-type of the media. - created_ts INTEGER, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - media_length INTEGER, -- Length of the media in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - CONSTRAINT uniqueness UNIQUE (media_origin, media_id) -); - -CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_method TEXT, -- The method used to make the thumbnail - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - CONSTRAINT uniqueness UNIQUE ( - media_origin, media_id, thumbnail_width, thumbnail_height, - thumbnail_type, thumbnail_type - ) -); - -CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); - - -PRAGMA user_version = 9; diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql deleted file mode 100644 index beb39ca201..0000000000 --- a/synapse/storage/schema/filtering.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS user_filters( - user_id TEXT, - filter_id INTEGER, - filter_json TEXT, - FOREIGN KEY(user_id) REFERENCES users(id) -); - -CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id -); diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql index 1e766d6db2..1e766d6db2 100644 --- a/synapse/storage/schema/event_edges.sql +++ b/synapse/storage/schema/full_schemas/11/event_edges.sql diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/full_schemas/11/event_signatures.sql index c28c39c48a..c28c39c48a 100644 --- a/synapse/storage/schema/event_signatures.sql +++ b/synapse/storage/schema/full_schemas/11/event_signatures.sql diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/full_schemas/11/im.sql index dd00c1cd2f..dd00c1cd2f 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/full_schemas/11/im.sql diff --git a/synapse/storage/schema/keys.sql b/synapse/storage/schema/full_schemas/11/keys.sql index a9e0a4fe0d..a9e0a4fe0d 100644 --- a/synapse/storage/schema/keys.sql +++ b/synapse/storage/schema/full_schemas/11/keys.sql diff --git a/synapse/storage/schema/media_repository.sql b/synapse/storage/schema/full_schemas/11/media_repository.sql index afdf48cbfb..afdf48cbfb 100644 --- a/synapse/storage/schema/media_repository.sql +++ b/synapse/storage/schema/full_schemas/11/media_repository.sql diff --git a/synapse/storage/schema/presence.sql b/synapse/storage/schema/full_schemas/11/presence.sql index f9f8db9697..f9f8db9697 100644 --- a/synapse/storage/schema/presence.sql +++ b/synapse/storage/schema/full_schemas/11/presence.sql diff --git a/synapse/storage/schema/profiles.sql b/synapse/storage/schema/full_schemas/11/profiles.sql index f06a528b4d..f06a528b4d 100644 --- a/synapse/storage/schema/profiles.sql +++ b/synapse/storage/schema/full_schemas/11/profiles.sql diff --git a/synapse/storage/schema/redactions.sql b/synapse/storage/schema/full_schemas/11/redactions.sql index 5011d95db8..5011d95db8 100644 --- a/synapse/storage/schema/redactions.sql +++ b/synapse/storage/schema/full_schemas/11/redactions.sql diff --git a/synapse/storage/schema/room_aliases.sql b/synapse/storage/schema/full_schemas/11/room_aliases.sql index 0d2df01603..0d2df01603 100644 --- a/synapse/storage/schema/room_aliases.sql +++ b/synapse/storage/schema/full_schemas/11/room_aliases.sql diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/full_schemas/11/state.sql index 1fe8f1e430..1fe8f1e430 100644 --- a/synapse/storage/schema/state.sql +++ b/synapse/storage/schema/full_schemas/11/state.sql diff --git a/synapse/storage/schema/transactions.sql b/synapse/storage/schema/full_schemas/11/transactions.sql index 2d30f99b06..2d30f99b06 100644 --- a/synapse/storage/schema/transactions.sql +++ b/synapse/storage/schema/full_schemas/11/transactions.sql diff --git a/synapse/storage/schema/users.sql b/synapse/storage/schema/full_schemas/11/users.sql index 08ccfdac0a..08ccfdac0a 100644 --- a/synapse/storage/schema/users.sql +++ b/synapse/storage/schema/full_schemas/11/users.sql diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql deleted file mode 100644 index 3735b11547..0000000000 --- a/synapse/storage/schema/pusher.sql +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Push notification endpoints that users have configured -CREATE TABLE IF NOT EXISTS pushers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - profile_tag varchar(32) NOT NULL, - kind varchar(8) NOT NULL, - app_id varchar(64) NOT NULL, - app_display_name varchar(64) NOT NULL, - device_display_name varchar(128) NOT NULL, - pushkey blob NOT NULL, - ts BIGINT NOT NULL, - lang varchar(8), - data blob, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - FOREIGN KEY(user_name) REFERENCES users(name), - UNIQUE (app_id, pushkey) -); - -CREATE TABLE IF NOT EXISTS push_rules ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - priority_class TINYINT NOT NULL, - priority INTEGER NOT NULL DEFAULT 0, - conditions TEXT NOT NULL, - actions TEXT NOT NULL, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/rejections.sql b/synapse/storage/schema/rejections.sql deleted file mode 100644 index bd2a8b1bb5..0000000000 --- a/synapse/storage/schema/rejections.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS rejections( - event_id TEXT NOT NULL, - reason TEXT NOT NULL, - last_check TEXT NOT NULL, - CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE -); diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql new file mode 100644 index 0000000000..0431e2d051 --- /dev/null +++ b/synapse/storage/schema/schema_version.sql @@ -0,0 +1,30 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS schema_version( + Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row. + version INTEGER NOT NULL, + upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema. + CONSTRAINT schema_version_lock_x CHECK (Lock='X') + CONSTRAINT schema_version_lock_uniq UNIQUE (Lock) +); + +CREATE TABLE IF NOT EXISTS applied_schema_deltas( + version INTEGER NOT NULL, + file TEXT NOT NULL, + CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE +); + +CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3ccb6f8a61..09bc522210 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -36,6 +36,7 @@ what sort order was used: from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.util.logutils import log_function @@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): class StreamStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_appservice_room_stream(self, service, from_key, to_key, limit=0): + # NB this lives here instead of appservice.py so we can reuse the + # 'private' StreamToken class in this file. + if limit: + limit = max(limit, MAX_STREAM_SIZE) + else: + limit = MAX_STREAM_SIZE + + # From and to keys should be integers from ordering. + from_id = _StreamToken.parse_stream_token(from_key) + to_id = _StreamToken.parse_stream_token(to_key) + + if from_key == to_key: + defer.returnValue(([], to_key)) + return + + # select all the events between from/to with a sensible limit + sql = ( + "SELECT e.event_id, e.room_id, e.type, s.state_key, " + "e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON " + "e.event_id = s.event_id " + "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? " + "ORDER BY stream_ordering ASC LIMIT %(limit)d " + ) % { + "limit": limit + } + + def f(txn): + # pull out all the events between the tokens + txn.execute(sql, (from_id.stream, to_id.stream,)) + rows = self.cursor_to_dict(txn) + + # Logic: + # - We want ALL events which match the AS room_id regex + # - We want ALL events which match the rooms represented by the AS + # room_alias regex + # - We want ALL events for rooms that AS users have joined. + # This is currently supported via get_app_service_rooms (which is + # used for the Notifier listener rooms). We can't reasonably make a + # SQL query for these room IDs, so we'll pull all the events between + # from/to and filter in python. + rooms_for_as = self._get_app_service_rooms_txn(txn, service) + room_ids_for_as = [r.room_id for r in rooms_for_as] + + def app_service_interested(row): + if row["room_id"] in room_ids_for_as: + return True + + if row["type"] == EventTypes.Member: + if service.is_interested_in_user(row.get("state_key")): + return True + return False + + ret = self._get_events_txn( + txn, + # apply the filter on the room id list + [ + r["event_id"] for r in rows + if app_service_interested(r) + ], + get_prev_content=True + ) + + self._set_before_and_after(ret, rows) + + if rows: + key = "s%d" % max(r["stream_ordering"] for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = to_key + + return ret, key + + results = yield self.runInteraction("get_appservice_room_stream", f) + defer.returnValue(results) + @log_function def get_room_events_stream(self, user_id, from_key, to_key, room_id, limit=0, with_feedback=False): @@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore): self._set_before_and_after(ret, rows) if rows: - key = "s%d" % max([r["stream_ordering"] for r in rows]) - + key = "s%d" % max(r["stream_ordering"] for r in rows) else: # Assume we didn't get anything because there was nothing to # get. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index e06ef35690..0b8a3b7a07 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore, Table, cached from collections import namedtuple -from twisted.internet import defer - import logging logger = logging.getLogger(__name__) @@ -28,10 +26,6 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - # a write-through cache of DestinationsTable.EntryType indexed by - # destination string - destination_retry_cache = {} - def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -211,6 +205,7 @@ class TransactionStore(SQLBaseStore): return ReceivedTransactionsTable.decode_results(txn.fetchall()) + @cached() def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -221,9 +216,6 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a DestinationsTable.EntryType for the retry scheme """ - if destination in self.destination_retry_cache: - return defer.succeed(self.destination_retry_cache[destination]) - return self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) @@ -250,7 +242,9 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - self.destination_retry_cache[destination] = ( + # As this is the new value, we might as well prefill the cache + self.get_destination_retry_timings.prefill( + destination, DestinationsTable.EntryType( destination, retry_last_ts, diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py index a45c673d32..f115f50e50 100644 --- a/synapse/util/lrucache.py +++ b/synapse/util/lrucache.py @@ -88,11 +88,15 @@ class LruCache(object): else: return default + def cache_len(): + return len(cache) + self.sentinel = object() self.get = cache_get self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + self.len = cache_len def __getitem__(self, key): result = self.get(key, self.sentinel) @@ -108,3 +112,6 @@ class LruCache(object): result = self.pop(key, self.sentinel) if result is self.sentinel: raise KeyError() + + def __len__(self): + return self.len() diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py new file mode 100644 index 0000000000..d4457af950 --- /dev/null +++ b/synapse/util/ratelimitutils.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import LimitExceededError + +from synapse.util.async import sleep + +import collections +import contextlib +import logging + + +logger = logging.getLogger(__name__) + + +class FederationRateLimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + """ + Args: + clock (Clock) + window_size (int): The window size in milliseconds. + sleep_limit (int): The number of requests received in the last + `window_size` milliseconds before we artificially start + delaying processing of requests. + sleep_msec (int): The number of milliseconds to delay processing + of incoming requests by. + reject_limit (int): The maximum number of requests that are can be + queued for processing before we start rejecting requests with + a 429 Too Many Requests response. + concurrent_requests (int): The number of concurrent requests to + process. + """ + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.ratelimiters = {} + + def ratelimit(self, host): + """Used to ratelimit an incoming request from given host + + Example usage: + + with rate_limiter.ratelimit(origin) as wait_deferred: + yield wait_deferred + # Handle request ... + + Args: + host (str): Origin of incoming request. + + Returns: + _PerHostRatelimiter + """ + return self.ratelimiters.setdefault( + host, + _PerHostRatelimiter( + clock=self.clock, + window_size=self.window_size, + sleep_limit=self.sleep_limit, + sleep_msec=self.sleep_msec, + reject_limit=self.reject_limit, + concurrent_requests=self.concurrent_requests, + ) + ).ratelimit() + + +class _PerHostRatelimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.sleeping_requests = set() + self.ready_request_queue = collections.OrderedDict() + self.current_processing = set() + self.request_times = [] + + def is_empty(self): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + return not ( + self.ready_request_queue + or self.sleeping_requests + or self.current_processing + or self.request_times + ) + + @contextlib.contextmanager + def ratelimit(self): + # `contextlib.contextmanager` takes a generator and turns it into a + # context manager. The generator should only yield once with a value + # to be returned by manager. + # Exceptions will be reraised at the yield. + + request_id = object() + ret = self._on_enter(request_id) + try: + yield ret + finally: + self._on_exit(request_id) + + def _on_enter(self, request_id): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) + if queue_size > self.reject_limit: + raise LimitExceededError( + retry_after_ms=int( + self.window_size / self.sleep_limit + ), + ) + + self.request_times.append(time_now) + + def queue_request(): + if len(self.current_processing) > self.concurrent_requests: + logger.debug("Ratelimit [%s]: Queue req", id(request_id)) + queue_defer = defer.Deferred() + self.ready_request_queue[request_id] = queue_defer + return queue_defer + else: + return defer.succeed(None) + + logger.debug( + "Ratelimit [%s]: len(self.request_times)=%d", + id(request_id), len(self.request_times), + ) + + if len(self.request_times) > self.sleep_limit: + logger.debug( + "Ratelimit [%s]: sleeping req", + id(request_id), + ) + ret_defer = sleep(self.sleep_msec/1000.0) + + self.sleeping_requests.add(request_id) + + def on_wait_finished(_): + logger.debug( + "Ratelimit [%s]: Finished sleeping", + id(request_id), + ) + self.sleeping_requests.discard(request_id) + queue_defer = queue_request() + return queue_defer + + ret_defer.addBoth(on_wait_finished) + else: + ret_defer = queue_request() + + def on_start(r): + logger.debug( + "Ratelimit [%s]: Processing req", + id(request_id), + ) + self.current_processing.add(request_id) + return r + + def on_err(r): + self.current_processing.discard(request_id) + return r + + def on_both(r): + # Ensure that we've properly cleaned up. + self.sleeping_requests.discard(request_id) + self.ready_request_queue.pop(request_id, None) + return r + + ret_defer.addCallbacks(on_start, on_err) + ret_defer.addBoth(on_both) + return ret_defer + + def _on_exit(self, request_id): + logger.debug( + "Ratelimit [%s]: Processed req", + id(request_id), + ) + self.current_processing.discard(request_id) + try: + request_id, deferred = self.ready_request_queue.popitem() + self.current_processing.add(request_id) + deferred.callback(None) + except KeyError: + pass |