summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py22
-rw-r--r--synapse/config/_base.py6
-rw-r--r--synapse/config/appservice.py2
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/federation/transport/client.py23
-rw-r--r--synapse/federation/transport/server.py14
-rw-r--r--synapse/handlers/auth.py39
-rw-r--r--synapse/handlers/events.py7
-rw-r--r--synapse/handlers/initial_sync.py6
-rw-r--r--synapse/handlers/profile.py10
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/handlers/state_deltas.py70
-rw-r--r--synapse/handlers/user_directory.py51
-rw-r--r--synapse/http/matrixfederationclient.py86
-rw-r--r--synapse/module_api/__init__.py18
-rw-r--r--synapse/replication/tcp/protocol.py35
-rw-r--r--synapse/replication/tcp/resource.py1
-rw-r--r--synapse/replication/tcp/streams.py11
-rw-r--r--synapse/rest/client/v1/login.py49
-rw-r--r--synapse/storage/receipts.py4
-rw-r--r--synapse/storage/state_deltas.py74
-rw-r--r--synapse/storage/user_directory.py94
23 files changed, 443 insertions, 193 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 25c10244d3..6bb5a8b24d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try: except ImportError: pass -__version__ = "0.99.2" +__version__ = "0.99.3" diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index ab4d2b8f11..f4171da6e3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -614,13 +614,13 @@ class Auth(object): Returns: True if the the sender is allowed to redact the target event if the - target event was created by them. + target event was created by them. False if the sender is allowed to redact the target event with no - further checks. + further checks. Raises: AuthError if the event sender is definitely not allowed to redact - the target event. + the target event. """ return event_auth.check_redaction(room_version, event, auth_events) @@ -736,9 +736,9 @@ class Auth(object): Returns: Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + the user in the room and the membership event ID of the user. If + the user is not in the room and never has been, then + `(Membership.JOIN, None)` is returned. """ try: @@ -770,13 +770,13 @@ class Auth(object): Args: user_id(str|None): If present, checks for presence against existing - MAU cohort + MAU cohort threepid(dict|None): If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. """ # Never fail an auth check for the server notices users or support user diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a219a83550..f7d7f153bb 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -137,7 +137,7 @@ class Config(object): @staticmethod def read_config_file(file_path): with open(file_path) as file_stream: - return yaml.load(file_stream) + return yaml.safe_load(file_stream) def invoke_all(self, name, *args, **kargs): results = [] @@ -318,7 +318,7 @@ class Config(object): ) config_file.write(config_str) - config = yaml.load(config_str) + config = yaml.safe_load(config_str) obj.invoke_all("generate_files", config) print( @@ -390,7 +390,7 @@ class Config(object): server_name=server_name, generate_secrets=False, ) - config = yaml.load(config_string) + config = yaml.safe_load(config_string) config.pop("log_config") config.update(specified_config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 9e64c76544..7e89d345d8 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py
@@ -68,7 +68,7 @@ def load_appservices(hostname, config_files): try: with open(config_file, 'r') as f: appservice = _load_appservice( - hostname, yaml.load(f), config_file + hostname, yaml.safe_load(f), config_file ) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 464c28c2d9..c1febbe9d3 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -195,7 +195,7 @@ def setup_logging(config, use_worker_options=False): else: def load_log_config(): with open(log_config, 'r') as f: - logging.config.dictConfig(yaml.load(f)) + logging.config.dictConfig(yaml.safe_load(f)) def sighup(*args): # it might be better to use a file watcher or something for this. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8e2be218e2..e424c40fdf 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -51,9 +51,10 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state/%s/", room_id) + path = _create_v1_path("/state/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -73,9 +74,10 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state_ids/%s/", room_id) + path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -95,8 +97,11 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = _create_v1_path("/event/%s/", event_id) - return self.client.get_json(destination, path=path, timeout=timeout) + path = _create_v1_path("/event/%s", event_id) + return self.client.get_json( + destination, path=path, timeout=timeout, + try_trailing_slash_on_400=True, + ) @log_function def backfill(self, destination, room_id, event_tuples, limit): @@ -121,7 +126,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = _create_v1_path("/backfill/%s/", room_id) + path = _create_v1_path("/backfill/%s", room_id) args = { "v": event_tuples, @@ -132,6 +137,7 @@ class TransportLayerClient(object): destination, path=path, args=args, + try_trailing_slash_on_400=True, ) @defer.inlineCallbacks @@ -167,7 +173,7 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_v1_path("/send/%s/", transaction.transaction_id) + path = _create_v1_path("/send/%s", transaction.transaction_id) response = yield self.client.put_json( transaction.destination, @@ -176,6 +182,7 @@ class TransportLayerClient(object): json_data_callback=json_data_callback, long_retries=True, backoff_on_404=True, # If we get a 404 the other side has gone + try_trailing_slash_on_400=True, ) defer.returnValue(response) @@ -959,7 +966,7 @@ def _create_v1_path(path, *args): Example: - _create_v1_path("/event/%s/", event_id) + _create_v1_path("/event/%s", event_id) Args: path (str): String template for the path @@ -980,7 +987,7 @@ def _create_v2_path(path, *args): Example: - _create_v2_path("/event/%s/", event_id) + _create_v2_path("/event/%s", event_id) Args: path (str): String template for the path diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 96d680a5ad..efb6bdca48 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -312,7 +312,7 @@ class BaseFederationServlet(object): class FederationSendServlet(BaseFederationServlet): - PATH = "/send/(?P<transaction_id>[^/]*)/" + PATH = "/send/(?P<transaction_id>[^/]*)/?" def __init__(self, handler, server_name, **kwargs): super(FederationSendServlet, self).__init__( @@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet): - PATH = "/event/(?P<event_id>[^/]*)/" + PATH = "/event/(?P<event_id>[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. def on_GET(self, origin, content, query, event_id): @@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet): - PATH = "/state/(?P<context>[^/]*)/" + PATH = "/state/(?P<context>[^/]*)/?" # This is when someone asks for all data for a given context. def on_GET(self, origin, content, query, context): @@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet): class FederationStateIdsServlet(BaseFederationServlet): - PATH = "/state_ids/(?P<room_id>[^/]*)/" + PATH = "/state_ids/(?P<room_id>[^/]*)/?" def on_GET(self, origin, content, query, room_id): return self.handler.on_state_ids_request( @@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/(?P<context>[^/]*)/" + PATH = "/backfill/(?P<context>[^/]*)/?" def on_GET(self, origin, content, query, context): versions = [x.decode('ascii') for x in query[b"v"]] @@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/categories/" + "/groups/(?P<group_id>[^/]*)/categories/?" ) @defer.inlineCallbacks @@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ PATH = ( - "/groups/(?P<group_id>[^/]*)/roles/" + "/groups/(?P<group_id>[^/]*)/roles/?" ) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index caad9ae2dd..4544de821d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -746,6 +746,42 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks + def check_password_provider_3pid(self, medium, address, password): + """Check if a password provider is able to validate a thirdparty login + + Args: + medium (str): The medium of the 3pid (ex. email). + address (str): The address of the 3pid (ex. jdoe@example.com). + password (str): The password of the user. + + Returns: + Deferred[(str|None, func|None)]: A tuple of `(user_id, + callback)`. If authentication is successful, `user_id` is a `str` + containing the authenticated, canonical user ID. `callback` is + then either a function to be later run after the server has + completed login/registration, or `None`. If authentication was + unsuccessful, `user_id` and `callback` are both `None`. + """ + for provider in self.password_providers: + if hasattr(provider, "check_3pid_auth"): + # This function is able to return a deferred that either + # resolves None, meaning authentication failure, or upon + # success, to a str (which is the user_id) or a tuple of + # (user_id, callback_func), where callback_func should be run + # after we've finished everything else + result = yield provider.check_3pid_auth( + medium, address, password, + ) + if result: + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + result = (result, None) + defer.returnValue(result) + + defer.returnValue((None, None)) + + @defer.inlineCallbacks def _check_local_password(self, user_id, password): """Authenticate a user against the local password database. @@ -756,7 +792,8 @@ class AuthHandler(BaseHandler): user_id (unicode): complete @user:id password (unicode): the provided password Returns: - (unicode) the canonical_user_id, or None if unknown user / bad password + Deferred[unicode] the canonical_user_id, or Deferred[None] if + unknown user/bad password Raises: LimitExceededError if the ratelimiter's login requests count for this diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f772e62c28..d883e98381 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -19,7 +19,7 @@ import random from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.types import UserID @@ -61,6 +61,11 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ + if room_id: + blocked = yield self.store.is_room_blocked(room_id) + if blocked: + raise SynapseError(403, "This room has been blocked on this server") + # send any outstanding server notices to the user. yield self._server_notices_sender.on_user_syncing(auth_user_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 563bb3cea3..7dfae78db0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -262,6 +262,10 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ + blocked = yield self.store.is_room_blocked(room_id) + if blocked: + raise SynapseError(403, "This room has been blocked on this server") + user_id = requester.user.to_string() membership, member_event_id = yield self._check_in_room_or_world_readable( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 7b8c77ba4d..2df2eaf609 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -233,8 +233,14 @@ class BaseProfileHandler(BaseHandler): @defer.inlineCallbacks def set_displayname(self, target_user, requester, new_displayname, by_admin=False): - """target_user is the UserID whose displayname is to be changed; - requester is the authenticated user attempting to make this change.""" + """Set the displayname of a user + + Args: + target_user (UserID): the user whose displayname is to be changed. + requester (Requester): The user attempting to make this change. + new_displayname (str): The displayname to give this user. + by_admin (bool): Whether this change was made by an administrator. + """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b0468cd3d5..f4745614f1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -172,7 +172,7 @@ class RegistrationHandler(BaseHandler): api.constants.UserTypes, or None for a normal user. default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. - address (str|None): the IP address used to perform the regitration. + address (str|None): the IP address used to perform the registration. Returns: A tuple of (user_id, access_token). Raises: @@ -719,7 +719,7 @@ class RegistrationHandler(BaseHandler): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the regitration. + address (str|None): the IP address used to perform the registration. Returns: Deferred @@ -817,9 +817,9 @@ class RegistrationHandler(BaseHandler): access_token (str|None): The access token of the newly logged in device, or None if `inhibit_login` enabled. bind_email (bool): Whether to bind the email with the identity - server + server. bind_msisdn (bool): Whether to bind the msisdn with the identity - server + server. """ if self.hs.config.worker_app: yield self._post_registration_client( @@ -861,7 +861,7 @@ class RegistrationHandler(BaseHandler): """A user consented to the terms on registration Args: - user_id (str): The user ID that consented + user_id (str): The user ID that consented. consent_version (str): version of the policy the user has consented to. """ diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py new file mode 100644
index 0000000000..b268bbcb2c --- /dev/null +++ b/synapse/handlers/state_deltas.py
@@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +class StateDeltasHandler(object): + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + """Given two events check if the `key_name` field in content changed + from not matching `public_value` to doing so. + + For example, check if `history_visibility` (`key_name`) changed from + `shared` to `world_readable` (`public_value`). + + Returns: + None if the field in the events either both match `public_value` + or if neither do, i.e. there has been no change. + True if it didnt match `public_value` but now does + False if it did match `public_value` but now doesn't + """ + prev_event = None + event = None + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + + if not event and not prev_event: + logger.debug("Neither event exists: %r %r", prev_event_id, event_id) + defer.returnValue(None) + + prev_value = None + value = None + + if prev_event: + prev_value = prev_event.content.get(key_name) + + if event: + value = event.content.get(key_name) + + logger.debug("prev_value: %r -> value: %r", prev_value, value) + + if value == public_value and prev_value != public_value: + defer.returnValue(True) + elif value != public_value and prev_value == public_value: + defer.returnValue(False) + else: + defer.returnValue(None) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 7dc0e236e7..b689979b4b 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import get_localpart_from_id @@ -29,7 +30,7 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -class UserDirectoryHandler(object): +class UserDirectoryHandler(StateDeltasHandler): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,6 +42,8 @@ class UserDirectoryHandler(object): """ def __init__(self, hs): + super(UserDirectoryHandler, self).__init__(hs) + self.store = hs.get_datastore() self.state = hs.get_state_handler() self.server_name = hs.hostname @@ -360,7 +363,7 @@ class UserDirectoryHandler(object): @defer.inlineCallbacks def _handle_remove_user(self, room_id, user_id): - """Called when we might need to remove user to directory + """Called when we might need to remove user from directory Args: room_id (str): room_id that user left or stopped being public that @@ -402,47 +405,3 @@ class UserDirectoryHandler(object): if prev_name != new_name or prev_avatar != new_avatar: yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) - - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): - """Given two events check if the `key_name` field in content changed - from not matching `public_value` to doing so. - - For example, check if `history_visibility` (`key_name`) changed from - `shared` to `world_readable` (`public_value`). - - Returns: - None if the field in the events either both match `public_value` - or if neither do, i.e. there has been no change. - True if it didnt match `public_value` but now does - False if it did match `public_value` but now doesn't - """ - prev_event = None - event = None - if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - - if event_id: - event = yield self.store.get_event(event_id, allow_none=True) - - if not event and not prev_event: - logger.debug("Neither event exists: %r %r", prev_event_id, event_id) - defer.returnValue(None) - - prev_value = None - value = None - - if prev_event: - prev_value = prev_event.content.get(key_name) - - if event: - value = event.content.get(key_name) - - logger.debug("prev_value: %r -> value: %r", prev_value, value) - - if value == public_value and prev_value != public_value: - defer.returnValue(True) - elif value != public_value and prev_value == public_value: - defer.returnValue(False) - else: - defer.returnValue(None) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1682c9af13..ff63d0b2a8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -189,6 +189,58 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks + def _send_request_with_optional_trailing_slash( + self, + request, + try_trailing_slash_on_400=False, + **send_request_args + ): + """Wrapper for _send_request which can optionally retry the request + upon receiving a combination of a 400 HTTP response code and a + 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 + due to #3622. + + Args: + request (MatrixFederationRequest): details of request to be sent + try_trailing_slash_on_400 (bool): Whether on receiving a 400 + 'M_UNRECOGNIZED' from the server to retry the request with a + trailing slash appended to the request path. + send_request_args (Dict): A dictionary of arguments to pass to + `_send_request()`. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + + Returns: + Deferred[Dict]: Parsed JSON response body. + """ + try: + response = yield self._send_request( + request, **send_request_args + ) + except HttpResponseException as e: + # Received an HTTP error > 300. Check if it meets the requirements + # to retry with a trailing slash + if not try_trailing_slash_on_400: + raise + + if e.code != 400 or e.to_synapse_error().errcode != "M_UNRECOGNIZED": + raise + + # Retry with a trailing slash if we received a 400 with + # 'M_UNRECOGNIZED' which some endpoints can return when omitting a + # trailing slash on Synapse <= v0.99.3. + logger.info("Retrying request with trailing slash") + request.path += "/" + + response = yield self._send_request( + request, **send_request_args + ) + + defer.returnValue(response) + + @defer.inlineCallbacks def _send_request( self, request, @@ -196,7 +248,7 @@ class MatrixFederationHttpClient(object): timeout=None, long_retries=False, ignore_backoff=False, - backoff_on_404=False + backoff_on_404=False, ): """ Sends a request to the given server. @@ -473,7 +525,8 @@ class MatrixFederationHttpClient(object): json_data_callback=None, long_retries=False, timeout=None, ignore_backoff=False, - backoff_on_404=False): + backoff_on_404=False, + try_trailing_slash_on_400=False): """ Sends the specifed json data using PUT Args: @@ -493,7 +546,12 @@ class MatrixFederationHttpClient(object): and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as a failure of the server (and should therefore back off future - requests) + requests). + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end + of the request. Workaround for #3622 in Synapse <= v0.99.3. This + will be attempted before backing off if backing off has been + enabled. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The @@ -509,7 +567,6 @@ class MatrixFederationHttpClient(object): RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ - request = MatrixFederationRequest( method="PUT", destination=destination, @@ -519,17 +576,19 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=backoff_on_404, + ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, - ignore_backoff=ignore_backoff, - backoff_on_404=backoff_on_404, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks @@ -592,7 +651,8 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False): + timeout=None, ignore_backoff=False, + try_trailing_slash_on_400=False): """ GETs some json from the given host homeserver and path Args: @@ -606,6 +666,9 @@ class MatrixFederationHttpClient(object): be retried. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end of + the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -631,16 +694,19 @@ class MatrixFederationHttpClient(object): query=args, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=False, + ignore_backoff=ignore_backoff, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, - ignore_backoff=ignore_backoff, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index fc9a20ff59..235ce8334e 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -73,14 +73,26 @@ class ModuleApi(object): """ return self._auth_handler.check_user_exists(user_id) - def register(self, localpart): - """Registers a new user with given localpart + @defer.inlineCallbacks + def register(self, localpart, displayname=None): + """Registers a new user with given localpart and optional + displayname. + + Args: + localpart (str): The localpart of the new user. + displayname (str|None): The displayname of the new user. If None, + the user's displayname will default to `localpart`. Returns: Deferred: a 2-tuple of (user_id, access_token) """ + # Register the user reg = self.hs.get_registration_handler() - return reg.register(localpart=localpart) + user_id, access_token = yield reg.register( + localpart=localpart, default_display_name=displayname, + ) + + defer.returnValue((user_id, access_token)) @defer.inlineCallbacks def invalidate_access_token(self, access_token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 55630ba9a7..02e5bf6cc8 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -223,14 +223,25 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return # Now lets try and call on_<CMD_NAME> function - try: - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - getattr(self, "on_%s" % (cmd_name,)), - cmd, - ) - except Exception: - logger.exception("[%s] Failed to handle line: %r", self.id(), line) + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), + self.handle_command, + cmd, + ) + + def handle_command(self, cmd): + """Handle a command we have received over the replication stream. + + By default delegates to on_<COMMAND> + + Args: + cmd (synapse.replication.tcp.commands.Command): received command + + Returns: + Deferred + """ + handler = getattr(self, "on_%s" % (cmd.NAME,)) + return handler(cmd) def close(self): logger.warn("[%s] Closing connection", self.id()) @@ -364,8 +375,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.unregisterProducer() def __str__(self): + addr = None + if self.transport: + addr = str(self.transport.getPeer()) return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % ( - self.name, self.conn_id, self.addr, + self.name, self.conn_id, addr, ) def id(self): @@ -381,12 +395,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer, addr): + def __init__(self, server_name, clock, streamer): BaseReplicationStreamProtocol.__init__(self, clock) # Old style class self.server_name = server_name self.streamer = streamer - self.addr = addr # The streams the client has subscribed to and is up to date with self.replication_streams = set() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 47cdf30bd3..7fc346c7b6 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -57,7 +57,6 @@ class ReplicationStreamProtocolFactory(Factory): self.server_name, self.clock, self.streamer, - addr ) diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index c1e626be3f..e23084baae 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py
@@ -23,7 +23,7 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ - +import itertools import logging from collections import namedtuple @@ -195,8 +195,8 @@ class Stream(object): limit=MAX_EVENTS_BEHIND + 1, ) - if len(rows) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) + # never turn more than MAX_EVENTS_BEHIND + 1 into updates. + rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: rows = yield self.update_function( from_token, current_token, @@ -204,6 +204,11 @@ class Stream(object): updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + # check we didn't get more rows than the limit. + # doing it like this allows the update_function to be a generator. + if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: + raise Exception("stream %s has fallen behind" % (self.NAME)) + defer.returnValue((updates, current_token)) def current_token(self): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8d56effbb8..5180e9eaf1 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -201,6 +201,24 @@ class LoginRestServlet(ClientV1RestServlet): # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) address = address.lower() + + # Check for login providers that support 3pid login types + canonical_user_id, callback_3pid = ( + yield self.auth_handler.check_password_provider_3pid( + medium, + address, + login_submission["password"], + ) + ) + if canonical_user_id: + # Authentication through password provider and 3pid succeeded + result = yield self._register_device_with_callback( + canonical_user_id, login_submission, callback_3pid, + ) + defer.returnValue(result) + + # No password providers were able to handle this 3pid + # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( medium, address, ) @@ -223,20 +241,43 @@ class LoginRestServlet(ClientV1RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - auth_handler = self.auth_handler - canonical_user_id, callback = yield auth_handler.validate_login( + canonical_user_id, callback = yield self.auth_handler.validate_login( identifier["user"], login_submission, ) + result = yield self._register_device_with_callback( + canonical_user_id, login_submission, callback, + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _register_device_with_callback( + self, + user_id, + login_submission, + callback=None, + ): + """ Registers a device with a given user_id. Optionally run a callback + function after registration has completed. + + Args: + user_id (str): ID of the user to register. + login_submission (dict): Dictionary of login information. + callback (func|None): Callback function to run after registration. + + Returns: + result (Dict[str,str]): Dictionary of account information after + successful registration. + """ device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - canonical_user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name, ) result = { - "user_id": canonical_user_id, + "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 0fd1ccc40a..89a1f7e3d7 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py
@@ -301,7 +301,9 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return txn.fetchall() + return ( + r[0:5] + (json.loads(r[5]), ) for r in txn + ) return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py new file mode 100644
index 0000000000..57bc45cdb9 --- /dev/null +++ b/synapse/storage/state_deltas.py
@@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class StateDeltasStore(SQLBaseStore): + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index d360e857d1..4d60a5726f 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, JoinRules from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.state import StateFilter +from synapse.storage.state_deltas import StateDeltasStore from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -31,7 +32,7 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryStore(BackgroundUpdateStore): +class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -134,7 +135,12 @@ class UserDirectoryStore(BackgroundUpdateStore): @defer.inlineCallbacks def _populate_user_directory_process_rooms(self, progress, batch_size): - + """ + Args: + progress (dict) + batch_size (int): Maximum number of state events to process + per cycle. + """ state = self.hs.get_state_handler() # If we don't have progress filed, delete everything. @@ -142,13 +148,14 @@ class UserDirectoryStore(BackgroundUpdateStore): yield self.delete_all_from_user_dir() def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. sql = """ - SELECT room_id FROM %s + SELECT room_id, events FROM %s ORDER BY events DESC - LIMIT %s + LIMIT 250 """ % ( TEMP_TABLE + "_rooms", - str(batch_size), ) txn.execute(sql) rooms_to_work_on = txn.fetchall() @@ -156,8 +163,6 @@ class UserDirectoryStore(BackgroundUpdateStore): if not rooms_to_work_on: return None - rooms_to_work_on = [x[0] for x in rooms_to_work_on] - # Get how many are left to process, so we can give status on how # far we are in processing txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") @@ -179,7 +184,9 @@ class UserDirectoryStore(BackgroundUpdateStore): % (len(rooms_to_work_on), progress["remaining"]) ) - for room_id in rooms_to_work_on: + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: is_in_room = yield self.is_host_joined(room_id, self.server_name) if is_in_room: @@ -246,7 +253,13 @@ class UserDirectoryStore(BackgroundUpdateStore): progress, ) - defer.returnValue(len(rooms_to_work_on)) + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + defer.returnValue(processed_event_count) + + defer.returnValue(processed_event_count) @defer.inlineCallbacks def _populate_user_directory_process_users(self, progress, batch_size): @@ -488,16 +501,6 @@ class UserDirectoryStore(BackgroundUpdateStore): defer.returnValue(user_ids) - @defer.inlineCallbacks - def get_all_local_users(self): - """Get all local users - """ - sql = """ - SELECT name FROM users - """ - rows = yield self._execute("get_all_local_users", None, sql) - defer.returnValue([name for name, in rows]) - def add_users_who_share_private_room(self, room_id, user_id_tuples): """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. @@ -675,59 +678,6 @@ class UserDirectoryStore(BackgroundUpdateStore): desc="update_user_directory_stream_pos", ) - def get_current_state_deltas(self, prev_stream_id): - prev_stream_id = int(prev_stream_id) - if not self._curr_state_delta_stream_cache.has_any_entity_changed( - prev_stream_id - ): - return [] - - def get_current_state_deltas_txn(txn): - # First we calculate the max stream id that will give us less than - # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) - FROM current_state_delta_stream - WHERE stream_id > ? - GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 - """ - txn.execute(sql, (prev_stream_id,)) - - total = 0 - max_stream_id = prev_stream_id - for max_stream_id, count in txn: - total += count - if total > 100: - # We arbitarily limit to 100 entries to ensure we don't - # select toooo many. - break - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - return self.cursor_to_dict(txn) - - return self.runInteraction( - "get_current_state_deltas", get_current_state_deltas_txn - ) - - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( - table="current_state_delta_stream", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", - ) - @defer.inlineCallbacks def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory