diff --git a/synapse/__init__.py b/synapse/__init__.py
index 25c10244d3..6bb5a8b24d 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.99.2"
+__version__ = "0.99.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index ab4d2b8f11..f4171da6e3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -614,13 +614,13 @@ class Auth(object):
Returns:
True if the the sender is allowed to redact the target event if the
- target event was created by them.
+ target event was created by them.
False if the sender is allowed to redact the target event with no
- further checks.
+ further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
- the target event.
+ the target event.
"""
return event_auth.check_redaction(room_version, event, auth_events)
@@ -736,9 +736,9 @@ class Auth(object):
Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
- the user in the room and the membership event ID of the user. If
- the user is not in the room and never has been, then
- `(Membership.JOIN, None)` is returned.
+ the user in the room and the membership event ID of the user. If
+ the user is not in the room and never has been, then
+ `(Membership.JOIN, None)` is returned.
"""
try:
@@ -770,13 +770,13 @@ class Auth(object):
Args:
user_id(str|None): If present, checks for presence against existing
- MAU cohort
+ MAU cohort
threepid(dict|None): If present, checks for presence against configured
- reserved threepid. Used in cases where the user is trying register
- with a MAU blocked server, normally they would be rejected but their
- threepid is on the reserved list. user_id and
- threepid should never be set at the same time.
+ reserved threepid. Used in cases where the user is trying register
+ with a MAU blocked server, normally they would be rejected but their
+ threepid is on the reserved list. user_id and
+ threepid should never be set at the same time.
"""
# Never fail an auth check for the server notices users or support user
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a219a83550..f7d7f153bb 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -137,7 +137,7 @@ class Config(object):
@staticmethod
def read_config_file(file_path):
with open(file_path) as file_stream:
- return yaml.load(file_stream)
+ return yaml.safe_load(file_stream)
def invoke_all(self, name, *args, **kargs):
results = []
@@ -318,7 +318,7 @@ class Config(object):
)
config_file.write(config_str)
- config = yaml.load(config_str)
+ config = yaml.safe_load(config_str)
obj.invoke_all("generate_files", config)
print(
@@ -390,7 +390,7 @@ class Config(object):
server_name=server_name,
generate_secrets=False,
)
- config = yaml.load(config_string)
+ config = yaml.safe_load(config_string)
config.pop("log_config")
config.update(specified_config)
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 9e64c76544..7e89d345d8 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -68,7 +68,7 @@ def load_appservices(hostname, config_files):
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
- hostname, yaml.load(f), config_file
+ hostname, yaml.safe_load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 464c28c2d9..c1febbe9d3 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -195,7 +195,7 @@ def setup_logging(config, use_worker_options=False):
else:
def load_log_config():
with open(log_config, 'r') as f:
- logging.config.dictConfig(yaml.load(f))
+ logging.config.dictConfig(yaml.safe_load(f))
def sighup(*args):
# it might be better to use a file watcher or something for this.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8e2be218e2..e424c40fdf 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -51,9 +51,10 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = _create_v1_path("/state/%s/", room_id)
+ path = _create_v1_path("/state/%s", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
+ try_trailing_slash_on_400=True,
)
@log_function
@@ -73,9 +74,10 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = _create_v1_path("/state_ids/%s/", room_id)
+ path = _create_v1_path("/state_ids/%s", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
+ try_trailing_slash_on_400=True,
)
@log_function
@@ -95,8 +97,11 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = _create_v1_path("/event/%s/", event_id)
- return self.client.get_json(destination, path=path, timeout=timeout)
+ path = _create_v1_path("/event/%s", event_id)
+ return self.client.get_json(
+ destination, path=path, timeout=timeout,
+ try_trailing_slash_on_400=True,
+ )
@log_function
def backfill(self, destination, room_id, event_tuples, limit):
@@ -121,7 +126,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = _create_v1_path("/backfill/%s/", room_id)
+ path = _create_v1_path("/backfill/%s", room_id)
args = {
"v": event_tuples,
@@ -132,6 +137,7 @@ class TransportLayerClient(object):
destination,
path=path,
args=args,
+ try_trailing_slash_on_400=True,
)
@defer.inlineCallbacks
@@ -167,7 +173,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
- path = _create_v1_path("/send/%s/", transaction.transaction_id)
+ path = _create_v1_path("/send/%s", transaction.transaction_id)
response = yield self.client.put_json(
transaction.destination,
@@ -176,6 +182,7 @@ class TransportLayerClient(object):
json_data_callback=json_data_callback,
long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
+ try_trailing_slash_on_400=True,
)
defer.returnValue(response)
@@ -959,7 +966,7 @@ def _create_v1_path(path, *args):
Example:
- _create_v1_path("/event/%s/", event_id)
+ _create_v1_path("/event/%s", event_id)
Args:
path (str): String template for the path
@@ -980,7 +987,7 @@ def _create_v2_path(path, *args):
Example:
- _create_v2_path("/event/%s/", event_id)
+ _create_v2_path("/event/%s", event_id)
Args:
path (str): String template for the path
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 96d680a5ad..efb6bdca48 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -312,7 +312,7 @@ class BaseFederationServlet(object):
class FederationSendServlet(BaseFederationServlet):
- PATH = "/send/(?P<transaction_id>[^/]*)/"
+ PATH = "/send/(?P<transaction_id>[^/]*)/?"
def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(
@@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet):
class FederationEventServlet(BaseFederationServlet):
- PATH = "/event/(?P<event_id>[^/]*)/"
+ PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
def on_GET(self, origin, content, query, event_id):
@@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateServlet(BaseFederationServlet):
- PATH = "/state/(?P<context>[^/]*)/"
+ PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context):
@@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet):
class FederationStateIdsServlet(BaseFederationServlet):
- PATH = "/state_ids/(?P<room_id>[^/]*)/"
+ PATH = "/state_ids/(?P<room_id>[^/]*)/?"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
@@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet):
- PATH = "/backfill/(?P<context>[^/]*)/"
+ PATH = "/backfill/(?P<context>[^/]*)/?"
def on_GET(self, origin, content, query, context):
versions = [x.decode('ascii') for x in query[b"v"]]
@@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
PATH = (
- "/groups/(?P<group_id>[^/]*)/categories/"
+ "/groups/(?P<group_id>[^/]*)/categories/?"
)
@defer.inlineCallbacks
@@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
- "/groups/(?P<group_id>[^/]*)/roles/"
+ "/groups/(?P<group_id>[^/]*)/roles/?"
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index caad9ae2dd..4544de821d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -746,6 +746,42 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
+ def check_password_provider_3pid(self, medium, address, password):
+ """Check if a password provider is able to validate a thirdparty login
+
+ Args:
+ medium (str): The medium of the 3pid (ex. email).
+ address (str): The address of the 3pid (ex. jdoe@example.com).
+ password (str): The password of the user.
+
+ Returns:
+ Deferred[(str|None, func|None)]: A tuple of `(user_id,
+ callback)`. If authentication is successful, `user_id` is a `str`
+ containing the authenticated, canonical user ID. `callback` is
+ then either a function to be later run after the server has
+ completed login/registration, or `None`. If authentication was
+ unsuccessful, `user_id` and `callback` are both `None`.
+ """
+ for provider in self.password_providers:
+ if hasattr(provider, "check_3pid_auth"):
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = yield provider.check_3pid_auth(
+ medium, address, password,
+ )
+ if result:
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ result = (result, None)
+ defer.returnValue(result)
+
+ defer.returnValue((None, None))
+
+ @defer.inlineCallbacks
def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database.
@@ -756,7 +792,8 @@ class AuthHandler(BaseHandler):
user_id (unicode): complete @user:id
password (unicode): the provided password
Returns:
- (unicode) the canonical_user_id, or None if unknown user / bad password
+ Deferred[unicode] the canonical_user_id, or Deferred[None] if
+ unknown user/bad password
Raises:
LimitExceededError if the ratelimiter's login requests count for this
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f772e62c28..d883e98381 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -19,7 +19,7 @@ import random
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.types import UserID
@@ -61,6 +61,11 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
+ if room_id:
+ blocked = yield self.store.is_room_blocked(room_id)
+ if blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
# send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(auth_user_id)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 563bb3cea3..7dfae78db0 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
@@ -262,6 +262,10 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
+ blocked = yield self.store.is_room_blocked(room_id)
+ if blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 7b8c77ba4d..2df2eaf609 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -233,8 +233,14 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
- """target_user is the UserID whose displayname is to be changed;
- requester is the authenticated user attempting to make this change."""
+ """Set the displayname of a user
+
+ Args:
+ target_user (UserID): the user whose displayname is to be changed.
+ requester (Requester): The user attempting to make this change.
+ new_displayname (str): The displayname to give this user.
+ by_admin (bool): Whether this change was made by an administrator.
+ """
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b0468cd3d5..f4745614f1 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -172,7 +172,7 @@ class RegistrationHandler(BaseHandler):
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
- address (str|None): the IP address used to perform the regitration.
+ address (str|None): the IP address used to perform the registration.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -719,7 +719,7 @@ class RegistrationHandler(BaseHandler):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- address (str|None): the IP address used to perform the regitration.
+ address (str|None): the IP address used to perform the registration.
Returns:
Deferred
@@ -817,9 +817,9 @@ class RegistrationHandler(BaseHandler):
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled.
bind_email (bool): Whether to bind the email with the identity
- server
+ server.
bind_msisdn (bool): Whether to bind the msisdn with the identity
- server
+ server.
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
@@ -861,7 +861,7 @@ class RegistrationHandler(BaseHandler):
"""A user consented to the terms on registration
Args:
- user_id (str): The user ID that consented
+ user_id (str): The user ID that consented.
consent_version (str): version of the policy the user has
consented to.
"""
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
new file mode 100644
index 0000000000..b268bbcb2c
--- /dev/null
+++ b/synapse/handlers/state_deltas.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeltasHandler(object):
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ """Given two events check if the `key_name` field in content changed
+ from not matching `public_value` to doing so.
+
+ For example, check if `history_visibility` (`key_name`) changed from
+ `shared` to `world_readable` (`public_value`).
+
+ Returns:
+ None if the field in the events either both match `public_value`
+ or if neither do, i.e. there has been no change.
+ True if it didnt match `public_value` but now does
+ False if it did match `public_value` but now doesn't
+ """
+ prev_event = None
+ event = None
+ if prev_event_id:
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+
+ if event_id:
+ event = yield self.store.get_event(event_id, allow_none=True)
+
+ if not event and not prev_event:
+ logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
+ defer.returnValue(None)
+
+ prev_value = None
+ value = None
+
+ if prev_event:
+ prev_value = prev_event.content.get(key_name)
+
+ if event:
+ value = event.content.get(key_name)
+
+ logger.debug("prev_value: %r -> value: %r", prev_value, value)
+
+ if value == public_value and prev_value != public_value:
+ defer.returnValue(True)
+ elif value != public_value and prev_value == public_value:
+ defer.returnValue(False)
+ else:
+ defer.returnValue(None)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 7dc0e236e7..b689979b4b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
from synapse.types import get_localpart_from_id
@@ -29,7 +30,7 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-class UserDirectoryHandler(object):
+class UserDirectoryHandler(StateDeltasHandler):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@@ -41,6 +42,8 @@ class UserDirectoryHandler(object):
"""
def __init__(self, hs):
+ super(UserDirectoryHandler, self).__init__(hs)
+
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
@@ -360,7 +363,7 @@ class UserDirectoryHandler(object):
@defer.inlineCallbacks
def _handle_remove_user(self, room_id, user_id):
- """Called when we might need to remove user to directory
+ """Called when we might need to remove user from directory
Args:
room_id (str): room_id that user left or stopped being public that
@@ -402,47 +405,3 @@ class UserDirectoryHandler(object):
if prev_name != new_name or prev_avatar != new_avatar:
yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
-
- @defer.inlineCallbacks
- def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
- """Given two events check if the `key_name` field in content changed
- from not matching `public_value` to doing so.
-
- For example, check if `history_visibility` (`key_name`) changed from
- `shared` to `world_readable` (`public_value`).
-
- Returns:
- None if the field in the events either both match `public_value`
- or if neither do, i.e. there has been no change.
- True if it didnt match `public_value` but now does
- False if it did match `public_value` but now doesn't
- """
- prev_event = None
- event = None
- if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
-
- if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
-
- if not event and not prev_event:
- logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
- defer.returnValue(None)
-
- prev_value = None
- value = None
-
- if prev_event:
- prev_value = prev_event.content.get(key_name)
-
- if event:
- value = event.content.get(key_name)
-
- logger.debug("prev_value: %r -> value: %r", prev_value, value)
-
- if value == public_value and prev_value != public_value:
- defer.returnValue(True)
- elif value != public_value and prev_value == public_value:
- defer.returnValue(False)
- else:
- defer.returnValue(None)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1682c9af13..ff63d0b2a8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -189,6 +189,58 @@ class MatrixFederationHttpClient(object):
self._cooperator = Cooperator(scheduler=schedule)
@defer.inlineCallbacks
+ def _send_request_with_optional_trailing_slash(
+ self,
+ request,
+ try_trailing_slash_on_400=False,
+ **send_request_args
+ ):
+ """Wrapper for _send_request which can optionally retry the request
+ upon receiving a combination of a 400 HTTP response code and a
+ 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
+ due to #3622.
+
+ Args:
+ request (MatrixFederationRequest): details of request to be sent
+ try_trailing_slash_on_400 (bool): Whether on receiving a 400
+ 'M_UNRECOGNIZED' from the server to retry the request with a
+ trailing slash appended to the request path.
+ send_request_args (Dict): A dictionary of arguments to pass to
+ `_send_request()`.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+
+ Returns:
+ Deferred[Dict]: Parsed JSON response body.
+ """
+ try:
+ response = yield self._send_request(
+ request, **send_request_args
+ )
+ except HttpResponseException as e:
+ # Received an HTTP error > 300. Check if it meets the requirements
+ # to retry with a trailing slash
+ if not try_trailing_slash_on_400:
+ raise
+
+ if e.code != 400 or e.to_synapse_error().errcode != "M_UNRECOGNIZED":
+ raise
+
+ # Retry with a trailing slash if we received a 400 with
+ # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
+ # trailing slash on Synapse <= v0.99.3.
+ logger.info("Retrying request with trailing slash")
+ request.path += "/"
+
+ response = yield self._send_request(
+ request, **send_request_args
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
def _send_request(
self,
request,
@@ -196,7 +248,7 @@ class MatrixFederationHttpClient(object):
timeout=None,
long_retries=False,
ignore_backoff=False,
- backoff_on_404=False
+ backoff_on_404=False,
):
"""
Sends a request to the given server.
@@ -473,7 +525,8 @@ class MatrixFederationHttpClient(object):
json_data_callback=None,
long_retries=False, timeout=None,
ignore_backoff=False,
- backoff_on_404=False):
+ backoff_on_404=False,
+ try_trailing_slash_on_400=False):
""" Sends the specifed json data using PUT
Args:
@@ -493,7 +546,12 @@ class MatrixFederationHttpClient(object):
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
a failure of the server (and should therefore back off future
- requests)
+ requests).
+ try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ response we should try appending a trailing slash to the end
+ of the request. Workaround for #3622 in Synapse <= v0.99.3. This
+ will be attempted before backing off if backing off has been
+ enabled.
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@@ -509,7 +567,6 @@ class MatrixFederationHttpClient(object):
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
-
request = MatrixFederationRequest(
method="PUT",
destination=destination,
@@ -519,17 +576,19 @@ class MatrixFederationHttpClient(object):
json=data,
)
- response = yield self._send_request(
+ response = yield self._send_request_with_optional_trailing_slash(
request,
+ try_trailing_slash_on_400,
+ backoff_on_404=backoff_on_404,
+ ignore_backoff=ignore_backoff,
long_retries=long_retries,
timeout=timeout,
- ignore_backoff=ignore_backoff,
- backoff_on_404=backoff_on_404,
)
body = yield _handle_json_response(
self.hs.get_reactor(), self.default_timeout, request, response,
)
+
defer.returnValue(body)
@defer.inlineCallbacks
@@ -592,7 +651,8 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
- timeout=None, ignore_backoff=False):
+ timeout=None, ignore_backoff=False,
+ try_trailing_slash_on_400=False):
""" GETs some json from the given host homeserver and path
Args:
@@ -606,6 +666,9 @@ class MatrixFederationHttpClient(object):
be retried.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
+ try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ response we should try appending a trailing slash to the end of
+ the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -631,16 +694,19 @@ class MatrixFederationHttpClient(object):
query=args,
)
- response = yield self._send_request(
+ response = yield self._send_request_with_optional_trailing_slash(
request,
+ try_trailing_slash_on_400,
+ backoff_on_404=False,
+ ignore_backoff=ignore_backoff,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
- ignore_backoff=ignore_backoff,
)
body = yield _handle_json_response(
self.hs.get_reactor(), self.default_timeout, request, response,
)
+
defer.returnValue(body)
@defer.inlineCallbacks
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index fc9a20ff59..235ce8334e 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -73,14 +73,26 @@ class ModuleApi(object):
"""
return self._auth_handler.check_user_exists(user_id)
- def register(self, localpart):
- """Registers a new user with given localpart
+ @defer.inlineCallbacks
+ def register(self, localpart, displayname=None):
+ """Registers a new user with given localpart and optional
+ displayname.
+
+ Args:
+ localpart (str): The localpart of the new user.
+ displayname (str|None): The displayname of the new user. If None,
+ the user's displayname will default to `localpart`.
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
+ # Register the user
reg = self.hs.get_registration_handler()
- return reg.register(localpart=localpart)
+ user_id, access_token = yield reg.register(
+ localpart=localpart, default_display_name=displayname,
+ )
+
+ defer.returnValue((user_id, access_token))
@defer.inlineCallbacks
def invalidate_access_token(self, access_token):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 55630ba9a7..02e5bf6cc8 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -223,14 +223,25 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return
# Now lets try and call on_<CMD_NAME> function
- try:
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(),
- getattr(self, "on_%s" % (cmd_name,)),
- cmd,
- )
- except Exception:
- logger.exception("[%s] Failed to handle line: %r", self.id(), line)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ self.handle_command,
+ cmd,
+ )
+
+ def handle_command(self, cmd):
+ """Handle a command we have received over the replication stream.
+
+ By default delegates to on_<COMMAND>
+
+ Args:
+ cmd (synapse.replication.tcp.commands.Command): received command
+
+ Returns:
+ Deferred
+ """
+ handler = getattr(self, "on_%s" % (cmd.NAME,))
+ return handler(cmd)
def close(self):
logger.warn("[%s] Closing connection", self.id())
@@ -364,8 +375,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.unregisterProducer()
def __str__(self):
+ addr = None
+ if self.transport:
+ addr = str(self.transport.getPeer())
return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
- self.name, self.conn_id, self.addr,
+ self.name, self.conn_id, addr,
)
def id(self):
@@ -381,12 +395,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
- def __init__(self, server_name, clock, streamer, addr):
+ def __init__(self, server_name, clock, streamer):
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
self.server_name = server_name
self.streamer = streamer
- self.addr = addr
# The streams the client has subscribed to and is up to date with
self.replication_streams = set()
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 47cdf30bd3..7fc346c7b6 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -57,7 +57,6 @@ class ReplicationStreamProtocolFactory(Factory):
self.server_name,
self.clock,
self.streamer,
- addr
)
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index c1e626be3f..e23084baae 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -23,7 +23,7 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
-
+import itertools
import logging
from collections import namedtuple
@@ -195,8 +195,8 @@ class Stream(object):
limit=MAX_EVENTS_BEHIND + 1,
)
- if len(rows) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
+ # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+ rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = yield self.update_function(
from_token, current_token,
@@ -204,6 +204,11 @@ class Stream(object):
updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows]
+ # check we didn't get more rows than the limit.
+ # doing it like this allows the update_function to be a generator.
+ if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+ raise Exception("stream %s has fallen behind" % (self.NAME))
+
defer.returnValue((updates, current_token))
def current_token(self):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8d56effbb8..5180e9eaf1 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -201,6 +201,24 @@ class LoginRestServlet(ClientV1RestServlet):
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
+
+ # Check for login providers that support 3pid login types
+ canonical_user_id, callback_3pid = (
+ yield self.auth_handler.check_password_provider_3pid(
+ medium,
+ address,
+ login_submission["password"],
+ )
+ )
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ result = yield self._register_device_with_callback(
+ canonical_user_id, login_submission, callback_3pid,
+ )
+ defer.returnValue(result)
+
+ # No password providers were able to handle this 3pid
+ # Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
medium, address,
)
@@ -223,20 +241,43 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- auth_handler = self.auth_handler
- canonical_user_id, callback = yield auth_handler.validate_login(
+ canonical_user_id, callback = yield self.auth_handler.validate_login(
identifier["user"],
login_submission,
)
+ result = yield self._register_device_with_callback(
+ canonical_user_id, login_submission, callback,
+ )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _register_device_with_callback(
+ self,
+ user_id,
+ login_submission,
+ callback=None,
+ ):
+ """ Registers a device with a given user_id. Optionally run a callback
+ function after registration has completed.
+
+ Args:
+ user_id (str): ID of the user to register.
+ login_submission (dict): Dictionary of login information.
+ callback (func|None): Callback function to run after registration.
+
+ Returns:
+ result (Dict[str,str]): Dictionary of account information after
+ successful registration.
+ """
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- canonical_user_id, device_id, initial_display_name,
+ user_id, device_id, initial_display_name,
)
result = {
- "user_id": canonical_user_id,
+ "user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 0fd1ccc40a..89a1f7e3d7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -301,7 +301,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return txn.fetchall()
+ return (
+ r[0:5] + (json.loads(r[5]), ) for r in txn
+ )
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
new file mode 100644
index 0000000000..57bc45cdb9
--- /dev/null
+++ b/synapse/storage/state_deltas.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeltasStore(SQLBaseStore):
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index d360e857d1..4d60a5726f 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.state import StateFilter
+from synapse.storage.state_deltas import StateDeltasStore
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
-class UserDirectoryStore(BackgroundUpdateStore):
+class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@@ -134,7 +135,12 @@ class UserDirectoryStore(BackgroundUpdateStore):
@defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size):
-
+ """
+ Args:
+ progress (dict)
+ batch_size (int): Maximum number of state events to process
+ per cycle.
+ """
state = self.hs.get_state_handler()
# If we don't have progress filed, delete everything.
@@ -142,13 +148,14 @@ class UserDirectoryStore(BackgroundUpdateStore):
yield self.delete_all_from_user_dir()
def _get_next_batch(txn):
+ # Only fetch 250 rooms, so we don't fetch too many at once, even
+ # if those 250 rooms have less than batch_size state events.
sql = """
- SELECT room_id FROM %s
+ SELECT room_id, events FROM %s
ORDER BY events DESC
- LIMIT %s
+ LIMIT 250
""" % (
TEMP_TABLE + "_rooms",
- str(batch_size),
)
txn.execute(sql)
rooms_to_work_on = txn.fetchall()
@@ -156,8 +163,6 @@ class UserDirectoryStore(BackgroundUpdateStore):
if not rooms_to_work_on:
return None
- rooms_to_work_on = [x[0] for x in rooms_to_work_on]
-
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
@@ -179,7 +184,9 @@ class UserDirectoryStore(BackgroundUpdateStore):
% (len(rooms_to_work_on), progress["remaining"])
)
- for room_id in rooms_to_work_on:
+ processed_event_count = 0
+
+ for room_id, event_count in rooms_to_work_on:
is_in_room = yield self.is_host_joined(room_id, self.server_name)
if is_in_room:
@@ -246,7 +253,13 @@ class UserDirectoryStore(BackgroundUpdateStore):
progress,
)
- defer.returnValue(len(rooms_to_work_on))
+ processed_event_count += event_count
+
+ if processed_event_count > batch_size:
+ # Don't process any more rooms, we've hit our batch size.
+ defer.returnValue(processed_event_count)
+
+ defer.returnValue(processed_event_count)
@defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size):
@@ -488,16 +501,6 @@ class UserDirectoryStore(BackgroundUpdateStore):
defer.returnValue(user_ids)
- @defer.inlineCallbacks
- def get_all_local_users(self):
- """Get all local users
- """
- sql = """
- SELECT name FROM users
- """
- rows = yield self._execute("get_all_local_users", None, sql)
- defer.returnValue([name for name, in rows])
-
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
@@ -675,59 +678,6 @@ class UserDirectoryStore(BackgroundUpdateStore):
desc="update_user_directory_stream_pos",
)
- def get_current_state_deltas(self, prev_stream_id):
- prev_stream_id = int(prev_stream_id)
- if not self._curr_state_delta_stream_cache.has_any_entity_changed(
- prev_stream_id
- ):
- return []
-
- def get_current_state_deltas_txn(txn):
- # First we calculate the max stream id that will give us less than
- # N results.
- # We arbitarily limit to 100 stream_id entries to ensure we don't
- # select toooo many.
- sql = """
- SELECT stream_id, count(*)
- FROM current_state_delta_stream
- WHERE stream_id > ?
- GROUP BY stream_id
- ORDER BY stream_id ASC
- LIMIT 100
- """
- txn.execute(sql, (prev_stream_id,))
-
- total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
- total += count
- if total > 100:
- # We arbitarily limit to 100 entries to ensure we don't
- # select toooo many.
- break
-
- # Now actually get the deltas
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC
- """
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(
- "get_current_state_deltas", get_current_state_deltas_txn
- )
-
- def get_max_stream_id_in_current_state_deltas(self):
- return self._simple_select_one_onecol(
- table="current_state_delta_stream",
- keyvalues={},
- retcol="COALESCE(MAX(stream_id), -1)",
- desc="get_max_stream_id_in_current_state_deltas",
- )
-
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
|