diff options
Diffstat (limited to 'synapse')
78 files changed, 1916 insertions, 1085 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index d366b69dab..f32c28be02 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.18.3" +__version__ = "0.18.5-rc2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 69b3392735..a99986714d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -39,6 +39,9 @@ AuthEventTypes = ( EventTypes.ThirdPartyInvite, ) +# guests always get this device id. +GUEST_DEVICE_ID = "guest_device" + class Auth(object): """ @@ -51,17 +54,6 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 - # Docs for these currently lives at - # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst - # In addition, we have type == delete_pusher which grants access only to - # delete pushers. - self._KNOWN_CAVEAT_PREFIXES = set([ - "gen = ", - "guest = ", - "type = ", - "time < ", - "user_id = ", - ]) @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): @@ -685,31 +677,28 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): - """ Get a registered user's ID. + """ Validate access token and get user_id from it Args: token (str): The access token to get the user by. + rights (str): The operation being performed; the access token must + allow this. Returns: dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self.get_user_from_macaroon(token, rights) - except AuthError: - # TODO(daniel): Remove this fallback when all existing access tokens - # have been re-issued as macaroons. - if self.hs.config.expire_access_token: - raise - ret = yield self._look_up_user_by_access_token(token) - - defer.returnValue(ret) + macaroon = pymacaroons.Macaroon.deserialize(token) + except Exception: # deserialize can throw more-or-less anything + # doesn't look like a macaroon: treat it as an opaque token which + # must be in the database. + # TODO: it would be nice to get rid of this, but apparently some + # people use access tokens which aren't macaroons + r = yield self._look_up_user_by_access_token(token) + defer.returnValue(r) - @defer.inlineCallbacks - def get_user_from_macaroon(self, macaroon_str, rights="access"): try: - macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_id = self.get_user_id_from_macaroon(macaroon) user = UserID.from_string(user_id) @@ -724,11 +713,36 @@ class Auth(object): guest = True if guest: + # Guest access tokens are not stored in the database (there can + # only be one access token per guest, anyway). + # + # In order to prevent guest access tokens being used as regular + # user access tokens (and hence getting around the invalidation + # process), we look up the user id and check that it is indeed + # a guest user. + # + # It would of course be much easier to store guest access + # tokens in the database as well, but that would break existing + # guest tokens. + stored_user = yield self.store.get_user_by_id(user_id) + if not stored_user: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Unknown user_id %s" % user_id, + errcode=Codes.UNKNOWN_TOKEN + ) + if not stored_user["is_guest"]: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Guest access token used for regular user", + errcode=Codes.UNKNOWN_TOKEN + ) ret = { "user": user, "is_guest": True, "token_id": None, - "device_id": None, + # all guests get the same device id + "device_id": GUEST_DEVICE_ID, } elif rights == "delete_pusher": # We don't store these tokens in the database @@ -750,7 +764,7 @@ class Auth(object): # macaroon. They probably should be. # TODO: build the dictionary from the macaroon once the # above are fixed - ret = yield self._look_up_user_by_access_token(macaroon_str) + ret = yield self._look_up_user_by_access_token(token) if ret["user"] != user: logger.error( "Macaroon user (%s) != DB user (%s)", @@ -798,27 +812,38 @@ class Auth(object): Args: macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token required (e.g. "access", "refresh", + type_string(str): The kind of token required (e.g. "access", "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. - This should really always be True, but no clients currently implement - token refresh, so we can't enforce expiry yet. user_id (str): The user_id required """ v = pymacaroons.Verifier() + + # the verifier runs a test for every caveat on the macaroon, to check + # that it is met for the current request. Each caveat must match at + # least one of the predicates specified by satisfy_exact or + # specify_general. v.satisfy_exact("gen = 1") v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") + + # verify_expiry should really always be True, but there exist access + # tokens in the wild which expire when they should not, so we can't + # enforce expiry yet (so we have to allow any caveat starting with + # 'time < ' in access tokens). + # + # On the other hand, short-term login tokens (as used by CAS login, for + # example) have an expiry time which we do want to enforce. + if verify_expiry: v.satisfy_general(self._verify_expiry) else: v.satisfy_general(lambda c: c.startswith("time < ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + # access_tokens include a nonce for uniqueness: any value is acceptable + v.satisfy_general(lambda c: c.startswith("nonce = ")) - v = pymacaroons.Verifier() - v.satisfy_general(self._verify_recognizes_caveats) v.verify(macaroon, self.hs.config.macaroon_secret_key) def _verify_expiry(self, caveat): @@ -829,15 +854,6 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - def _verify_recognizes_caveats(self, caveat): - first_space = caveat.find(" ") - if first_space < 0: - return False - second_space = caveat.find(" ", first_space + 1) - if second_space < 0: - return False - return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES - @defer.inlineCallbacks def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0041646858..921c457738 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -39,6 +39,7 @@ class Codes(object): CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" MISSING_PARAM = "M_MISSING_PARAM" + INVALID_PARAM = "M_INVALID_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3b3ef70750..fb291d7fb9 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -71,6 +71,21 @@ class Filtering(object): if key in user_filter_json["room"]: self._check_definition(user_filter_json["room"][key]) + if "event_fields" in user_filter_json: + if type(user_filter_json["event_fields"]) != list: + raise SynapseError(400, "event_fields must be a list of strings") + for field in user_filter_json["event_fields"]: + if not isinstance(field, basestring): + raise SynapseError(400, "Event field must be a string") + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + if r'\\' in field: + raise SynapseError( + 400, r'The escape character \ cannot itself be escaped' + ) + def _check_definition_room_lists(self, definition): """Check that "rooms" and "not_rooms" are lists of room ids if they are present @@ -152,6 +167,7 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self.event_fields = filter_json.get("event_fields", []) def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) @@ -186,6 +202,26 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) + def blocks_all_presence(self): + return ( + self._presence_filter.filters_all_types() or + self._presence_filter.filters_all_senders() + ) + + def blocks_all_room_ephemeral(self): + return ( + self._room_ephemeral_filter.filters_all_types() or + self._room_ephemeral_filter.filters_all_senders() or + self._room_ephemeral_filter.filters_all_rooms() + ) + + def blocks_all_room_timeline(self): + return ( + self._room_timeline_filter.filters_all_types() or + self._room_timeline_filter.filters_all_senders() or + self._room_timeline_filter.filters_all_rooms() + ) + class Filter(object): def __init__(self, filter_json): @@ -202,6 +238,15 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + def filters_all_types(self): + return "*" in self.not_types + + def filters_all_senders(self): + return "*" in self.not_senders + + def filters_all_rooms(self): + return "*" in self.not_rooms + def check(self, event): """Checks whether the filter matches the given event. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py new file mode 100644 index 0000000000..80ea4c8062 --- /dev/null +++ b/synapse/app/federation_sender.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.crypto import context_factory +from synapse.http.site import SynapseSite +from synapse.federation import send_queue +from synapse.federation.units import Edu +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.storage.engines import create_engine +from synapse.storage.presence import UserPresenceState +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from synapse import events + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc +import ujson as json + +logger = logging.getLogger("synapse.app.appservice") + + +class FederationSenderSlaveStore( + SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, + SlavedRegistrationStore, +): + pass + + +class FederationSenderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse federation_sender now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + send_handler = FederationSenderHandler(self) + + send_handler.on_start() + + while True: + try: + args = store.stream_positions() + args.update((yield send_handler.stream_positions())) + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + yield send_handler.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse federation sender", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.federation_sender" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + if config.send_federation: + sys.stderr.write( + "\nThe send_federation must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``send_federation: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.send_federation = True + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ps = FederationSenderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-federation-sender", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +class FederationSenderHandler(object): + """Processes the replication stream and forwards the appropriate entries + to the federation sender. + """ + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation_sender = hs.get_federation_sender() + + self._room_serials = {} + self._room_typing = {} + + def on_start(self): + # There may be some events that are persisted but haven't been sent, + # so send them now. + self.federation_sender.notify_new_events( + self.store.get_room_max_stream_ordering() + ) + + @defer.inlineCallbacks + def stream_positions(self): + stream_id = yield self.store.get_federation_out_pos("federation") + defer.returnValue({ + "federation": stream_id, + + # Ack stuff we've "processed", this should only be called from + # one process. + "federation_ack": stream_id, + }) + + @defer.inlineCallbacks + def process_replication(self, result): + # The federation stream contains things that we want to send out, e.g. + # presence, typing, etc. + fed_stream = result.get("federation") + if fed_stream: + latest_id = int(fed_stream["position"]) + + # The federation stream containis a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + presence_to_send = {} + keyed_edus = {} + edus = {} + failures = {} + device_destinations = set() + + # Parse the rows in the stream + for row in fed_stream["rows"]: + position, typ, content_js = row + content = json.loads(content_js) + + if typ == send_queue.PRESENCE_TYPE: + destination = content["destination"] + state = UserPresenceState.from_dict(content["state"]) + + presence_to_send.setdefault(destination, []).append(state) + elif typ == send_queue.KEYED_EDU_TYPE: + key = content["key"] + edu = Edu(**content["edu"]) + + keyed_edus.setdefault( + edu.destination, {} + )[(edu.destination, tuple(key))] = edu + elif typ == send_queue.EDU_TYPE: + edu = Edu(**content) + + edus.setdefault(edu.destination, []).append(edu) + elif typ == send_queue.FAILURE_TYPE: + destination = content["destination"] + failure = content["failure"] + + failures.setdefault(destination, []).append(failure) + elif typ == send_queue.DEVICE_MESSAGE_TYPE: + device_destinations.add(content["destination"]) + else: + raise Exception("Unrecognised federation type: %r", typ) + + # We've finished collecting, send everything off + for destination, states in presence_to_send.items(): + self.federation_sender.send_presence(destination, states) + + for destination, edu_map in keyed_edus.items(): + for key, edu in edu_map.items(): + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in edus.items(): + for edu in edu_list: + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in failures.items(): + for failure in failure_list: + self.federation_sender.send_failure(destination, failure) + + for destination in device_destinations: + self.federation_sender.send_device_messages(destination) + + # Record where we are in the stream. + yield self.store.update_federation_out_pos( + "federation", latest_id + ) + + # We also need to poke the federation sender when new events happen + event_stream = result.get("events") + if event_stream: + latest_pos = event_stream["position"] + self.federation_sender.notify_new_events(latest_pos) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index af9f17bf7b..1ab5593c6e 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -64,11 +64,12 @@ class Config(object): if isinstance(value, int) or isinstance(value, long): return value second = 1000 - hour = 60 * 60 * second + minute = 60 * second + hour = 60 * minute day = 24 * hour week = 7 * day year = 365 * day - sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year} + sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} size = 1 suffix = value[-1] if suffix in sizes: diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dc68683fbc..ec72c95436 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -50,6 +50,7 @@ handlers: console: class: logging.StreamHandler formatter: precise + filters: [context] loggers: synapse: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 1f438d2bb3..83762d089a 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -27,17 +27,23 @@ class PasswordAuthProviderConfig(Config): ldap_config = config.get("ldap_config", {}) self.ldap_enabled = ldap_config.get("enabled", False) if self.ldap_enabled: - from synapse.util.ldap_auth_provider import LdapAuthProvider + from ldap_auth_provider import LdapAuthProvider parsed_config = LdapAuthProvider.parse_config(ldap_config) self.password_providers.append((LdapAuthProvider, parsed_config)) providers = config.get("password_providers", []) for provider in providers: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + # This is for backwards compat when the ldap auth provider resided + # in this package. + if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": + from ldap_auth_provider import LdapAuthProvider + provider_class = LdapAuthProvider + else: + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) @@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config): def default_config(self, **kwargs): return """\ # password_providers: - # - module: "synapse.util.ldap_auth_provider.LdapAuthProvider" + # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true # uri: "ldap://ldap.example.com:389" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc3f879857..87e500c97a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -32,7 +32,6 @@ class RegistrationConfig(Config): ) self.registration_shared_secret = config.get("registration_shared_secret") - self.user_creation_max_duration = int(config["user_creation_max_duration"]) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] @@ -55,11 +54,6 @@ class RegistrationConfig(Config): # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" - # Sets the expiry for the short term user creation in - # milliseconds. For instance the bellow duration is two weeks - # in milliseconds. - user_creation_max_duration: 1209600000 - # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. diff --git a/synapse/config/server.py b/synapse/config/server.py index ed5417d0c3..634d8e6fe5 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,6 +30,11 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 0e9fd902af..5bbaef8187 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,6 +16,17 @@ from synapse.api.constants import EventTypes from . import EventBase +from frozendict import frozendict + +import re + +# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' +# (?<!stuff) matches if the current position in the string is not preceded +# by a match for 'stuff'. +# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as +# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" +SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.') + def prune_event(event): """ Returns a pruned version of the given event, which removes all keys we @@ -97,6 +108,83 @@ def prune_event(event): ) +def _copy_field(src, dst, field): + """Copy the field in 'src' to 'dst'. + + For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] + then dst={"foo":{"bar":5}}. + + Args: + src(dict): The dict to read from. + dst(dict): The dict to modify. + field(list<str>): List of keys to drill down to in 'src'. + """ + if len(field) == 0: # this should be impossible + return + if len(field) == 1: # common case e.g. 'origin_server_ts' + if field[0] in src: + dst[field[0]] = src[field[0]] + return + + # Else is a nested field e.g. 'content.body' + # Pop the last field as that's the key to move across and we need the + # parent dict in order to access the data. Drill down to the right dict. + key_to_move = field.pop(-1) + sub_dict = src + for sub_field in field: # e.g. sub_field => "content" + if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + sub_dict = sub_dict[sub_field] + else: + return + + if key_to_move not in sub_dict: + return + + # Insert the key into the output dictionary, creating nested objects + # as required. We couldn't do this any earlier or else we'd need to delete + # the empty objects if the key didn't exist. + sub_out_dict = dst + for sub_field in field: + sub_out_dict = sub_out_dict.setdefault(sub_field, {}) + sub_out_dict[key_to_move] = sub_dict[key_to_move] + + +def only_fields(dictionary, fields): + """Return a new dict with only the fields in 'dictionary' which are present + in 'fields'. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + dictionary(dict): The dictionary to read from. + fields(list<str>): A list of fields to copy over. Only shallow refs are + taken. + Returns: + dict: A new dictionary with only the given fields. If fields was empty, + the same dictionary is returned. + """ + if len(fields) == 0: + return dictionary + + # for each field, convert it: + # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] + split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] + + # for each element of the output array of arrays: + # remove escaping so we can use the right key names. + split_fields[:] = [ + [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + ] + + output = {} + for field_array in split_fields: + _copy_field(dictionary, output, field_array) + return output + + def format_event_raw(d): return d @@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None): + token_id=None, only_event_fields=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, d["unsigned"]["transaction_id"] = txn_id if as_client_event: - return event_format(d) - else: - return d + d = event_format(d) + + if only_event_fields: + if (not isinstance(only_event_fields, list) or + not all(isinstance(f, basestring) for f in only_event_fields)): + raise TypeError("only_event_fields must be a list of strings") + d = only_fields(d, only_event_fields) + + return d diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 979fdf2431..2e32d245ba 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,10 +17,9 @@ """ from .replication import ReplicationLayer -from .transport.client import TransportLayerClient -def initialize_http_replication(homeserver): - transport = TransportLayerClient(homeserver) +def initialize_http_replication(hs): + transport = hs.get_federation_transport_client() - return ReplicationLayer(homeserver, transport) + return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 94e76b1978..b255709165 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,7 +18,6 @@ from twisted.internet import defer from .federation_base import FederationBase from synapse.api.constants import Membership -from .units import Edu from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, @@ -45,10 +44,6 @@ logger = logging.getLogger(__name__) # synapse.federation.federation_client is a silly name metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations") - -sent_edus_counter = metrics.register_counter("sent_edus") - sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) @@ -93,63 +88,6 @@ class FederationClient(FederationBase): self._get_pdu_cache.start() @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - sent_pdus_destination_dist.inc_by(len(destinations)) - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - def send_presence(self, destination, states): - if destination != self.server_name: - self._transaction_queue.enqueue_presence(destination, states) - - @log_function - def send_edu(self, destination, edu_type, content, key=None): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - sent_edus_counter.inc() - - self._transaction_queue.enqueue_edu(edu, key=key) - - @log_function - def send_device_messages(self, destination): - """Sends the device messages in the local database to the remote - destination""" - self._transaction_queue.enqueue_device_messages(destination) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False): """Sends a federation Query to a remote homeserver of the given type diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index ea66a5dcbc..62d865ec4b 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -20,8 +20,6 @@ a given transport. from .federation_client import FederationClient from .federation_server import FederationServer -from .transaction_queue import TransactionQueue - from .persistence import TransactionActions import logging @@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer): self._clock = hs.get_clock() self.transaction_actions = TransactionActions(self.store) - self._transaction_queue = TransactionQueue(hs, transport_layer) - - self._order = 0 self.hs = hs diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py new file mode 100644 index 0000000000..5c9f7a86f0 --- /dev/null +++ b/synapse/federation/send_queue.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A federation sender that forwards things to be sent across replication to +a worker process. + +It assumes there is a single worker process feeding off of it. + +Each row in the replication stream consists of a type and some json, where the +types indicate whether they are presence, or edus, etc. + +Ephemeral or non-event data are queued up in-memory. When the worker requests +updates since a particular point, all in-memory data since before that point is +dropped. We also expire things in the queue after 5 minutes, to ensure that a +dead worker doesn't cause the queues to grow limitlessly. + +Events are replicated via a separate events stream. +""" + +from .units import Edu + +from synapse.util.metrics import Measure +import synapse.metrics + +from blist import sorteddict +import ujson + + +metrics = synapse.metrics.get_metrics_for(__name__) + + +PRESENCE_TYPE = "p" +KEYED_EDU_TYPE = "k" +EDU_TYPE = "e" +FAILURE_TYPE = "f" +DEVICE_MESSAGE_TYPE = "d" + + +class FederationRemoteSendQueue(object): + """A drop in replacement for TransactionQueue""" + + def __init__(self, hs): + self.server_name = hs.hostname + self.clock = hs.get_clock() + + self.presence_map = {} + self.presence_changed = sorteddict() + + self.keyed_edu = {} + self.keyed_edu_changed = sorteddict() + + self.edus = sorteddict() + + self.failures = sorteddict() + + self.device_messages = sorteddict() + + self.pos = 1 + self.pos_time = sorteddict() + + # EVERYTHING IS SAD. In particular, python only makes new scopes when + # we make a new function, so we need to make a new function so the inner + # lambda binds to the queue rather than to the name of the queue which + # changes. ARGH. + def register(name, queue): + metrics.register_callback( + queue_name + "_size", + lambda: len(queue), + ) + + for queue_name in [ + "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", + "edus", "failures", "device_messages", "pos_time", + ]: + register(queue_name, getattr(self, queue_name)) + + self.clock.looping_call(self._clear_queue, 30 * 1000) + + def _next_pos(self): + pos = self.pos + self.pos += 1 + self.pos_time[self.clock.time_msec()] = pos + return pos + + def _clear_queue(self): + """Clear the queues for anything older than N minutes""" + + FIVE_MINUTES_AGO = 5 * 60 * 1000 + now = self.clock.time_msec() + + keys = self.pos_time.keys() + time = keys.bisect_left(now - FIVE_MINUTES_AGO) + if not keys[:time]: + return + + position_to_delete = max(keys[:time]) + for key in keys[:time]: + del self.pos_time[key] + + self._clear_queue_before_pos(position_to_delete) + + def _clear_queue_before_pos(self, position_to_delete): + """Clear all the queues from before a given position""" + with Measure(self.clock, "send_queue._clear"): + # Delete things out of presence maps + keys = self.presence_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_changed[key] + + user_ids = set( + user_id for uids in self.presence_changed.values() for _, user_id in uids + ) + + to_del = [ + user_id for user_id in self.presence_map if user_id not in user_ids + ] + for user_id in to_del: + del self.presence_map[user_id] + + # Delete things out of keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.keyed_edu_changed[key] + + live_keys = set() + for edu_key in self.keyed_edu_changed.values(): + live_keys.add(edu_key) + + to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] + for edu_key in to_del: + del self.keyed_edu[edu_key] + + # Delete things out of edu map + keys = self.edus.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.edus[key] + + # Delete things out of failure map + keys = self.failures.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.failures[key] + + # Delete things out of device map + keys = self.device_messages.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.device_messages[key] + + def notify_new_events(self, current_id): + """As per TransactionQueue""" + # We don't need to replicate this as it gets sent down a different + # stream. + pass + + def send_edu(self, destination, edu_type, content, key=None): + """As per TransactionQueue""" + pos = self._next_pos() + + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + if key: + assert isinstance(key, tuple) + self.keyed_edu[(destination, key)] = edu + self.keyed_edu_changed[pos] = (destination, key) + else: + self.edus[pos] = edu + + def send_presence(self, destination, states): + """As per TransactionQueue""" + pos = self._next_pos() + + self.presence_map.update({ + state.user_id: state + for state in states + }) + + self.presence_changed[pos] = [ + (destination, state.user_id) for state in states + ] + + def send_failure(self, failure, destination): + """As per TransactionQueue""" + pos = self._next_pos() + + self.failures[pos] = (destination, str(failure)) + + def send_device_messages(self, destination): + """As per TransactionQueue""" + pos = self._next_pos() + self.device_messages[pos] = destination + + def get_current_token(self): + return self.pos - 1 + + def get_replication_rows(self, token, limit, federation_ack=None): + """ + Args: + token (int) + limit (int) + federation_ack (int): Optional. The position where the worker is + explicitly acknowledged it has handled. Allows us to drop + data from before that point + """ + # TODO: Handle limit. + + # To handle restarts where we wrap around + if token > self.pos: + token = -1 + + rows = [] + + # There should be only one reader, so lets delete everything its + # acknowledged its seen. + if federation_ack: + self._clear_queue_before_pos(federation_ack) + + # Fetch changed presence + keys = self.presence_changed.keys() + i = keys.bisect_right(token) + dest_user_ids = set( + (pos, dest_user_id) + for pos in keys[i:] + for dest_user_id in self.presence_changed[pos] + ) + + for (key, (dest, user_id)) in dest_user_ids: + rows.append((key, PRESENCE_TYPE, ujson.dumps({ + "destination": dest, + "state": self.presence_map[user_id].as_dict(), + }))) + + # Fetch changes keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_right(token) + keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + + for (pos, (destination, edu_key)) in keyed_edus: + rows.append( + (pos, KEYED_EDU_TYPE, ujson.dumps({ + "key": edu_key, + "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), + })) + ) + + # Fetch changed edus + keys = self.edus.keys() + i = keys.bisect_right(token) + edus = set((k, self.edus[k]) for k in keys[i:]) + + for (pos, edu) in edus: + rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + + # Fetch changed failures + keys = self.failures.keys() + i = keys.bisect_right(token) + failures = set((k, self.failures[k]) for k in keys[i:]) + + for (pos, (destination, failure)) in failures: + rows.append((pos, FAILURE_TYPE, ujson.dumps({ + "destination": destination, + "failure": failure, + }))) + + # Fetch changed device messages + keys = self.device_messages.keys() + i = keys.bisect_right(token) + device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + + for (pos, destination) in device_messages: + rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ + "destination": destination, + }))) + + # Sort rows based on pos + rows.sort() + + return rows diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f8ca93e4c3..51b656d74a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -19,6 +19,7 @@ from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn @@ -26,6 +27,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state import synapse.metrics @@ -36,6 +38,12 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") +sent_pdus_destination_dist = client_metrics.register_distribution( + "sent_pdu_destinations" +) +sent_edus_counter = client_metrics.register_counter("sent_edus") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -44,13 +52,14 @@ class TransactionQueue(object): It batches pending PDUs into single transactions. """ - def __init__(self, hs, transport_layer): + def __init__(self, hs): self.server_name = hs.hostname self.store = hs.get_datastore() + self.state = hs.get_state_handler() self.transaction_actions = TransactionActions(self.store) - self.transport_layer = transport_layer + self.transport_layer = hs.get_federation_transport_client() self.clock = hs.get_clock() @@ -95,6 +104,11 @@ class TransactionQueue(object): # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._order = 1 + + self._is_processing = False + self._last_poked_id = -1 + def can_send_to(self, destination): """Can we send messages to the given server? @@ -115,11 +129,61 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - def enqueue_pdu(self, pdu, destinations, order): + @defer.inlineCallbacks + def notify_new_events(self, current_id): + """This gets called when we have some new events we might want to + send out to other servers. + """ + self._last_poked_id = max(current_id, self._last_poked_id) + + if self._is_processing: + return + + try: + self._is_processing = True + while True: + last_token = yield self.store.get_federation_out_pos("events") + next_token, events = yield self.store.get_all_new_events_stream( + last_token, self._last_poked_id, limit=20, + ) + + logger.debug("Handling %s -> %s", last_token, next_token) + + if not events and next_token >= self._last_poked_id: + break + + for event in events: + users_in_room = yield self.state.get_current_user_in_room( + event.room_id, latest_event_ids=[event.event_id], + ) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(event.state_key)) + + logger.debug("Sending %s to %r", event, destinations) + + self._send_pdu(event, destinations) + + yield self.store.update_federation_out_pos( + "events", next_token + ) + + finally: + self._is_processing = False + + def _send_pdu(self, pdu, destinations): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. + order = self._order + self._order += 1 + destinations = set(destinations) destinations = set( dest for dest in destinations if self.can_send_to(dest) @@ -130,6 +194,8 @@ class TransactionQueue(object): if not destinations: return + sent_pdus_destination_dist.inc_by(len(destinations)) + for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( (pdu, order) @@ -139,7 +205,10 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_presence(self, destination, states): + def send_presence(self, destination, states): + if not self.can_send_to(destination): + return + self.pending_presence_by_dest.setdefault(destination, {}).update({ state.user_id: state for state in states }) @@ -148,12 +217,19 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu, key=None): - destination = edu.destination + def send_edu(self, destination, edu_type, content, key=None): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) if not self.can_send_to(destination): return + sent_edus_counter.inc() + if key: self.pending_edus_keyed_by_dest.setdefault( destination, {} @@ -165,7 +241,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_failure(self, failure, destination): + def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return @@ -180,7 +256,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_device_messages(self, destination): + def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -191,6 +267,9 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) + def get_current_token(self): + return 0 + @defer.inlineCallbacks def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) @@ -383,6 +462,13 @@ class TransactionQueue(object): code = e.code response = e.response + if e.code == 429 or 500 <= e.code: + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + raise e + logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 63d05f2531..5ad408f549 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler from .identity import IdentityHandler -from .receipts import ReceiptsHandler from .search import SearchHandler @@ -56,7 +55,6 @@ class Handlers(object): self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) - self.receipts_handler = ReceiptsHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3851b35889..3b146f09d6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,8 @@ class AuthHandler(BaseHandler): for module, config in hs.config.password_providers ] + logger.info("Extra password_providers: %r", self.password_providers) + self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() @@ -160,7 +162,15 @@ class AuthHandler(BaseHandler): for f in flows: if len(set(f) - set(creds.keys())) == 0: - logger.info("Auth completed with creds: %r", creds) + # it's very useful to know what args are stored, but this can + # include the password in the case of registering, so only log + # the keys (confusingly, clientdict may contain a password + # param, creds is just what the user authed as for UI auth + # and is not sensitive). + logger.info( + "Auth completed with creds: %r. Client dict has keys: %r", + creds, clientdict.keys() + ) defer.returnValue((True, creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) @@ -378,12 +388,10 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ - Gets login tuple for the user with the given user ID. - - Creates a new access/refresh token for the user. + Creates a new access token for the user with the given user ID. The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. @@ -398,16 +406,13 @@ class AuthHandler(BaseHandler): initial_display_name (str): display name to associate with the device if it needs re-registering Returns: - A tuple of: The access token for the user's session. - The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - refresh_token = yield self.issue_refresh_token(user_id, device_id) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -418,7 +423,7 @@ class AuthHandler(BaseHandler): user_id, device_id, initial_display_name ) - defer.returnValue((access_token, refresh_token)) + defer.returnValue(access_token) @defer.inlineCallbacks def check_user_exists(self, user_id): @@ -529,35 +534,19 @@ class AuthHandler(BaseHandler): device_id) defer.returnValue(access_token) - @defer.inlineCallbacks - def issue_refresh_token(self, user_id, device_id=None): - refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token, - device_id) - defer.returnValue(refresh_token) - - def generate_access_token(self, user_id, extra_caveats=None, - duration_in_ms=(60 * 60 * 1000)): + def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_refresh_token(self, user_id): - m = self._generate_base_macaroon(user_id) - m.add_first_party_caveat("type = refresh") - # Important to add a nonce, because otherwise every refresh token for a - # user will be the same. - m.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - return m.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index c5368e5df2..f7fad15c62 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -34,9 +34,9 @@ class DeviceMessageHandler(object): self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler( + hs.get_replication_layer().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fd11935b40..b63a660c06 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -111,6 +111,11 @@ class E2eKeysHandler(object): failures[destination] = { "status": 503, "message": "Not ready for retry", } + except Exception as e: + # include ConnectionRefused and other errors + failures[destination] = { + "status": 503, "message": e.message + } yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) @@ -222,6 +227,11 @@ class E2eKeysHandler(object): failures[destination] = { "status": 503, "message": "Not ready for retry", } + except Exception as e: + # include ConnectionRefused and other errors + failures[destination] = { + "status": 503, "message": e.message + } yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(claim_client_keys)(destination) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2d801bad47..771ab3bc43 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -80,22 +80,6 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up self.room_queues = {} - def handle_new_event(self, event, destinations): - """ Takes in an event from the client to server side, that has already - been authed and handled by the state module, and sends it to any - remote home servers that may be interested. - - Args: - event: The event to send - destinations: A list of destinations to send it to - - Returns: - Deferred: Resolved when it has successfully been queued for - processing. - """ - - return self.replication_layer.send_pdu(event, destinations) - @log_function @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): @@ -830,25 +814,6 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - new_pdu = event - - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ) - - destinations.discard(origin) - - logger.debug( - "on_send_join_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids @@ -1055,24 +1020,6 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - new_pdu = event - - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ) - destinations.discard(origin) - - logger.debug( - "on_send_leave_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - defer.returnValue(None) @defer.inlineCallbacks diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbfa5a0281..e0ade4c164 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( + receipts = yield self.store.get_linearized_receipts_for_room( room_id, - now_token.receipt_key + to_key=now_token.receipt_key, ) + if not receipts: + receipts = [] defer.returnValue(receipts) presence, receipts, (messages, token) = yield defer.gatherResults( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4d0515ddfb..7a57a69bd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.push.action_generator import ActionGenerator from synapse.types import ( - UserID, RoomAlias, RoomStreamToken, get_domain_from_id + UserID, RoomAlias, RoomStreamToken, ) from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.logcontext import preserve_fn @@ -605,13 +605,6 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = [ - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ] - @defer.inlineCallbacks def _notify(): yield run_on_reactor() @@ -624,7 +617,3 @@ class MessageHandler(BaseHandler): # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) - - preserve_fn(federation_handler.handle_new_event)( - event, destinations=destinations, - ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b047ae2250..1b89dc6274 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -91,28 +91,29 @@ class PresenceHandler(object): self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.federation = hs.get_replication_layer() + self.replication = hs.get_replication_layer() + self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence", self.incoming_presence ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e536a909d0..50aa513935 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_replication_layer() - self.federation.register_edu_handler( + self.federation = hs.get_federation_sender() + hs.get_replication_layer().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() @@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler): if not res: # res will be None if this read receipt is 'old' - defer.returnValue(False) + continue stream_id, max_persisted_id = res @@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler): if max_batch_id is None or max_persisted_id > max_batch_id: max_batch_id = max_persisted_id + if min_batch_id is None: + # no new receipts + defer.returnValue(False) + affected_room_ids = list(set([r["room_id"] for r in receipts])) with PreserveLoggingContext(): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7e119f13b1..286f0cef0a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler): "User ID already taken.", errcode=Codes.USER_IN_USE, ) - user_data = yield self.auth.get_user_from_macaroon(guest_access_token) + user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, @@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token( - user_id, None, duration_in_ms) + token = self.auth_handler().generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 59e4d1cd15..5f18007e90 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler): "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": False, + "guest_can_join": True, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": True, + "guest_can_join": True, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, "history_visibility": "shared", "original_invitees_have_ops": False, + "guest_can_join": False, }, } @@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler): content={"history_visibility": config["history_visibility"]} ) + if config["guest_can_join"]: + if (EventTypes.GuestAccess, '') not in initial_state: + yield send( + etype=EventTypes.GuestAccess, + content={"guest_access": "can_join"} + ) + for (etype, state_key), content in initial_state.items(): yield send( etype=etype, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f910ff814..b62773dcbe 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -277,6 +277,7 @@ class SyncHandler(object): """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() + block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -293,7 +294,7 @@ class SyncHandler(object): else: recents = [] - if not limited: + if not limited or block_all_timeline: defer.returnValue(TimelineBatch( events=recents, prev_batch=now_token, @@ -509,6 +510,7 @@ class SyncHandler(object): Returns: Deferred(SyncResult) """ + logger.info("Calculating sync response for %r", sync_config.user) # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability @@ -531,9 +533,14 @@ class SyncHandler(object): ) newly_joined_rooms, newly_joined_users = res - yield self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_users + block_all_presence_data = ( + since_token is None and + sync_config.filter_collection.blocks_all_presence() ) + if not block_all_presence_data: + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) yield self._generate_sync_entry_for_to_device(sync_result_builder) @@ -709,13 +716,20 @@ class SyncHandler(object): `(newly_joined_rooms, newly_joined_users)` """ user_id = sync_result_builder.sync_config.user.to_string() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, - now_token=sync_result_builder.now_token, - since_token=sync_result_builder.since_token, + block_all_room_ephemeral = ( + sync_result_builder.since_token is None and + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) - sync_result_builder.now_token = now_token + + if block_all_room_ephemeral: + ephemeral_by_room = {} + else: + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 27ee715ff0..0eea7f8f9c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -55,9 +55,9 @@ class TypingHandler(object): self.clock = hs.get_clock() self.wheel_timer = WheelTimer(bucket_size=5000) - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler("m.typing", self._recv_edu) + hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index d0556ae347..d5970c05a8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from signedjson.sign import sign_json +import cgi import simplejson as json import logging import random @@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) defer.returnValue(json.loads(body)) @@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -525,3 +511,29 @@ def _flatten_response_never_received(e): ) else: return "%s: %s" % (type(e).__name__, e.message,) + + +def check_content_type_is_json(headers): + """ + Check that a set of HTTP headers have a Content-Type header, and that it + is application/json. + + Args: + headers (twisted.web.http_headers.Headers): headers to check + + Raises: + RuntimeError if the + + """ + c_type = headers.getRawHeaders("Content-Type") + if c_type is None: + raise RuntimeError( + "No Content-Type header" + ) + + c_type = c_type[0] # only the first header + val, options = cgi.parse_header(c_type) + if val != "application/json": + raise RuntimeError( + "Content-Type not application/json: was '%s'" % c_type + ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 48653ae843..acbd4bb5ae 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.util import DeferredTimedOutError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred from synapse.util.logcontext import PreserveLoggingContext, preserve_fn @@ -143,6 +144,12 @@ class Notifier(object): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + else: + self.federation_sender = None + self.state_handler = hs.get_state_handler() self.clock.looping_call( @@ -220,6 +227,9 @@ class Notifier(object): # poke any interested application service. self.appservice_handler.notify_interested_services(room_stream_id) + if self.federation_sender: + self.federation_sender.notify_new_events(room_stream_id) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -285,14 +295,7 @@ class Notifier(object): result = None if timeout: - # Will be set to a _NotificationListener that we'll be waiting on. - # Allows us to cancel it. - listener = None - - def timed_out(): - if listener: - listener.deferred.cancel() - timer = self.clock.call_later(timeout / 1000., timed_out) + end_time = self.clock.time_msec() + timeout prev_token = from_token while not result: @@ -303,6 +306,10 @@ class Notifier(object): if result: break + now = self.clock.time_msec() + if end_time <= now: + break + # Now we wait for the _NotifierUserStream to be told there # is a new token. # We need to supply the token we supplied to callback so @@ -310,11 +317,14 @@ class Notifier(object): prev_token = current_token listener = user_stream.new_listener(prev_token) with PreserveLoggingContext(): - yield listener.deferred + yield self.clock.time_bound_deferred( + listener.deferred, + time_out=(end_time - now) / 1000. + ) + except DeferredTimedOutError: + break except defer.CancelledError: break - - self.clock.cancel_call_later(timer, ignore_errs=True) else: current_token = user_stream.current_token result = yield callback(from_token, current_token) @@ -483,22 +493,27 @@ class Notifier(object): """ listener = _NotificationListener(None) - def timed_out(): - listener.deferred.cancel() + end_time = self.clock.time_msec() + timeout - timer = self.clock.call_later(timeout / 1000., timed_out) while True: listener.deferred = self.replication_deferred.observe() result = yield callback() if result: break + now = self.clock.time_msec() + if end_time <= now: + break + try: with PreserveLoggingContext(): - yield listener.deferred + yield self.clock.time_bound_deferred( + listener.deferred, + time_out=(end_time - now) / 1000. + ) + except DeferredTimedOutError: + break except defer.CancelledError: break - self.clock.cancel_call_later(timer, ignore_errs=True) - defer.returnValue(result) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index be55598c43..78b095c903 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,12 +87,12 @@ class BulkPushRuleEvaluator: condition_cache = {} for uid, rules in self.rules_by_user.items(): - display_name = None - member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) - if member_ev_id: - member_ev = yield self.store.get_event(member_ev_id, allow_none=True) - if member_ev: - display_name = member_ev.content.get("displayname", None) + display_name = room_members.get(uid, {}).get("display_name", None) + if not display_name: + # Handle the case where we are pushing a membership event to + # that user, as they might not be already joined. + if event.type == EventTypes.Member and event.state_key == uid: + display_name = event.content.get("displayname", None) filtered = filtered_by_user[uid] if len(filtered) == 0: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index b9e41770ee..3742a25b37 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = { "Jinja2>=2.8": ["Jinja2>=2.8"], "bleach>=1.4.2": ["bleach>=1.4.2"], }, - "ldap": { - "ldap3>=1.0": ["ldap3>=1.0"], + "matrix-synapse-ldap3": { + "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], }, "psutil": { "psutil>=2.0.0": ["psutil>=2.0.0"], @@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False): def github_link(project, version, egg): return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) + DEPENDENCY_LINKS = { } @@ -156,6 +157,7 @@ def list_requirements(): result.append(requirement) return result + if __name__ == "__main__": import sys sys.stdout.writelines(req + "\n" for req in list_requirements()) diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py new file mode 100644 index 0000000000..c05a50d7a6 --- /dev/null +++ b/synapse/replication/expire_cache.py @@ -0,0 +1,60 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class ExpireCacheResource(Resource): + """ + HTTP endpoint for expiring storage caches. + + POST /_synapse/replication/expire_cache HTTP/1.1 + Content-Type: application/json + + { + "invalidate": [ + { + "name": "func_name", + "keys": ["key1", "key2"] + } + ] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.store = hs.get_datastore() + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + for row in content["invalidate"]: + name = row["name"] + keys = tuple(row["keys"]) + + getattr(self.store, name).invalidate(keys) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 5a14c51d23..d79b421cba 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request from synapse.replication.pusher_resource import PusherResource from synapse.replication.presence_resource import PresenceResource +from synapse.replication.expire_cache import ExpireCacheResource from synapse.api.errors import SynapseError from twisted.web.resource import Resource @@ -44,6 +45,7 @@ STREAM_NAMES = ( ("caches",), ("to_device",), ("public_rooms",), + ("federation",), ) @@ -116,11 +118,14 @@ class ReplicationResource(Resource): self.sources = hs.get_event_sources() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() + self.federation_sender = hs.get_federation_sender() self.notifier = hs.notifier self.clock = hs.get_clock() + self.config = hs.get_config() self.putChild("remove_pushers", PusherResource(hs)) self.putChild("syncing_users", PresenceResource(hs)) + self.putChild("expire_cache", ExpireCacheResource(hs)) def render_GET(self, request): self._async_render_GET(request) @@ -134,6 +139,7 @@ class ReplicationResource(Resource): pushers_token = self.store.get_pushers_stream_token() caches_token = self.store.get_cache_stream_token() public_rooms_token = self.store.get_current_public_room_stream_id() + federation_token = self.federation_sender.get_current_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -148,6 +154,7 @@ class ReplicationResource(Resource): caches_token, int(stream_token.to_device_key), int(public_rooms_token), + int(federation_token), )) @request_handler() @@ -164,8 +171,13 @@ class ReplicationResource(Resource): } request_streams["streams"] = parse_string(request, "streams") + federation_ack = parse_integer(request, "federation_ack", None) + def replicate(): - return self.replicate(request_streams, limit) + return self.replicate( + request_streams, limit, + federation_ack=federation_ack + ) writer = yield self.notifier.wait_for_replication(replicate, timeout) result = writer.finish() @@ -183,7 +195,7 @@ class ReplicationResource(Resource): finish_request(request) @defer.inlineCallbacks - def replicate(self, request_streams, limit): + def replicate(self, request_streams, limit, federation_ack=None): writer = _Writer() current_token = yield self.current_replication_token() logger.debug("Replicating up to %r", current_token) @@ -202,6 +214,7 @@ class ReplicationResource(Resource): yield self.caches(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams) + self.federation(writer, current_token, limit, request_streams, federation_ack) self.streams(writer, current_token, request_streams) logger.debug("Replicated %d rows", writer.total) @@ -465,6 +478,23 @@ class ReplicationResource(Resource): "position", "room_id", "visibility" ), position=upto_token) + def federation(self, writer, current_token, limit, request_streams, federation_ack): + if self.config.send_federation: + return + + current_position = current_token.federation + + federation = request_streams.get("federation") + + if federation is not None and federation != current_position: + federation_rows = self.federation_sender.get_replication_rows( + federation, limit, federation_ack=federation_ack, + ) + upto_token = _position_from_rows(federation_rows, current_position) + writer.write_header_and_rows("federation", federation_rows, ( + "position", "type", "content", + ), position=upto_token) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -497,6 +527,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", + "federation", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f19540d6bb..18076e0f3b 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None + self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" + self.http_client = hs.get_simple_http_client() + def stream_positions(self): pos = {} if self._cache_id_gen: @@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore): logger.info("Got unexpected cache_func: %r", cache_func) self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + txn.call_after(cache_func.invalidate, keys) + txn.call_after(self._send_invalidation_poke, cache_func, keys) + + @defer.inlineCallbacks + def _send_invalidation_poke(self, cache_func, keys): + try: + yield self.http_client.post_json_get_json(self.expire_cache_url, { + "invalidate": [{ + "name": cache_func.__name__, + "keys": list(keys), + }] + }) + except: + logger.exception("Failed to poke on expire_cache") diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 3bfd5e8213..cc860f9f9b 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore): "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + self._device_inbox_id_gen.get_current_token() + ) get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ + get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__ + delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() @@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore): self._device_inbox_id_gen.advance(int(stream["position"])) for row in stream["rows"]: stream_id = row[0] - user_id = row[1] - self._device_inbox_stream_cache.entity_has_changed( - user_id, stream_id - ) + entity = row[1] + + if entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + entity, stream_id + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + entity, stream_id + ) return super(SlavedDeviceInboxStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 0c26e96e98..64f18bbb3e 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore from synapse.util.caches.stream_change_cache import StreamChangeCache import ujson as json +import logging + + +logger = logging.getLogger(__name__) + # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore): EventFederationStore.__dict__["_get_forward_extremeties_for_room"] ) + get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ + + get_federation_out_pos = DataStore.get_federation_out_pos.__func__ + update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() @@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore): stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) + + if stream["rows"]: + logger.info("Got %d event rows", len(stream["rows"])) + for row in stream["rows"]: self._process_replication_row( row, backfilled=False, state_resets=state_resets diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 6f2ba98af5..fbb58f35da 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.transactions import TransactionStore @@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore class TransactionStore(BaseSlavedStore): get_destination_retry_timings = TransactionStore.__dict__[ "get_destination_retry_timings" - ].orig + ] _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ + set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__ + _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__ - # For now, don't record the destination rety timings - def set_destination_retry_timings(*args, **kwargs): - return defer.succeed(None) + prep_send_transaction = DataStore.prep_send_transaction.__func__ + delivered_txn = DataStore.delivered_txn.__func__ diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py new file mode 100644 index 0000000000..351170edbc --- /dev/null +++ b/synapse/rest/client/transactions.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains logic for storing HTTP PUT transactions. This is used +to ensure idempotency when performing PUTs using the REST API.""" +import logging + +from synapse.api.auth import get_access_token_from_request +from synapse.util.async import ObservableDeferred + +logger = logging.getLogger(__name__) + + +def get_transaction_key(request): + """A helper function which returns a transaction key that can be used + with TransactionCache for idempotent requests. + + Idempotency is based on the returned key being the same for separate + requests to the same endpoint. The key is formed from the HTTP request + path and the access_token for the requesting user. + + Args: + request (twisted.web.http.Request): The incoming request. Must + contain an access_token. + Returns: + str: A transaction key + """ + token = get_access_token_from_request(request) + return request.path + "/" + token + + +CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins + + +class HttpTransactionCache(object): + + def __init__(self, clock): + self.clock = clock + self.transactions = { + # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) + } + # Try to clean entries every 30 mins. This means entries will exist + # for at *LEAST* 30 mins, and at *MOST* 60 mins. + self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + + def fetch_or_execute_request(self, request, fn, *args, **kwargs): + """A helper function for fetch_or_execute which extracts + a transaction key from the given request. + + See: + fetch_or_execute + """ + return self.fetch_or_execute( + get_transaction_key(request), fn, *args, **kwargs + ) + + def fetch_or_execute(self, txn_key, fn, *args, **kwargs): + """Fetches the response for this transaction, or executes the given function + to produce a response for this transaction. + + Args: + txn_key (str): A key to ensure idempotency should fetch_or_execute be + called again at a later point in time. + fn (function): A function which returns a tuple of + (response_code, response_dict). + *args: Arguments to pass to fn. + **kwargs: Keyword arguments to pass to fn. + Returns: + Deferred which resolves to a tuple of (response_code, response_dict). + """ + try: + return self.transactions[txn_key][0].observe() + except (KeyError, IndexError): + pass # execute the function instead. + + deferred = fn(*args, **kwargs) + observable = ObservableDeferred(deferred) + self.transactions[txn_key] = (observable, self.clock.time_msec()) + return observable.observe() + + def _cleanup(self): + now = self.clock.time_msec() + for key in self.transactions.keys(): + ts = self.transactions[key][1] + if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period + del self.transactions[key] diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c2a8447860..c7aa0bbf59 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -18,7 +18,8 @@ from synapse.http.servlet import RestServlet from synapse.api.urls import CLIENT_PREFIX -from .transactions import HttpTransactionStore +from synapse.rest.client.transactions import HttpTransactionCache + import re import logging @@ -59,4 +60,4 @@ class ClientV1RestServlet(RestServlet): self.hs = hs self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() - self.txns = HttpTransactionStore() + self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 345018a8fc..093bc072f4 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet): password=login_submission["password"], ) device_id = yield self._register_device(user_id, login_submission) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, "device_id": device_id, } @@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet): yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) device_id = yield self._register_device(user_id, login_submission) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, "device_id": device_id, } @@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device( registered_user_id, login_submission ) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + registered_user_id, device_id, + login_submission.get("initial_device_display_name"), ) + result = { "user_id": registered_user_id, "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, } else: diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index b5a76fefac..ecf7e311a9 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet): def __init__(self, hs): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() - self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet): if "displayname" not in user_json: raise SynapseError(400, "Expected 'displayname' key.") - if "duration_seconds" not in user_json: - raise SynapseError(400, "Expected 'duration_seconds' key.") - localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") - duration_seconds = 0 - try: - duration_seconds = int(user_json["duration_seconds"]) - except ValueError: - raise SynapseError(400, "Failed to parse 'duration_seconds'") - if duration_seconds > self.direct_user_creation_max_duration: - duration_seconds = self.direct_user_creation_max_duration password_hash = user_json["password_hash"].encode("utf-8") \ if user_json.get("password_hash") else None @@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet): requester=requester, localpart=localpart, displayname=displayname, - duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 010fbc7c32..a0bba1fa3b 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -53,19 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): client_path_patterns("/createRoom(?:/.*)?$"), self.on_OPTIONS) - @defer.inlineCallbacks def on_PUT(self, request, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request + ) @defer.inlineCallbacks def on_POST(self, request): @@ -214,19 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id, event_type, txn_id): return (200, "Not implemented") - @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_id, event_type, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_type, txn_id + ) # TODO: Needs unit testing for room ID + alias joins @@ -283,19 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): defer.returnValue((200, {"room_id": room_id})) - @defer.inlineCallbacks def on_PUT(self, request, room_identifier, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_identifier, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) # TODO: Needs unit testing @@ -396,6 +369,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +class JoinedRoomMemberListRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") + + def __init__(self, hs): + super(JoinedRoomMemberListRestServlet, self).__init__(hs) + self.state = hs.get_state_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + yield self.auth.get_user_by_req(request) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + defer.returnValue((200, { + "joined": users_with_profile + })) + + # TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") @@ -537,22 +528,11 @@ class RoomForgetRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) - @defer.inlineCallbacks def on_PUT(self, request, room_id, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST( - request, room_id, txn_id + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, txn_id ) - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) - # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -623,22 +603,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet): return False return True - @defer.inlineCallbacks def on_PUT(self, request, room_id, membership_action, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST( - request, room_id, membership_action, txn_id + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, membership_action, txn_id ) - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) - class RoomRedactEventRestServlet(ClientV1RestServlet): def __init__(self, hs): @@ -669,19 +638,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): defer.returnValue((200, {"event_id": event.event_id})) - @defer.inlineCallbacks def on_PUT(self, request, room_id, event_id, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass - - response = yield self.on_POST(request, room_id, event_id, txn_id) - - self.txns.store_client_transaction(request, txn_id, response) - defer.returnValue(response) + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_id, txn_id + ) class RoomTypingRestServlet(ClientV1RestServlet): @@ -750,6 +710,22 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) +class JoinedRoomsRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/joined_rooms$") + + def __init__(self, hs): + super(JoinedRoomsRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = set(r.room_id for r in rooms) # Ensure they're unique. + defer.returnValue((200, {"joined_rooms": list(room_ids)})) + + def register_txn_path(servlet, regex_string, http_server, with_get=False): """Registers a transaction-based path. @@ -785,6 +761,7 @@ def register_servlets(hs, http_server): RoomStateEventRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) + JoinedRoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server) @@ -796,4 +773,5 @@ def register_servlets(hs, http_server): RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) RoomEventContext(hs).register(http_server) diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py deleted file mode 100644 index 2f2c9d0881..0000000000 --- a/synapse/rest/client/v1/transactions.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains logic for storing HTTP PUT transactions. This is used -to ensure idempotency when performing PUTs using the REST API.""" -import logging - -from synapse.api.auth import get_access_token_from_request - -logger = logging.getLogger(__name__) - - -# FIXME: elsewhere we use FooStore to indicate something in the storage layer... -class HttpTransactionStore(object): - - def __init__(self): - # { key : (txn_id, response) } - self.transactions = {} - - def get_response(self, key, txn_id): - """Retrieve a response for this request. - - Args: - key (str): A transaction-independent key for this request. Usually - this is a combination of the path (without the transaction id) - and the user's access token. - txn_id (str): The transaction ID for this request - Returns: - A tuple of (HTTP response code, response content) or None. - """ - try: - logger.debug("get_response TxnId: %s", txn_id) - (last_txn_id, response) = self.transactions[key] - if txn_id == last_txn_id: - logger.info("get_response: Returning a response for %s", txn_id) - return response - except KeyError: - pass - return None - - def store_response(self, key, txn_id, response): - """Stores an HTTP response tuple. - - Args: - key (str): A transaction-independent key for this request. Usually - this is a combination of the path (without the transaction id) - and the user's access token. - txn_id (str): The transaction ID for this request. - response (tuple): A tuple of (HTTP response code, response content) - """ - logger.debug("store_response TxnId: %s", txn_id) - self.transactions[key] = (txn_id, response) - - def store_client_transaction(self, request, txn_id, response): - """Stores the request/response pair of an HTTP transaction. - - Args: - request (twisted.web.http.Request): The twisted HTTP request. This - request must have the transaction ID as the last path segment. - response (tuple): A tuple of (response code, response dict) - txn_id (str): The transaction ID for this request. - """ - self.store_response(self._get_key(request), txn_id, response) - - def get_client_transaction(self, request, txn_id): - """Retrieves a stored response if there was one. - - Args: - request (twisted.web.http.Request): The twisted HTTP request. This - request must have the transaction ID as the last path segment. - txn_id (str): The transaction ID for this request. - Returns: - The response tuple. - Raises: - KeyError if the transaction was not found. - """ - response = self.get_response(self._get_key(request), txn_id) - if response is None: - raise KeyError("Transaction not found.") - return response - - def _get_key(self, request): - token = get_access_token_from_request(request) - path_without_txn_id = request.path.rsplit("/", 1)[0] - return path_without_txn_id + "/" + token diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 3ba0b0fc07..a1feaf3d54 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) devices = yield self.device_handler.get_devices_by_user( requester.user.to_string() ) @@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( requester.user.to_string(), device_id, @@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) body = servlet.parse_json_object_from_request(request) yield self.device_handler.update_device( diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f185f9a774..08b7c99d57 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -150,7 +150,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.query_devices(body, timeout) @@ -158,7 +158,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id @@ -204,7 +204,7 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) result = yield self.e2e_keys_handler.claim_one_time_keys( {"one_time_keys": {user_id: {device_id: algorithm}}}, @@ -214,7 +214,7 @@ class OneTimeKeyServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.claim_one_time_keys( diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 891cef99c6..1fbff2edd8 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet): super(ReceiptRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.receipts_handler = hs.get_handlers().receipts_handler + self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 68d18a9b82..3e7a285e10 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError @@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + body = parse_json_object_from_request(request) + kind = "user" if "kind" in request.args: kind = request.args["kind"][0] if kind == "guest": - ret = yield self._do_guest_registration() + ret = yield self._do_guest_registration(body) defer.returnValue(ret) return elif kind != "user": @@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - body = parse_json_object_from_request(request) - # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None @@ -169,6 +170,17 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) + if ( + 'initial_device_display_name' in body and + 'password' not in body + ): + # ignore 'initial_device_display_name' if sent without + # a password to work around a client bug where it sent + # the 'initial_device_display_name' param alone, wiping out + # the original registration params + logger.warn("Ignoring initial_device_display_name without password") + del body['initial_device_display_name'] + session_id = self.auth_handler.get_session_id(body) registered_user_id = None if session_id: @@ -362,8 +374,7 @@ class RegisterRestServlet(RestServlet): def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user - Allocates device_id if one was not given; also creates access_token - and refresh_token. + Allocates device_id if one was not given; also creates access_token. Args: (str) user_id: full canonical @user:id @@ -374,8 +385,8 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token, refresh_token = ( - yield self.auth_handler.get_login_tuple_for_user_id( + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( user_id, device_id=device_id, initial_display_name=params.get("initial_device_display_name") ) @@ -385,7 +396,6 @@ class RegisterRestServlet(RestServlet): "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, - "refresh_token": refresh_token, "device_id": device_id, }) @@ -410,20 +420,28 @@ class RegisterRestServlet(RestServlet): ) @defer.inlineCallbacks - def _do_guest_registration(self): + def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: defer.returnValue((403, "Guest access is disabled")) user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True ) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + access_token = self.auth_handler.generate_access_token( user_id, ["guest = true"] ) - # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok - # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, + "device_id": device_id, "access_token": access_token, "home_server": self.hs.hostname, })) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 5975164b37..d607bd2970 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request -from synapse.rest.client.v1.transactions import HttpTransactionStore +from synapse.rest.client.transactions import HttpTransactionCache from ._base import client_v2_patterns @@ -40,19 +40,17 @@ class SendToDeviceRestServlet(servlet.RestServlet): super(SendToDeviceRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.txns = HttpTransactionStore() + self.txns = HttpTransactionCache(hs.get_clock()) self.device_message_handler = hs.get_device_message_handler() - @defer.inlineCallbacks def on_PUT(self, request, message_type, txn_id): - try: - defer.returnValue( - self.txns.get_client_transaction(request, txn_id) - ) - except KeyError: - pass + return self.txns.fetch_or_execute_request( + request, self._put, request, message_type, txn_id + ) - requester = yield self.auth.get_user_by_req(request) + @defer.inlineCallbacks + def _put(self, request, message_type, txn_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -63,7 +61,6 @@ class SendToDeviceRestServlet(servlet.RestServlet): ) response = (200, {}) - self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 6fc63715aa..7199ec883a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id, filter.event_fields ) invited = self.encode_invited( @@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id, filter.event_fields ) response_content = { @@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": formatted} - def encode_joined(self, rooms, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id + room, time_now, token_id, only_fields=event_fields ) return joined @@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id, joined=False + room, time_now, token_id, joined=False, only_fields=event_fields ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True, only_fields=None): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet): of transaction IDs joined (bool): True if the user is joined to this room - will mean we handle ephemeral events - + only_fields(list<str>): Optional. The list of event fields to include. Returns: dict[str, object]: the room, encoded in our response format """ @@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet): return serialize_event( event, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + only_event_fields=only_fields, ) state_dict = room.state diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 0d312c91d4..6e76b9e9c2 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - self.hs = hs - self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_object_from_request(request) - try: - old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_auth_handler() - refresh_result = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token - ) - (user_id, new_refresh_token, device_id) = refresh_result - new_access_token = yield auth_handler.issue_access_token( - user_id, device_id - ) - defer.returnValue((200, { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - })) - except KeyError: - raise SynapseError(400, "Missing required key 'refresh_token'.") - except StoreError: - raise AuthError(403, "Did not recognize refresh token") + raise AuthError(403, "tokenrefresh is no longer supported.") def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 33f35fb44e..6a5a57102f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -543,5 +543,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # We always add an ellipsis because at the very least # we chopped mid paragraph. - description = new_desc.strip() + "…" + description = new_desc.strip() + u"…" return description if description else None diff --git a/synapse/server.py b/synapse/server.py index 374124a147..0bfb411269 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.federation import initialize_http_replication +from synapse.federation.send_queue import FederationRemoteSendQueue +from synapse.federation.transport.client import TransportLayerClient +from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler @@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler +from synapse.handlers.receipts import ReceiptsHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -124,6 +128,9 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'federation_transport_client', + 'federation_sender', + 'receipts_handler', ] def __init__(self, hostname, **kwargs): @@ -265,9 +272,30 @@ class HomeServer(object): def build_media_repository(self): return MediaRepository(self) + def build_federation_transport_client(self): + return TransportLayerClient(self) + + def build_federation_sender(self): + if self.should_send_federation(): + return TransactionQueue(self) + elif not self.config.worker_app: + return FederationRemoteSendQueue(self) + else: + raise Exception("Workers cannot send federation traffic") + + def build_receipts_handler(self): + return ReceiptsHandler(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) + def should_send_federation(self): + "Should this server be sending federation traffic directly?" + return self.config.send_federation and ( + not self.config.worker_app + or self.config.worker_app == "synapse.app.federation_sender" + ) + def _make_dependency_method(depname): def _get(hs): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9996f195a0..fe936b3e62 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") @@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore, ) self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() super(DataStore, self).__init__(hs) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d828d6ee1d..b62c459d8b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -561,12 +561,17 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + sql = ( - "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" + "SELECT %(retcol)s FROM %(table)s %(where)s" ) % { "retcol": retcol, "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + "where": where, } txn.execute(sql, keyvalues.values()) @@ -744,10 +749,15 @@ class SQLBaseStore(object): @staticmethod def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - update_sql = "UPDATE %s SET %s WHERE %s" % ( + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( table, ", ".join("%s = ?" % (k,) for k in updatevalues), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + where, ) txn.execute( diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 3d5994a580..514570561f 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore): def get_app_services(self): return self.services_cache + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service + """ + for service in self.services_cache: + if service.is_interested_in_user(user_id): + return True + return False + def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index f640e73714..87398d60bc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -269,27 +269,29 @@ class DeviceInboxStore(SQLBaseStore): return defer.succeed([]) def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id FROM device_inbox" + "SELECT stream_id, user_id" + " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY stream_id" " ORDER BY stream_id ASC" - " LIMIT ?" ) - txn.execute(sql, (last_pos, current_pos, limit)) - stream_ids = txn.fetchall() - if not stream_ids: - return [] - max_stream_id_in_limit = stream_ids[-1] + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() sql = ( - "SELECT stream_id, user_id, device_id, message_json" - " FROM device_inbox" + "SELECT stream_id, destination" + " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) - txn.execute(sql, (last_pos, max_stream_id_in_limit)) - return txn.fetchall() + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn.fetchall()) + + return rows return self.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 9cd923eb93..7de3e8c58c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore): columns=["user_id", "stream_ordering"], ) + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: @@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore): topological_ordering, stream_ordering ) + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 sql = ( - "SELECT sum(notif), sum(highlight)" + "SELECT count(*)" " FROM event_push_actions ea" " WHERE" " user_id = ?" @@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id)) row = txn.fetchone() - if row: - return { - "notify_count": row[0] or 0, - "highlight_count": row[1] or 0, - } - else: - return {"notify_count": 0, "highlight_count": 0} + notify_count = row[0] if row else 0 + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 49aeb953bd..ecb79c07ef 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -54,6 +54,7 @@ def encode_json(json_object): else: return json.dumps(json_object, ensure_ascii=False) + # These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 5248736816..a2ccc66ea7 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -16,6 +16,7 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.errors import SynapseError, Codes from synapse.util.caches.descriptors import cachedInlineCallbacks import simplejson as json @@ -24,6 +25,13 @@ import simplejson as json class FilteringStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_user_filter(self, user_localpart, filter_id): + # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail + # with a coherent error message rather than 500 M_UNKNOWN. + try: + int(filter_id) + except ValueError: + raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) + def_json = yield self._simple_select_one_onecol( table="user_filters", keyvalues={ diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6576a30098..e46ae6502e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 38 +SCHEMA_VERSION = 39 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 21d0696640..7460f98a1f 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState", status_msg (str): User set status message. """ + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + def copy_and_replace(self, **kwargs): return self._replace(**kwargs) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 49721656b6..cbec255966 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore): event=event, ) - local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = set( + u for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + ) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate, + local_users_in_room, + on_invalidate=cache_context.invalidate, ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 9747a04a9a..f72d15f5ed 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_current_token() + max_persisted_id = self._receipts_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index e404fa72de..983a8ec52b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="add_access_token_to_user", ) - @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token, device_id=None): - """Adds a refresh token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new refresh token to add. - device_id (str): ID of the device to associate with the access - token - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._refresh_tokens_id_gen.get_next() - - yield self._simple_insert( - "refresh_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - }, - desc="add_refresh_token_to_user", - ) - def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): @@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): token ) - def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new one. - - Doing so invalidates the old refresh token - refresh tokens are single - use. - - Args: - refresh_token (str): The refresh token of a user. - token_generator (fn: str -> str): Function which, when given a - user ID, returns a unique refresh token for that user. This - function must never return the same value twice. - Returns: - tuple of (user_id, new_refresh_token, device_id) - Raises: - StoreError if no user was found with that refresh token. - """ - return self.runInteraction( - "exchange_refresh_token", - self._exchange_refresh_token, - refresh_token, - token_generator - ) - - def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" - txn.execute(sql, (old_token,)) - rows = self.cursor_to_dict(txn) - if not rows: - raise StoreError(403, "Did not recognize refresh token") - user_id = rows[0]["user_id"] - device_id = rows[0]["device_id"] - - # TODO(danielwh): Maybe perform a validation on the macaroon that - # macaroon.user_id == user_id. - - new_token = token_generator(user_id) - sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" - txn.execute(sql, (new_token, old_token,)) - - return user_id, new_token, device_id - @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 866d64e679..b2a45a38c1 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging +import ujson as json logger = logging.getLogger(__name__) @@ -34,7 +35,15 @@ RoomsForUser = namedtuple( ) +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" + + class RoomMemberStore(SQLBaseStore): + def __init__(self, hs): + super(RoomMemberStore, self).__init__(hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. @@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore): "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), } for event in events ] @@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore): table="room_memberships", column="event_id", iterable=member_event_ids, - retcols=['user_id'], + retcols=['user_id', 'display_name', 'avatar_url'], keyvalues={ "membership": Membership.JOIN, }, @@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore): desc="_get_joined_users_from_context", ) - users_in_room = set(row["user_id"] for row in rows) + users_in_room = { + row["user_id"]: { + "display_name": row["display_name"], + "avatar_url": row["avatar_url"], + } + for row in rows + } + if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room.add(event.state_key) + users_in_room[event.state_key] = { + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } defer.returnValue(users_in_room) @@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(True) defer.returnValue(False) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = (""" + SELECT stream_ordering, event_id, room_id, content + FROM events + INNER JOIN room_memberships USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + content = json.loads(row["content"]) + except: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append(( + display_name, avatar_url, event_id, room_id + )) + + to_update_sql = (""" + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """) + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index:index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(to_update) + + result = yield self.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + + defer.returnValue(result) diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql new file mode 100644 index 0000000000..00be801e90 --- /dev/null +++ b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/schema/delta/39/event_push_index.sql new file mode 100644 index 0000000000..de2ad93e5c --- /dev/null +++ b/synapse/storage/schema/delta/39/event_push_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/schema/delta/39/federation_out_position.sql new file mode 100644 index 0000000000..5af814290b --- /dev/null +++ b/synapse/storage/schema/delta/39/federation_out_position.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE federation_stream_position( + type TEXT NOT NULL, + stream_id INTEGER NOT NULL + ); + + INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); + INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/schema/delta/39/membership_profile.sql new file mode 100644 index 0000000000..1bf911c8ab --- /dev/null +++ b/synapse/storage/schema/delta/39/membership_profile.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_memberships ADD COLUMN display_name TEXT; +ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; + +INSERT into background_updates (update_name, progress_json) + VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49abf0ac74..23e7ad9922 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -653,7 +653,10 @@ class StateStore(SQLBaseStore): else: state_dict = results[group] - state_dict.update(group_state_dict) + state_dict.update({ + (intern_string(k[0]), intern_string(k[1])): v + for k, v in group_state_dict.items() + }) self._state_group_cache.update( cache_seq_num, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 888b1cb35d..2dc24951c4 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore): def get_room_max_stream_ordering(self): return self._stream_id_gen.get_current_token() + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() + def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore): "token": end_token, }, } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events""" + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn, + ) + + events = yield self._get_events(event_ids) + + defer.returnValue((upper_bound, events)) + + def get_federation_out_pos(self, typ): + return self._simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ}, + desc="get_federation_out_pos" + ) + + def update_federation_out_pos(self, typ, stream_id): + return self._simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index adab520c78..809fdd311f 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -200,25 +200,48 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) + self.database_engine.lock_table(txn, "destinations") - self._simple_upsert_txn( + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + + # We need to be careful here as the data may have changed from under us + # due to a worker setting the timings. + + prev_row = self._simple_select_one_txn( txn, - "destinations", + table="destinations", keyvalues={ "destination": destination, }, - values={ - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - insertion_values={ - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - } + retcols=("retry_last_ts", "retry_interval"), + allow_none=True, ) + if not prev_row: + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } + ) + elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + self._simple_update_one_txn( + txn, + "destinations", + keyvalues={ + "destination": destination, + }, + updatevalues={ + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 2b3f0bef3c..30fc480108 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -24,6 +24,11 @@ import logging logger = logging.getLogger(__name__) +class DeferredTimedOutError(SynapseError): + def __init__(self): + super(SynapseError).__init__(504, "Timed out") + + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -34,7 +39,7 @@ class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. - TODO(paul): Also move the sleep() functionallity into it + TODO(paul): Also move the sleep() functionality into it """ def time(self): @@ -46,6 +51,14 @@ class Clock(object): return int(self.time() * 1000) def looping_call(self, f, msec): + """Call a function repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f(function): The function to call repeatedly. + msec(float): How long to wait between calls in milliseconds. + """ l = task.LoopingCall(f) l.start(msec / 1000.0, now=False) return l @@ -81,7 +94,7 @@ class Clock(object): def timed_out_fn(): try: - ret_deferred.errback(SynapseError(504, "Timed out")) + ret_deferred.errback(DeferredTimedOutError()) except: pass diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 3fd5c3d9fd..d668e5a6b8 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -76,15 +76,26 @@ class JsonEncodedObject(object): d.update(self.unrecognized_keys) return d + def get_internal_dict(self): + d = { + k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + if k in self.valid_keys + } + d.update(self.unrecognized_keys) + return d + def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) -def _encode(obj): +def _encode(obj, internal=False): if type(obj) is list: - return [_encode(o) for o in obj] + return [_encode(o, internal=internal) for o in obj] if isinstance(obj, JsonEncodedObject): - return obj.get_dict() + if internal: + return obj.get_internal_dict() + else: + return obj.get_dict() return obj diff --git a/synapse/util/ldap_auth_provider.py b/synapse/util/ldap_auth_provider.py deleted file mode 100644 index 1b989248fb..0000000000 --- a/synapse/util/ldap_auth_provider.py +++ /dev/null @@ -1,369 +0,0 @@ - -from twisted.internet import defer - -from synapse.config._base import ConfigError -from synapse.types import UserID - -import ldap3 -import ldap3.core.exceptions - -import logging - -try: - import ldap3 - import ldap3.core.exceptions -except ImportError: - ldap3 = None - pass - - -logger = logging.getLogger(__name__) - - -class LDAPMode(object): - SIMPLE = "simple", - SEARCH = "search", - - LIST = (SIMPLE, SEARCH) - - -class LdapAuthProvider(object): - __version__ = "0.1" - - def __init__(self, config, account_handler): - self.account_handler = account_handler - - if not ldap3: - raise RuntimeError( - 'Missing ldap3 library. This is required for LDAP Authentication.' - ) - - self.ldap_mode = config.mode - self.ldap_uri = config.uri - self.ldap_start_tls = config.start_tls - self.ldap_base = config.base - self.ldap_attributes = config.attributes - if self.ldap_mode == LDAPMode.SEARCH: - self.ldap_bind_dn = config.bind_dn - self.ldap_bind_password = config.bind_password - self.ldap_filter = config.filter - - @defer.inlineCallbacks - def check_password(self, user_id, password): - """ Attempt to authenticate a user against an LDAP Server - and register an account if none exists. - - Returns: - True if authentication against LDAP was successful - """ - localpart = UserID.from_string(user_id).localpart - - try: - server = ldap3.Server(self.ldap_uri) - logger.debug( - "Attempting LDAP connection with %s", - self.ldap_uri - ) - - if self.ldap_mode == LDAPMode.SIMPLE: - result, conn = self._ldap_simple_bind( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP authentication method simple bind returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - elif self.ldap_mode == LDAPMode.SEARCH: - result, conn = self._ldap_authenticated_search( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP auth method authenticated search returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - else: - raise RuntimeError( - 'Invalid LDAP mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) - - try: - logger.info( - "User authenticated against LDAP server: %s", - conn - ) - except NameError: - logger.warn( - "Authentication method yielded no LDAP connection, aborting!" - ) - defer.returnValue(False) - - # check if user with user_id exists - if (yield self.account_handler.check_user_exists(user_id)): - # exists, authentication complete - conn.unbind() - defer.returnValue(True) - - else: - # does not exist, fetch metadata for account creation from - # existing ldap connection - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - - if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: - query = "(&{filter}{user_filter})".format( - filter=query, - user_filter=self.ldap_filter - ) - logger.debug( - "ldap registration filter: %s", - query - ) - - conn.search( - search_base=self.ldap_base, - search_filter=query, - attributes=[ - self.ldap_attributes['name'], - self.ldap_attributes['mail'] - ] - ) - - if len(conn.response) == 1: - attrs = conn.response[0]['attributes'] - mail = attrs[self.ldap_attributes['mail']][0] - name = attrs[self.ldap_attributes['name']][0] - - # create account - user_id, access_token = ( - yield self.account_handler.register(localpart=localpart) - ) - - # TODO: bind email, set displayname with data from ldap directory - - logger.info( - "Registration based on LDAP data was successful: %d: %s (%s, %)", - user_id, - localpart, - name, - mail - ) - - defer.returnValue(True) - else: - if len(conn.response) == 0: - logger.warn("LDAP registration failed, no result.") - else: - logger.warn( - "LDAP registration failed, too many results (%s)", - len(conn.response) - ) - - defer.returnValue(False) - - defer.returnValue(False) - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during ldap authentication: %s", e) - defer.returnValue(False) - - @staticmethod - def parse_config(config): - class _LdapConfig(object): - pass - - ldap_config = _LdapConfig() - - ldap_config.enabled = config.get("enabled", False) - - ldap_config.mode = LDAPMode.SIMPLE - - # verify config sanity - _require_keys(config, [ - "uri", - "base", - "attributes", - ]) - - ldap_config.uri = config["uri"] - ldap_config.start_tls = config.get("start_tls", False) - ldap_config.base = config["base"] - ldap_config.attributes = config["attributes"] - - if "bind_dn" in config: - ldap_config.mode = LDAPMode.SEARCH - _require_keys(config, [ - "bind_dn", - "bind_password", - ]) - - ldap_config.bind_dn = config["bind_dn"] - ldap_config.bind_password = config["bind_password"] - ldap_config.filter = config.get("filter", None) - - # verify attribute lookup - _require_keys(config['attributes'], [ - "uid", - "name", - "mail", - ]) - - return ldap_config - - def _ldap_simple_bind(self, server, localpart, password): - """ Attempt a simple bind with the credentials - given by the user against the LDAP server. - - Returns True, LDAP3Connection - if the bind was successful - Returns False, None - if an error occured - """ - - try: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base - ) - conn = ldap3.Connection(server, bind_dn, password, - authentication=ldap3.AUTH_SIMPLE) - logger.debug( - "Established LDAP connection in simple bind mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in simple bind mode through StartTLS: %s", - conn - ) - - if conn.bind(): - # GOOD: bind okay - logger.debug("LDAP Bind successful in simple bind mode.") - return True, conn - - # BAD: bind failed - logger.info( - "Binding against LDAP failed for '%s' failed: %s", - localpart, conn.result['description'] - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - def _ldap_authenticated_search(self, server, localpart, password): - """ Attempt to login with the preconfigured bind_dn - and then continue searching and filtering within - the base_dn - - Returns (True, LDAP3Connection) - if a single matching DN within the base was found - that matched the filter expression, and with which - a successful bind was achieved - - The LDAP3Connection returned is the instance that was used to - verify the password not the one using the configured bind_dn. - Returns (False, None) - if an error occured - """ - - try: - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password - ) - logger.debug( - "Established LDAP connection in search mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in search mode through StartTLS: %s", - conn - ) - - if not conn.bind(): - logger.warn( - "Binding against LDAP with `bind_dn` failed: %s", - conn.result['description'] - ) - conn.unbind() - return False, None - - # construct search_filter like (uid=localpart) - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - if self.ldap_filter: - # combine with the AND expression - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug( - "LDAP search filter: %s", - query - ) - conn.search( - search_base=self.ldap_base, - search_filter=query - ) - - if len(conn.response) == 1: - # GOOD: found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('LDAP search found dn: %s', user_dn) - - # unbind and simple bind with user_dn to verify the password - # Note: do not use rebind(), for some reason it did not verify - # the password for me! - conn.unbind() - return self._ldap_simple_bind(server, localpart, password) - else: - # BAD: found 0 or > 1 results, abort! - if len(conn.response) == 0: - logger.info( - "LDAP search returned no results for '%s'", - localpart - ) - else: - logger.info( - "LDAP search returned too many (%s) results for '%s'", - len(conn.response), localpart - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - -def _require_keys(config, required): - missing = [key for key in required if key not in config] - if missing: - raise ConfigError( - "LDAP enabled but missing required config values: {}".format( - ", ".join(missing) - ) - ) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 49527f4d21..e2de7fce91 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -121,15 +121,9 @@ class RetryDestinationLimiter(object): pass def __exit__(self, exc_type, exc_val, exc_tb): - def err(failure): - logger.exception( - "Failed to store set_destination_retry_timings", - failure.value - ) - valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): - valid_err_code = 0 <= exc_val.code < 500 + valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 if exc_type is None or valid_err_code: # We connected successfully. @@ -151,6 +145,15 @@ class RetryDestinationLimiter(object): retry_last_ts = int(self.clock.time_msec()) - self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval - ).addErrback(err) + @defer.inlineCallbacks + def store_retry_timings(): + try: + yield self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ) + except: + logger.exception( + "Failed to store set_destination_retry_timings", + ) + + store_retry_timings() |