diff --git a/synapse/__init__.py b/synapse/__init__.py
index d5f6dc2094..7819cfbcbb 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.34.0rc2"
+__version__ = "0.34.1rc1"
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 70cecde486..4c3abf06fe 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -35,6 +35,7 @@ def request_registration(
server_location,
shared_secret,
admin=False,
+ user_type=None,
requests=_requests,
_print=print,
exit=sys.exit,
@@ -65,6 +66,9 @@ def request_registration(
mac.update(password.encode('utf8'))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ mac.update(b"\x00")
+ mac.update(user_type.encode('utf8'))
mac = mac.hexdigest()
@@ -74,6 +78,7 @@ def request_registration(
"password": password,
"mac": mac,
"admin": admin,
+ "user_type": user_type,
}
_print("Sending registration request...")
@@ -91,7 +96,7 @@ def request_registration(
_print("Success!")
-def register_new_user(user, password, server_location, shared_secret, admin):
+def register_new_user(user, password, server_location, shared_secret, admin, user_type):
if not user:
try:
default_user = getpass.getuser()
@@ -129,7 +134,8 @@ def register_new_user(user, password, server_location, shared_secret, admin):
else:
admin = False
- request_registration(user, password, server_location, shared_secret, bool(admin))
+ request_registration(user, password, server_location, shared_secret,
+ bool(admin), user_type)
def main():
@@ -154,6 +160,12 @@ def main():
default=None,
help="New password for user. Will prompt if omitted.",
)
+ parser.add_argument(
+ "-t",
+ "--user_type",
+ default=None,
+ help="User type as specified in synapse.api.constants.UserTypes",
+ )
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a",
@@ -208,7 +220,8 @@ def main():
if args.admin or args.no_admin:
admin = args.admin
- register_new_user(args.user, args.password, args.server_url, secret, admin)
+ register_new_user(args.user, args.password, args.server_url, secret,
+ admin, args.user_type)
if __name__ == "__main__":
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5309899703..b8a9af7158 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -802,9 +802,10 @@ class Auth(object):
threepid should never be set at the same time.
"""
- # Never fail an auth check for the server notices users
+ # Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
- if user_id == self.hs.config.server_notices_mxid:
+ is_support = yield self.store.is_support_user(user_id)
+ if user_id == self.hs.config.server_notices_mxid or is_support:
return
if self.hs.config.hs_disabled:
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f20e0fcf0b..87bc1cb53d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -102,6 +102,7 @@ class ThirdPartyEntityKind(object):
class RoomVersions(object):
V1 = "1"
+ V2 = "2"
VDH_TEST = "vdh-test-version"
STATE_V2_TEST = "state-v2-test"
@@ -113,9 +114,18 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1
# until we have a working v2.
KNOWN_ROOM_VERSIONS = {
RoomVersions.V1,
+ RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
+
+
+class UserTypes(object):
+ """Allows for user type specific behaviour. With the benefit of hindsight
+ 'admin' and 'guest' users should also be UserTypes. Normal users are type None
+ """
+ SUPPORT = "support"
+ ALL_USER_TYPES = (SUPPORT)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 48b903374d..0b464834ce 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -348,6 +348,24 @@ class IncompatibleRoomVersionError(SynapseError):
)
+class RequestSendFailed(RuntimeError):
+ """Sending a HTTP request over federation failed due to not being able to
+ talk to the remote server for some reason.
+
+ This exception is used to differentiate "expected" errors that arise due to
+ networking (e.g. DNS failures, connection timeouts etc), versus unexpected
+ errors (like programming errors).
+ """
+ def __init__(self, inner_exception, can_retry):
+ super(RequestSendFailed, self).__init__(
+ "Failed to send request: %s: %s" % (
+ type(inner_exception).__name__, inner_exception,
+ )
+ )
+ self.inner_exception = inner_exception
+ self.can_retry = can_retry
+
+
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server
interactions.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 677c0bdd4c..16ad654864 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import text_type
+
import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
@@ -353,7 +355,7 @@ class Filter(object):
sender = event.user_id
room_id = None
ev_type = "m.presence"
- is_url = False
+ contains_url = False
else:
sender = event.get("sender", None)
if not sender:
@@ -368,13 +370,16 @@ class Filter(object):
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- is_url = "url" in event.get("content", {})
+
+ content = event.get("content", {})
+ # check if there is a string url field in the content for filtering purposes
+ contains_url = isinstance(content.get("url"), text_type)
return self.check_fields(
room_id,
sender,
ev_type,
- is_url,
+ contains_url,
)
def check_fields(self, room_id, sender, event_type, contains_url):
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index c3afcc573b..b45adafdd3 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -19,15 +19,8 @@ from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
-
try:
python_dependencies.check_requirements()
-except python_dependencies.MissingRequirementError as e:
- message = "\n".join([
- "Missing Requirement: %s" % (str(e),),
- "To install run:",
- " pip install --upgrade --force \"%s\"" % (e.dependency,),
- "",
- ])
- sys.stderr.writelines(message)
+except python_dependencies.DependencyException as e:
+ sys.stderr.writelines(e.message)
sys.exit(1)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6169bf09bc..f3ac3d19f0 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -60,6 +60,7 @@ from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
+from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore, are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
@@ -168,8 +169,13 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
+ "/.well-known/matrix/client": WellKnownResource(self),
})
+ if self.get_config().saml2_enabled:
+ from synapse.rest.saml2 import SAML2Resource
+ resources["/_matrix/saml2"] = SAML2Resource(self)
+
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
consent_resource = ConsentResource(self)
@@ -316,9 +322,6 @@ def setup(config_options):
synapse.config.logger.setup_logging(config, use_worker_options=False)
- # check any extra requirements we have now we have a config
- check_requirements(config)
-
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
@@ -531,7 +534,7 @@ def run(hs):
)
start_generate_monthly_active_users()
- if hs.config.limit_usage_by_mau:
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index 79fe9c3dac..fca35b008c 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -16,7 +16,7 @@ from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
- from homeserver import HomeServerConfig
+ from synapse.config.homeserver import HomeServerConfig
action = sys.argv[1]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 14dae65ea0..fd2d6d52ef 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -135,10 +135,6 @@ class Config(object):
return file_stream.read()
@staticmethod
- def default_path(name):
- return os.path.abspath(os.path.join(os.path.curdir, name))
-
- @staticmethod
def read_config_file(file_path):
with open(file_path) as file_stream:
return yaml.load(file_stream)
@@ -151,8 +147,39 @@ class Config(object):
return results
def generate_config(
- self, config_dir_path, server_name, is_generating_file, report_stats=None
+ self,
+ config_dir_path,
+ data_dir_path,
+ server_name,
+ generate_secrets=False,
+ report_stats=None,
):
+ """Build a default configuration file
+
+ This is used both when the user explicitly asks us to generate a config file
+ (eg with --generate_config), and before loading the config at runtime (to give
+ a base which the config files override)
+
+ Args:
+ config_dir_path (str): The path where the config files are kept. Used to
+ create filenames for things like the log config and the signing key.
+
+ data_dir_path (str): The path where the data files are kept. Used to create
+ filenames for things like the database and media store.
+
+ server_name (str): The server name. Used to initialise the server_name
+ config param, but also used in the names of some of the config files.
+
+ generate_secrets (bool): True if we should generate new secrets for things
+ like the macaroon_secret_key. If False, these parameters will be left
+ unset.
+
+ report_stats (bool|None): Initial setting for the report_stats setting.
+ If None, report_stats will be left unset.
+
+ Returns:
+ str: the yaml config file
+ """
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(
@@ -160,15 +187,14 @@ class Config(object):
for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
+ data_dir_path=data_dir_path,
server_name=server_name,
- is_generating_file=is_generating_file,
+ generate_secrets=generate_secrets,
report_stats=report_stats,
)
)
- config = yaml.load(default_config)
-
- return default_config, config
+ return default_config
@classmethod
def load_config(cls, description, argv):
@@ -274,12 +300,14 @@ class Config(object):
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
- config_str, config = obj.generate_config(
+ config_str = obj.generate_config(
config_dir_path=config_dir_path,
+ data_dir_path=os.getcwd(),
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
- is_generating_file=True,
+ generate_secrets=True,
)
+ config = yaml.load(config_str)
obj.invoke_all("generate_files", config)
config_file.write(config_str)
print(
@@ -350,11 +378,13 @@ class Config(object):
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
- _, config = self.generate_config(
+ config_string = self.generate_config(
config_dir_path=config_dir_path,
+ data_dir_path=os.getcwd(),
server_name=server_name,
- is_generating_file=False,
+ generate_secrets=False,
)
+ config = yaml.load(config_string)
config.pop("log_config")
config.update(specified_config)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index e915d9d09b..c8890147a6 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from ._base import Config
@@ -45,8 +46,8 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
- def default_config(self, **kwargs):
- database_path = self.abspath("homeserver.db")
+ def default_config(self, data_dir_path, **kwargs):
+ database_path = os.path.join(data_dir_path, "homeserver.db")
return """\
# Database configuration
database:
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 10dd40159f..5aad062c36 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -32,7 +32,7 @@ from .ratelimiting import RatelimitConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
from .room_directory import RoomDirectoryConfig
-from .saml2 import SAML2Config
+from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
@@ -53,10 +53,3 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
ServerNoticesConfig, RoomDirectoryConfig,
):
pass
-
-
-if __name__ == '__main__':
- import sys
- sys.stdout.write(
- HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
- )
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 279c47bb48..53f48fe2dd 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -66,26 +66,35 @@ class KeyConfig(Config):
# falsification of values
self.form_secret = config.get("form_secret", None)
- def default_config(self, config_dir_path, server_name, is_generating_file=False,
+ def default_config(self, config_dir_path, server_name, generate_secrets=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
- if is_generating_file:
- macaroon_secret_key = random_string_with_symbols(50)
- form_secret = '"%s"' % random_string_with_symbols(50)
+ if generate_secrets:
+ macaroon_secret_key = 'macaroon_secret_key: "%s"' % (
+ random_string_with_symbols(50),
+ )
+ form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
else:
- macaroon_secret_key = None
- form_secret = 'null'
+ macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>"
+ form_secret = "# form_secret: <PRIVATE STRING>"
return """\
- macaroon_secret_key: "%(macaroon_secret_key)s"
+ # a secret which is used to sign access tokens. If none is specified,
+ # the registration_shared_secret is used, if one is given; otherwise,
+ # a secret key is derived from the signing key.
+ #
+ # Note that changing this will invalidate any active access tokens, so
+ # all clients will have to log back in.
+ %(macaroon_secret_key)s
# Used to enable access token expiration.
expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop
- # falsification of values
- form_secret: %(form_secret)s
+ # falsification of values. Must be specified for the User Consent
+ # forms to work.
+ %(form_secret)s
## Signing Keys ##
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 7081868963..f87efecbf8 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -80,9 +80,7 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
- log_config = self.abspath(
- os.path.join(config_dir_path, server_name + ".log.config")
- )
+ log_config = os.path.join(config_dir_path, server_name + ".log.config")
return """
# A yaml python logging config file
log_config: "%(log_config)s"
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 61155c99d0..718c43ae03 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -24,10 +24,16 @@ class MetricsConfig(Config):
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, report_stats=None, **kwargs):
- suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
- return ("""\
+ res = """\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
- """ + suffix) % locals()
+ """
+
+ if report_stats is None:
+ res += "# report_stats: true|false\n"
+ else:
+ res += "report_stats: %s\n" % ('true' if report_stats else 'false')
+
+ return res
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 717bbfec61..6c2b543b8c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -37,6 +37,7 @@ class RegistrationConfig(Config):
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
+ self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
self.invite_3pid_guest = (
@@ -49,8 +50,13 @@ class RegistrationConfig(Config):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
- def default_config(self, **kwargs):
- registration_shared_secret = random_string_with_symbols(50)
+ def default_config(self, generate_secrets=False, **kwargs):
+ if generate_secrets:
+ registration_shared_secret = 'registration_shared_secret: "%s"' % (
+ random_string_with_symbols(50),
+ )
+ else:
+ registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
return """\
## Registration ##
@@ -77,7 +83,7 @@ class RegistrationConfig(Config):
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
- registration_shared_secret: "%(registration_shared_secret)s"
+ %(registration_shared_secret)s
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
@@ -91,6 +97,14 @@ class RegistrationConfig(Config):
# accessible to anonymous users.
allow_guest_access: False
+ # The identity server which we suggest that clients should use when users log
+ # in on this server.
+ #
+ # (By default, no suggestion is made, so it is left up to the client.
+ # This setting is ignored unless public_baseurl is also set.)
+ #
+ # default_identity_server: https://matrix.org
+
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 06c62ab62c..76e3340a91 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import os
from collections import namedtuple
from synapse.util.module_loader import load_module
@@ -175,9 +175,9 @@ class ContentRepositoryConfig(Config):
"url_preview_url_blacklist", ()
)
- def default_config(self, **kwargs):
- media_store = self.default_path("media_store")
- uploads_path = self.default_path("uploads")
+ def default_config(self, data_dir_path, **kwargs):
+ media_store = os.path.join(data_dir_path, "media_store")
+ uploads_path = os.path.join(data_dir_path, "uploads")
return r"""
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
deleted file mode 100644
index 8d7f443021..0000000000
--- a/synapse/config/saml2.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015 Ericsson
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from ._base import Config
-
-
-class SAML2Config(Config):
- """SAML2 Configuration
- Synapse uses pysaml2 libraries for providing SAML2 support
-
- config_path: Path to the sp_conf.py configuration file
- idp_redirect_url: Identity provider URL which will redirect
- the user back to /login/saml2 with proper info.
-
- sp_conf.py file is something like:
- https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
-
- More information: https://pythonhosted.org/pysaml2/howto/config.html
- """
-
- def read_config(self, config):
- saml2_config = config.get("saml2_config", None)
- if saml2_config:
- self.saml2_enabled = saml2_config.get("enabled", True)
- self.saml2_config_path = saml2_config["config_path"]
- self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
- else:
- self.saml2_enabled = False
- self.saml2_config_path = None
- self.saml2_idp_redirect_url = None
-
- def default_config(self, config_dir_path, server_name, **kwargs):
- return """
- # Enable SAML2 for registration and login. Uses pysaml2
- # config_path: Path to the sp_conf.py configuration file
- # idp_redirect_url: Identity provider URL which will redirect
- # the user back to /login/saml2 with proper info.
- # See pysaml2 docs for format of config.
- #saml2_config:
- # enabled: true
- # config_path: "%s/sp_conf.py"
- # idp_redirect_url: "http://%s/idp"
- """ % (config_dir_path, server_name)
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
new file mode 100644
index 0000000000..86ffe334f5
--- /dev/null
+++ b/synapse/config/saml2_config.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config, ConfigError
+
+
+class SAML2Config(Config):
+ def read_config(self, config):
+ self.saml2_enabled = False
+
+ saml2_config = config.get("saml2_config")
+
+ if not saml2_config or not saml2_config.get("enabled", True):
+ return
+
+ self.saml2_enabled = True
+
+ import saml2.config
+ self.saml2_sp_config = saml2.config.SPConfig()
+ self.saml2_sp_config.load(self._default_saml_config_dict())
+ self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+
+ config_path = saml2_config.get("config_path", None)
+ if config_path is not None:
+ self.saml2_sp_config.load_file(config_path)
+
+ def _default_saml_config_dict(self):
+ import saml2
+
+ public_baseurl = self.public_baseurl
+ if public_baseurl is None:
+ raise ConfigError(
+ "saml2_config requires a public_baseurl to be set"
+ )
+
+ metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
+ response_url = public_baseurl + "_matrix/saml2/authn_response"
+ return {
+ "entityid": metadata_url,
+
+ "service": {
+ "sp": {
+ "endpoints": {
+ "assertion_consumer_service": [
+ (response_url, saml2.BINDING_HTTP_POST),
+ ],
+ },
+ "required_attributes": ["uid"],
+ "optional_attributes": ["mail", "surname", "givenname"],
+ },
+ }
+ }
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Enable SAML2 for registration and login. Uses pysaml2.
+ #
+ # saml2_config:
+ #
+ # # The following is the configuration for the pysaml2 Service Provider.
+ # # See pysaml2 docs for format of config.
+ # #
+ # # Default values will be used for the 'entityid' and 'service' settings,
+ # # so it is not normally necessary to specify them unless you need to
+ # # override them.
+ #
+ # sp_config:
+ # # point this to the IdP's metadata. You can use either a local file or
+ # # (preferably) a URL.
+ # metadata:
+ # # local: ["saml2/idp.xml"]
+ # remote:
+ # - url: https://our_idp/metadata.xml
+ #
+ # # The following is just used to generate our metadata xml, and you
+ # # may well not need it, depending on your setup. Alternatively you
+ # # may need a whole lot more detail - see the pysaml2 docs!
+ #
+ # description: ["My awesome SP", "en"]
+ # name: ["Test SP", "en"]
+ #
+ # organization:
+ # name: Example com
+ # display_name:
+ # - ["Example co", "en"]
+ # url: "http://example.com"
+ #
+ # contact_person:
+ # - given_name: Bob
+ # sur_name: "the Sysadmin"
+ # email_address": ["admin@example.com"]
+ # contact_type": technical
+ #
+ # # Instead of putting the config inline as above, you can specify a
+ # # separate pysaml2 configuration file:
+ # #
+ # # config_path: "%(config_dir_path)s/sp_conf.py"
+ """ % {"config_dir_path": config_dir_path}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a9154ad462..fb57791098 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +15,10 @@
# limitations under the License.
import logging
+import os.path
from synapse.http.endpoint import parse_and_validate_server_name
+from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -203,7 +205,9 @@ class ServerConfig(Config):
]
})
- def default_config(self, server_name, **kwargs):
+ _check_resource_config(self.listeners)
+
+ def default_config(self, server_name, data_dir_path, **kwargs):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -211,7 +215,7 @@ class ServerConfig(Config):
bind_port = 8448
unsecure_port = 8008
- pid_file = self.abspath("homeserver.pid")
+ pid_file = os.path.join(data_dir_path, "homeserver.pid")
return """\
## Server ##
@@ -356,41 +360,41 @@ class ServerConfig(Config):
# type: manhole
- # Homeserver blocking
- #
- # How to reach the server admin, used in ResourceLimitError
- # admin_contact: 'mailto:admin@server.com'
- #
- # Global block config
- #
- # hs_disabled: False
- # hs_disabled_message: 'Human readable reason for why the HS is blocked'
- # hs_disabled_limit_type: 'error code(str), to help clients decode reason'
- #
- # Monthly Active User Blocking
- #
- # Enables monthly active user checking
- # limit_usage_by_mau: False
- # max_mau_value: 50
- # mau_trial_days: 2
- #
- # If enabled, the metrics for the number of monthly active users will
- # be populated, however no one will be limited. If limit_usage_by_mau
- # is true, this is implied to be true.
- # mau_stats_only: False
- #
- # Sometimes the server admin will want to ensure certain accounts are
- # never blocked by mau checking. These accounts are specified here.
- #
- # mau_limit_reserved_threepids:
- # - medium: 'email'
- # address: 'reserved_user@example.com'
- #
- # Room searching
- #
- # If disabled, new messages will not be indexed for searching and users
- # will receive errors when searching for messages. Defaults to enabled.
- # enable_search: true
+ # Homeserver blocking
+ #
+ # How to reach the server admin, used in ResourceLimitError
+ # admin_contact: 'mailto:admin@server.com'
+ #
+ # Global block config
+ #
+ # hs_disabled: False
+ # hs_disabled_message: 'Human readable reason for why the HS is blocked'
+ # hs_disabled_limit_type: 'error code(str), to help clients decode reason'
+ #
+ # Monthly Active User Blocking
+ #
+ # Enables monthly active user checking
+ # limit_usage_by_mau: False
+ # max_mau_value: 50
+ # mau_trial_days: 2
+ #
+ # If enabled, the metrics for the number of monthly active users will
+ # be populated, however no one will be limited. If limit_usage_by_mau
+ # is true, this is implied to be true.
+ # mau_stats_only: False
+ #
+ # Sometimes the server admin will want to ensure certain accounts are
+ # never blocked by mau checking. These accounts are specified here.
+ #
+ # mau_limit_reserved_threepids:
+ # - medium: 'email'
+ # address: 'reserved_user@example.com'
+ #
+ # Room searching
+ #
+ # If disabled, new messages will not be indexed for searching and users
+ # will receive errors when searching for messages. Defaults to enabled.
+ # enable_search: true
""" % locals()
def read_arguments(self, args):
@@ -464,3 +468,36 @@ def _warn_if_webclient_configured(listeners):
if name == 'webclient':
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
+
+
+KNOWN_RESOURCES = (
+ 'client',
+ 'consent',
+ 'federation',
+ 'keys',
+ 'media',
+ 'metrics',
+ 'replication',
+ 'static',
+ 'webclient',
+)
+
+
+def _check_resource_config(listeners):
+ resource_names = set(
+ res_name
+ for listener in listeners
+ for res in listener.get("resources", [])
+ for res_name in res.get("names", [])
+ )
+
+ for resource in resource_names:
+ if resource not in KNOWN_RESOURCES:
+ raise ConfigError(
+ "Unknown listener resource '%s'" % (resource, )
+ )
+ if resource == "consent":
+ try:
+ check_requirements('resources.consent')
+ except DependencyException as e:
+ raise ConfigError(e.message)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 099ace28c1..4640513497 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -22,7 +22,11 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.errors import FederationDeniedError, HttpResponseException
+from synapse.api.errors import (
+ FederationDeniedError,
+ HttpResponseException,
+ RequestSendFailed,
+)
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import (
LaterGauge,
@@ -518,11 +522,16 @@ class TransactionQueue(object):
)
except FederationDeniedError as e:
logger.info(e)
- except Exception as e:
- logger.warn(
- "TX [%s] Failed to send transaction: %s",
+ except RequestSendFailed as e:
+ logger.warning("(TX [%s] Failed to send transaction: %s", destination, e)
+
+ for p, _ in pending_pdus:
+ logger.info("Failed to send event %s to %s", p.event_id,
+ destination)
+ except Exception:
+ logger.exception(
+ "TX [%s] Failed to send transaction",
destination,
- e,
)
for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c6e89db4bc..2abd9af94f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -563,10 +563,10 @@ class AuthHandler(BaseHandler):
insensitively, but return None if there are multiple inexact matches.
Args:
- (str) user_id: complete @user:id
+ (unicode|bytes) user_id: complete @user:id
Returns:
- defer.Deferred: (str) canonical_user_id, or None if zero or
+ defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
@@ -954,6 +954,15 @@ class MacaroonGenerator(object):
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
+ """
+
+ Args:
+ user_id (unicode):
+ duration_in_ms (int):
+
+ Returns:
+ unicode
+ """
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 43f81bd607..9d257ecf31 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -235,6 +235,17 @@ class PaginationHandler(object):
"room_key", next_key
)
+ if events:
+ if event_filter:
+ events = event_filter.filter(events)
+
+ events = yield filter_events_for_client(
+ self.store,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
if not events:
defer.returnValue({
"chunk": [],
@@ -242,16 +253,6 @@ class PaginationHandler(object):
"end": next_token.to_string(),
})
- if event_filter:
- events = event_filter.filter(events)
-
- events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
- )
-
state = None
if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 015909bb26..21c17c59a0 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -126,6 +126,8 @@ class RegistrationHandler(BaseHandler):
make_guest=False,
admin=False,
threepid=None,
+ user_type=None,
+ default_display_name=None,
):
"""Registers a new client on the server.
@@ -140,6 +142,10 @@ class RegistrationHandler(BaseHandler):
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
+ user_type (str|None): type of user. One of the values from
+ api.constants.UserTypes, or None for a normal user.
+ default_display_name (unicode|None): if set, the new user's displayname
+ will be set to this. Defaults to 'localpart'.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -169,6 +175,13 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
+ if was_guest:
+ # If the user was a guest then they already have a profile
+ default_display_name = None
+
+ elif default_display_name is None:
+ default_display_name = localpart
+
token = None
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
@@ -178,11 +191,9 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
- create_profile_with_localpart=(
- # If the user was a guest then they already have a profile
- None if was_guest else user.localpart
- ),
+ create_profile_with_displayname=default_display_name,
admin=admin,
+ user_type=user_type,
)
if self.hs.config.user_directory_search_all_users:
@@ -203,13 +214,15 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
+ if default_display_name is None:
+ default_display_name = localpart
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
make_guest=make_guest,
- create_profile_with_localpart=user.localpart,
+ create_profile_with_displayname=default_display_name,
)
except SynapseError:
# if user id is taken, just generate another
@@ -233,9 +246,16 @@ class RegistrationHandler(BaseHandler):
# auto-join the user to any rooms we're supposed to dump them into
fake_requester = create_requester(user_id)
- # try to create the room if we're the first user on the server
+ # try to create the room if we're the first real user on the server. Note
+ # that an auto-generated support user is not a real user and will never be
+ # the user to create the room
should_auto_create_rooms = False
- if self.hs.config.autocreate_auto_join_rooms:
+ is_support = yield self.store.is_support_user(user_id)
+ # There is an edge case where the first user is the support user, then
+ # the room is never created, though this seems unlikely and
+ # recoverable from given the support user being involved in the first
+ # place.
+ if self.hs.config.autocreate_auto_join_rooms and not is_support:
count = yield self.store.count_all_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
@@ -300,7 +320,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
password_hash="",
appservice_id=service_id,
- create_profile_with_localpart=user.localpart,
+ create_profile_with_displayname=user.localpart,
)
defer.returnValue(user_id)
@@ -328,35 +348,6 @@ class RegistrationHandler(BaseHandler):
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
- def register_saml2(self, localpart):
- """
- Registers email_id as SAML2 Based Auth.
- """
- if types.contains_invalid_mxid_characters(localpart):
- raise SynapseError(
- 400,
- "User ID can only contain characters a-z, 0-9, or '=_-./'",
- )
- yield self.auth.check_auth_blocking()
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
-
- yield self.check_user_id_not_appservice_exclusive(user_id)
- token = self.macaroon_gen.generate_access_token(user_id)
- try:
- yield self.store.register(
- user_id=user_id,
- token=token,
- password_hash=None,
- create_profile_with_localpart=user.localpart,
- )
- except Exception as e:
- yield self.store.add_access_token_to_user(user_id, token)
- # Ignore Registration errors
- logger.exception(e)
- defer.returnValue((user_id, token))
-
- @defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
Registers emails with an identity server.
@@ -507,7 +498,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
token=token,
password_hash=password_hash,
- create_profile_with_localpart=user.localpart,
+ create_profile_with_displayname=user.localpart,
)
else:
yield self._auth_handler.delete_access_tokens_for_user(user_id)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3928faa6e7..581e96c743 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -433,7 +433,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- self.auth.check_auth_blocking(user_id)
+ yield self.auth.check_auth_blocking(user_id)
if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 16aec5a530..9fd756df0b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1671,13 +1671,17 @@ class SyncHandler(object):
"content": content,
})
- account_data = sync_config.filter_collection.filter_room_account_data(
+ account_data_events = sync_config.filter_collection.filter_room_account_data(
account_data_events
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
- if not (always_include or batch or account_data or ephemeral or full_state):
+ if not (always_include
+ or batch
+ or account_data_events
+ or ephemeral
+ or full_state):
return
state = yield self.compute_state_delta(
@@ -1748,7 +1752,7 @@ class SyncHandler(object):
room_id=room_id,
timeline=batch,
state=state,
- account_data=account_data,
+ account_data=account_data_events,
)
if room_sync or always_include:
sync_result_builder.archived.append(room_sync)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index f11b430126..3c40999338 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -125,9 +125,12 @@ class UserDirectoryHandler(object):
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- yield self.store.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url, None,
- )
+ is_support = yield self.store.is_support_user(user_id)
+ # Support users are for diagnostics and should not appear in the user directory.
+ if not is_support:
+ yield self.store.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url, None,
+ )
@defer.inlineCallbacks
def handle_user_deactivated(self, user_id):
@@ -329,14 +332,7 @@ class UserDirectoryHandler(object):
public_value=Membership.JOIN,
)
- if change is None:
- # Handle any profile changes
- yield self._handle_profile_change(
- state_key, room_id, prev_event_id, event_id,
- )
- continue
-
- if not change:
+ if change is False:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = yield self.store.is_host_joined(
@@ -354,16 +350,25 @@ class UserDirectoryHandler(object):
else:
logger.debug("Server is still in room: %r", room_id)
- if change: # The user joined
- event = yield self.store.get_event(event_id, allow_none=True)
- profile = ProfileInfo(
- avatar_url=event.content.get("avatar_url"),
- display_name=event.content.get("displayname"),
- )
+ is_support = yield self.store.is_support_user(state_key)
+ if not is_support:
+ if change is None:
+ # Handle any profile changes
+ yield self._handle_profile_change(
+ state_key, room_id, prev_event_id, event_id,
+ )
+ continue
+
+ if change: # The user joined
+ event = yield self.store.get_event(event_id, allow_none=True)
+ profile = ProfileInfo(
+ avatar_url=event.content.get("avatar_url"),
+ display_name=event.content.get("displayname"),
+ )
- yield self._handle_new_user(room_id, state_key, profile)
- else: # The user left
- yield self._handle_remove_user(room_id, state_key)
+ yield self._handle_new_user(room_id, state_key, profile)
+ else: # The user left
+ yield self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3d05f83b8c..afcf698b29 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -21,28 +21,25 @@ from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json
+from netaddr import IPAddress
from prometheus_client import Counter
+from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
-from twisted.internet import defer, protocol, reactor, ssl
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.web._newclient import ResponseDone
-from twisted.web.client import (
- Agent,
- BrowserLikeRedirectAgent,
- ContentDecoderAgent,
- GzipDecoder,
- HTTPConnectionPool,
- PartialDownloadError,
- readBody,
+from twisted.internet import defer, protocol, ssl
+from twisted.internet.interfaces import (
+ IReactorPluggableNameResolver,
+ IResolutionReceiver,
)
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
-from synapse.http.endpoint import SpiderEndpoint
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
@@ -50,8 +47,125 @@ from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
-incoming_responses_counter = Counter("synapse_http_client_responses", "",
- ["method", "code"])
+incoming_responses_counter = Counter(
+ "synapse_http_client_responses", "", ["method", "code"]
+)
+
+
+def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+ """
+ Args:
+ ip_address (netaddr.IPAddress)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ if ip_address in ip_blacklist:
+ if ip_whitelist is None or ip_address not in ip_whitelist:
+ return True
+ return False
+
+
+class IPBlacklistingResolver(object):
+ """
+ A proxy for reactor.nameResolver which only produces non-blacklisted IP
+ addresses, preventing DNS rebinding attacks on URL preview.
+ """
+
+ def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ """
+ Args:
+ reactor (twisted.internet.reactor)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ self._reactor = reactor
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+
+ def resolveHostName(self, recv, hostname, portNumber=0):
+
+ r = recv()
+ d = defer.Deferred()
+ addresses = []
+
+ @provider(IResolutionReceiver)
+ class EndpointReceiver(object):
+ @staticmethod
+ def resolutionBegan(resolutionInProgress):
+ pass
+
+ @staticmethod
+ def addressResolved(address):
+ ip_address = IPAddress(address.host)
+
+ if check_against_blacklist(
+ ip_address, self._ip_whitelist, self._ip_blacklist
+ ):
+ logger.info(
+ "Dropped %s from DNS resolution to %s" % (ip_address, hostname)
+ )
+ raise SynapseError(403, "IP address blocked by IP blacklist entry")
+
+ addresses.append(address)
+
+ @staticmethod
+ def resolutionComplete():
+ d.callback(addresses)
+
+ self._reactor.nameResolver.resolveHostName(
+ EndpointReceiver, hostname, portNumber=portNumber
+ )
+
+ def _callback(addrs):
+ r.resolutionBegan(None)
+ for i in addrs:
+ r.addressResolved(i)
+ r.resolutionComplete()
+
+ d.addCallback(_callback)
+
+ return r
+
+
+class BlacklistingAgentWrapper(Agent):
+ """
+ An Agent wrapper which will prevent access to IP addresses being accessed
+ directly (without an IP address lookup).
+ """
+
+ def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ """
+ Args:
+ agent (twisted.web.client.Agent): The Agent to wrap.
+ reactor (twisted.internet.reactor)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ self._agent = agent
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ h = urllib.parse.urlparse(uri.decode('ascii'))
+
+ try:
+ ip_address = IPAddress(h.hostname)
+
+ if check_against_blacklist(
+ ip_address, self._ip_whitelist, self._ip_blacklist
+ ):
+ logger.info(
+ "Blocking access to %s because of blacklist" % (ip_address,)
+ )
+ e = SynapseError(403, "IP address blocked by IP blacklist entry")
+ return defer.fail(Failure(e))
+ except Exception:
+ # Not an IP
+ pass
+
+ return self._agent.request(
+ method, uri, headers=headers, bodyProducer=bodyProducer
+ )
class SimpleHttpClient(object):
@@ -59,14 +173,54 @@ class SimpleHttpClient(object):
A simple, no-frills HTTP client with methods that wrap up common ways of
using HTTP in Matrix
"""
- def __init__(self, hs):
+
+ def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ """
+ Args:
+ hs (synapse.server.HomeServer)
+ treq_args (dict): Extra keyword arguments to be given to treq.request.
+ ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ we may not request.
+ ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ request if it were otherwise caught in a blacklist.
+ """
self.hs = hs
- pool = HTTPConnectionPool(reactor)
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+ self._extra_treq_args = treq_args
+
+ self.user_agent = hs.version_string
+ self.clock = hs.get_clock()
+ if hs.config.user_agent_suffix:
+ self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+
+ self.user_agent = self.user_agent.encode('ascii')
+
+ if self._ip_blacklist:
+ real_reactor = hs.get_reactor()
+ # If we have an IP blacklist, we need to use a DNS resolver which
+ # filters out blacklisted IP addresses, to prevent DNS rebinding.
+ nameResolver = IPBlacklistingResolver(
+ real_reactor, self._ip_whitelist, self._ip_blacklist
+ )
+
+ @implementer(IReactorPluggableNameResolver)
+ class Reactor(object):
+ def __getattr__(_self, attr):
+ if attr == "nameResolver":
+ return nameResolver
+ else:
+ return getattr(real_reactor, attr)
+
+ self.reactor = Reactor()
+ else:
+ self.reactor = hs.get_reactor()
# the pusher makes lots of concurrent SSL connections to sygnal, and
- # tends to do so in batches, so we need to allow the pool to keep lots
- # of idle connections around.
+ # tends to do so in batches, so we need to allow the pool to keep
+ # lots of idle connections around.
+ pool = HTTPConnectionPool(self.reactor)
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
@@ -74,20 +228,35 @@ class SimpleHttpClient(object):
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
- reactor,
+ self.reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory(),
+ contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
)
- self.user_agent = hs.version_string
- self.clock = hs.get_clock()
- if hs.config.user_agent_suffix:
- self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
- self.user_agent = self.user_agent.encode('ascii')
+ if self._ip_blacklist:
+ # If we have an IP blacklist, we then install the blacklisting Agent
+ # which prevents direct access to IP addresses, that are not caught
+ # by the DNS resolution.
+ self.agent = BlacklistingAgentWrapper(
+ self.agent,
+ self.reactor,
+ ip_whitelist=self._ip_whitelist,
+ ip_blacklist=self._ip_blacklist,
+ )
@defer.inlineCallbacks
def request(self, method, uri, data=b'', headers=None):
+ """
+ Args:
+ method (str): HTTP method to use.
+ uri (str): URI to query.
+ data (bytes): Data to send in the request body, if applicable.
+ headers (t.w.http_headers.Headers): Request headers.
+
+ Raises:
+ SynapseError: If the IP is blacklisted.
+ """
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
@@ -97,25 +266,34 @@ class SimpleHttpClient(object):
try:
request_deferred = treq.request(
- method, uri, agent=self.agent, data=data, headers=headers
+ method,
+ uri,
+ agent=self.agent,
+ data=data,
+ headers=headers,
+ **self._extra_treq_args
)
request_deferred = timeout_deferred(
- request_deferred, 60, self.hs.get_reactor(),
+ request_deferred,
+ 60,
+ self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
- "Received response to %s %s: %s",
- method, redact_uri(uri), response.code
+ "Received response to %s %s: %s", method, redact_uri(uri), response.code
)
defer.returnValue(response)
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
- method, redact_uri(uri), type(e).__name__, e.args[0]
+ method,
+ redact_uri(uri),
+ type(e).__name__,
+ e.args[0],
)
raise
@@ -140,8 +318,9 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(
- encode_urlencode_args(args), True).encode("utf8")
+ query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
+ "utf8"
+ )
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -151,10 +330,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "POST",
- uri,
- headers=Headers(actual_headers),
- data=query_bytes
+ "POST", uri, headers=Headers(actual_headers), data=query_bytes
)
if 200 <= response.code < 300:
@@ -193,10 +369,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "POST",
- uri,
- headers=Headers(actual_headers),
- data=json_str
+ "POST", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -264,10 +437,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "PUT",
- uri,
- headers=Headers(actual_headers),
- data=json_str
+ "PUT", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -299,17 +469,11 @@ class SimpleHttpClient(object):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
- actual_headers = {
- b"User-Agent": [self.user_agent],
- }
+ actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
- response = yield self.request(
- "GET",
- uri,
- headers=Headers(actual_headers),
- )
+ response = yield self.request("GET", uri, headers=Headers(actual_headers))
body = yield make_deferred_yieldable(readBody(response))
@@ -334,22 +498,18 @@ class SimpleHttpClient(object):
headers, absolute URI of the response and HTTP response code.
"""
- actual_headers = {
- b"User-Agent": [self.user_agent],
- }
+ actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
- response = yield self.request(
- "GET",
- url,
- headers=Headers(actual_headers),
- )
+ response = yield self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders())
- if (b'Content-Length' in resp_headers and
- int(resp_headers[b'Content-Length']) > max_size):
+ if (
+ b'Content-Length' in resp_headers
+ and int(resp_headers[b'Content-Length'][0]) > max_size
+ ):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -359,26 +519,20 @@ class SimpleHttpClient(object):
if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url))
- raise SynapseError(
- 502,
- "Got error %d" % (response.code,),
- Codes.UNKNOWN,
- )
+ raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
# straight back in again
try:
- length = yield make_deferred_yieldable(_readBodyToFile(
- response, output_stream, max_size,
- ))
+ length = yield make_deferred_yieldable(
+ _readBodyToFile(response, output_stream, max_size)
+ )
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
- 502,
- ("Failed to download remote body: %s" % e),
- Codes.UNKNOWN,
+ 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
)
defer.returnValue(
@@ -387,13 +541,14 @@ class SimpleHttpClient(object):
resp_headers,
response.request.absoluteURI.decode('ascii'),
response.code,
- ),
+ )
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
+
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
@@ -405,11 +560,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- ))
+ self.deferred.errback(
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+ )
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -427,6 +584,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
+
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
@@ -449,10 +607,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
"POST",
url,
data=query_bytes,
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(
+ {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ ),
)
try:
@@ -463,57 +623,6 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response)
-class SpiderEndpointFactory(object):
- def __init__(self, hs):
- self.blacklist = hs.config.url_preview_ip_range_blacklist
- self.whitelist = hs.config.url_preview_ip_range_whitelist
- self.policyForHTTPS = hs.get_http_client_context_factory()
-
- def endpointForURI(self, uri):
- logger.info("Getting endpoint for %s", uri.toBytes())
-
- if uri.scheme == b"http":
- endpoint_factory = HostnameEndpoint
- elif uri.scheme == b"https":
- tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
-
- def endpoint_factory(reactor, host, port, **kw):
- return wrapClientTLS(
- tlsCreator,
- HostnameEndpoint(reactor, host, port, **kw))
- else:
- logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
- return None
- return SpiderEndpoint(
- reactor, uri.host, uri.port, self.blacklist, self.whitelist,
- endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
- )
-
-
-class SpiderHttpClient(SimpleHttpClient):
- """
- Separate HTTP client for spidering arbitrary URLs.
- Special in that it follows retries and has a UA that looks
- like a browser.
-
- used by the preview_url endpoint in the content repo.
- """
- def __init__(self, hs):
- SimpleHttpClient.__init__(self, hs)
- # clobber the base class's agent and UA:
- self.agent = ContentDecoderAgent(
- BrowserLikeRedirectAgent(
- Agent.usingEndpointFactory(
- reactor,
- SpiderEndpointFactory(hs)
- )
- ), [(b'gzip', GzipDecoder)]
- )
- # We could look like Chrome:
- # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
- # Chrome Safari" % hs.version_string)
-
-
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 91025037a3..f86a0b624e 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -218,41 +218,6 @@ class _WrappedConnection(object):
return d
-class SpiderEndpoint(object):
- """An endpoint which refuses to connect to blacklisted IP addresses
- Implements twisted.internet.interfaces.IStreamClientEndpoint.
- """
- def __init__(self, reactor, host, port, blacklist, whitelist,
- endpoint=HostnameEndpoint, endpoint_kw_args={}):
- self.reactor = reactor
- self.host = host
- self.port = port
- self.blacklist = blacklist
- self.whitelist = whitelist
- self.endpoint = endpoint
- self.endpoint_kw_args = endpoint_kw_args
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- address = yield self.reactor.resolve(self.host)
-
- from netaddr import IPAddress
- ip_address = IPAddress(address)
-
- if ip_address in self.blacklist:
- if self.whitelist is None or ip_address not in self.whitelist:
- raise ConnectError(
- "Refusing to spider blacklisted IP address %s" % address
- )
-
- logger.info("Connecting to %s:%s", address, self.port)
- endpoint = self.endpoint(
- self.reactor, address, self.port, **self.endpoint_kw_args
- )
- connection = yield endpoint.connect(protocolFactory)
- defer.returnValue(connection)
-
-
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 24b6110c20..7a2b4f0957 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random
import sys
from io import BytesIO
-from six import PY3, string_types
+from six import PY3, raise_from, string_types
from six.moves import urllib
import attr
@@ -41,6 +41,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ RequestSendFailed,
SynapseError,
)
from synapse.http.endpoint import matrix_federation_endpoint
@@ -231,7 +232,7 @@ class MatrixFederationHttpClient(object):
Deferred: resolves with the http response object on success.
Fails with ``HttpResponseException``: if we get an HTTP response
- code >= 300.
+ code >= 300 (except 429).
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
@@ -239,8 +240,8 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
- (May also fail with plenty of other Exceptions for things like DNS
- failures, connection failures, SSL failures.)
+ Fails with ``RequestSendFailed`` if there were problems connecting to
+ the remote, due to e.g. DNS failures, connection timeouts etc.
"""
if timeout:
_sec_timeout = timeout / 1000
@@ -335,23 +336,74 @@ class MatrixFederationHttpClient(object):
reactor=self.hs.get_reactor(),
)
- with Measure(self.clock, "outbound_request"):
- response = yield make_deferred_yieldable(
- request_deferred,
+ try:
+ with Measure(self.clock, "outbound_request"):
+ response = yield make_deferred_yieldable(
+ request_deferred,
+ )
+ except DNSLookupError as e:
+ raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
+ except Exception as e:
+ raise_from(RequestSendFailed(e, can_retry=True), e)
+
+ logger.info(
+ "{%s} [%s] Got response headers: %d %s",
+ request.txn_id,
+ request.destination,
+ response.code,
+ response.phrase.decode('ascii', errors='replace'),
+ )
+
+ if 200 <= response.code < 300:
+ pass
+ else:
+ # :'(
+ # Update transactions table?
+ d = treq.content(response)
+ d = timeout_deferred(
+ d,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
)
+ try:
+ body = yield make_deferred_yieldable(d)
+ except Exception as e:
+ # Eh, we're already going to raise an exception so lets
+ # ignore if this fails.
+ logger.warn(
+ "{%s} [%s] Failed to get error response: %s %s: %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _flatten_response_never_received(e),
+ )
+ body = None
+
+ e = HttpResponseException(
+ response.code, response.phrase, body
+ )
+
+ # Retry if the error is a 429 (Too Many Requests),
+ # otherwise just raise a standard HttpResponseException
+ if response.code == 429:
+ raise_from(RequestSendFailed(e, can_retry=True), e)
+ else:
+ raise e
+
break
- except Exception as e:
+ except RequestSendFailed as e:
logger.warn(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
- _flatten_response_never_received(e),
+ _flatten_response_never_received(e.inner_exception),
)
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ if not e.can_retry:
raise
if retries_left and not timeout:
@@ -376,29 +428,16 @@ class MatrixFederationHttpClient(object):
else:
raise
- logger.info(
- "{%s} [%s] Got response headers: %d %s",
- request.txn_id,
- request.destination,
- response.code,
- response.phrase.decode('ascii', errors='replace'),
- )
-
- if 200 <= response.code < 300:
- pass
- else:
- # :'(
- # Update transactions table?
- d = treq.content(response)
- d = timeout_deferred(
- d,
- timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
- )
- body = yield make_deferred_yieldable(d)
- raise HttpResponseException(
- response.code, response.phrase, body
- )
+ except Exception as e:
+ logger.warn(
+ "{%s} [%s] Request failed: %s %s: %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _flatten_response_never_received(e),
+ )
+ raise
defer.returnValue(response)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 92422c6ffc..69c5f9fe2e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -15,173 +15,138 @@
# limitations under the License.
import logging
-from distutils.version import LooseVersion
+
+from pkg_resources import DistributionNotFound, VersionConflict, get_distribution
logger = logging.getLogger(__name__)
-# this dict maps from python package name to a list of modules we expect it to
-# provide.
-#
-# the key is a "requirement specifier", as used as a parameter to `pip
-# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
+
+# REQUIREMENTS is a simple list of requirement specifiers[1], and must be
+# installed. It is passed to setup() as install_requires in setup.py.
#
-# the value is a sequence of strings; each entry should be the name of the
-# python module, optionally followed by a version assertion which can be either
-# ">=<ver>" or "==<ver>".
+# CONDITIONAL_REQUIREMENTS is the optional dependencies, represented as a dict
+# of lists. The dict key is the optional dependency name and can be passed to
+# pip when installing. The list is a series of requirement specifiers[1] to be
+# installed when that optional dependency requirement is specified. It is passed
+# to setup() as extras_require in setup.py
#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
-# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
-REQUIREMENTS = {
- "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
- "frozendict>=1": ["frozendict"],
- "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
- "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
- "signedjson>=1.0.0": ["signedjson>=1.0.0"],
- "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
- "service_identity>=16.0.0": ["service_identity>=16.0.0"],
- "Twisted>=17.1.0": ["twisted>=17.1.0"],
- "treq>=15.1": ["treq>=15.1"],
+REQUIREMENTS = [
+ "jsonschema>=2.5.1",
+ "frozendict>=1",
+ "unpaddedbase64>=1.1.0",
+ "canonicaljson>=1.1.3",
+ "signedjson>=1.0.0",
+ "pynacl>=1.2.1",
+ "service_identity>=16.0.0",
+ "Twisted>=17.1.0",
+ "treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
- "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
-
- "pyyaml>=3.11": ["yaml"],
- "pyasn1>=0.1.9": ["pyasn1"],
- "pyasn1-modules>=0.0.7": ["pyasn1_modules"],
- "daemonize>=2.3.1": ["daemonize"],
- "bcrypt>=3.1.0": ["bcrypt>=3.1.0"],
- "pillow>=3.1.2": ["PIL"],
- "sortedcontainers>=1.4.4": ["sortedcontainers"],
- "psutil>=2.0.0": ["psutil>=2.0.0"],
- "pysaml2>=3.0.0": ["saml2"],
- "pymacaroons-pynacl>=0.9.3": ["pymacaroons"],
- "msgpack-python>=0.4.2": ["msgpack"],
- "phonenumbers>=8.2.0": ["phonenumbers"],
- "six>=1.10": ["six"],
-
+ "pyopenssl>=16.0.0",
+ "pyyaml>=3.11",
+ "pyasn1>=0.1.9",
+ "pyasn1-modules>=0.0.7",
+ "daemonize>=2.3.1",
+ "bcrypt>=3.1.0",
+ "pillow>=3.1.2",
+ "sortedcontainers>=1.4.4",
+ "psutil>=2.0.0",
+ "pymacaroons-pynacl>=0.9.3",
+ "msgpack-python>=0.4.2",
+ "phonenumbers>=8.2.0",
+ "six>=1.10",
# prometheus_client 0.4.0 changed the format of counter metrics
# (cf https://github.com/matrix-org/synapse/issues/4001)
- "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"],
-
+ "prometheus_client>=0.0.18,<0.4.0",
# we use attr.s(slots), which arrived in 16.0.0
- "attrs>=16.0.0": ["attr>=16.0.0"],
- "netaddr>=0.7.18": ["netaddr"],
-}
+ "attrs>=16.0.0",
+ "netaddr>=0.7.18",
+]
CONDITIONAL_REQUIREMENTS = {
- "email.enable_notifs": {
- "Jinja2>=2.8": ["Jinja2>=2.8"],
- "bleach>=1.4.2": ["bleach>=1.4.2"],
- },
- "matrix-synapse-ldap3": {
- "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
- },
- "postgres": {
- "psycopg2>=2.6": ["psycopg2"]
- }
-}
-
-
-def requirements(config=None, include_conditional=False):
- reqs = REQUIREMENTS.copy()
- if include_conditional:
- for _, req in CONDITIONAL_REQUIREMENTS.items():
- reqs.update(req)
- return reqs
+ "email.enable_notifs": ["Jinja2>=2.9", "bleach>=1.4.2"],
+ "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
+ "postgres": ["psycopg2>=2.6"],
+ # ConsentResource uses select_autoescape, which arrived in jinja 2.9
+ "resources.consent": ["Jinja2>=2.9"],
-def github_link(project, version, egg):
- return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
-
-
-DEPENDENCY_LINKS = {
+ "saml2": ["pysaml2>=4.5.0"],
+ "url_preview": ["lxml>=3.5.0"],
+ "test": ["mock>=2.0"],
}
-class MissingRequirementError(Exception):
- def __init__(self, message, module_name, dependency):
- super(MissingRequirementError, self).__init__(message)
- self.module_name = module_name
- self.dependency = dependency
-
-
-def check_requirements(config=None):
- """Checks that all the modules needed by synapse have been correctly
- installed and are at the correct version"""
- for dependency, module_requirements in (
- requirements(config, include_conditional=False).items()):
- for module_requirement in module_requirements:
- if ">=" in module_requirement:
- module_name, required_version = module_requirement.split(">=")
- version_test = ">="
- elif "==" in module_requirement:
- module_name, required_version = module_requirement.split("==")
- version_test = "=="
- else:
- module_name = module_requirement
- version_test = None
-
- try:
- module = __import__(module_name)
- except ImportError:
- logging.exception(
- "Can't import %r which is part of %r",
- module_name, dependency
- )
- raise MissingRequirementError(
- "Can't import %r which is part of %r"
- % (module_name, dependency), module_name, dependency
- )
- version = getattr(module, "__version__", None)
- file_path = getattr(module, "__file__", None)
- logger.info(
- "Using %r version %r from %r to satisfy %r",
- module_name, version, file_path, dependency
+def list_requirements():
+ deps = set(REQUIREMENTS)
+ for opt in CONDITIONAL_REQUIREMENTS.values():
+ deps = set(opt) | deps
+
+ return list(deps)
+
+
+class DependencyException(Exception):
+ @property
+ def message(self):
+ return "\n".join([
+ "Missing Requirements: %s" % (", ".join(self.dependencies),),
+ "To install run:",
+ " pip install --upgrade --force %s" % (" ".join(self.dependencies),),
+ "",
+ ])
+
+ @property
+ def dependencies(self):
+ for i in self.args[0]:
+ yield '"' + i + '"'
+
+
+def check_requirements(for_feature=None, _get_distribution=get_distribution):
+ deps_needed = []
+ errors = []
+
+ if for_feature:
+ reqs = CONDITIONAL_REQUIREMENTS[for_feature]
+ else:
+ reqs = REQUIREMENTS
+
+ for dependency in reqs:
+ try:
+ _get_distribution(dependency)
+ except VersionConflict as e:
+ deps_needed.append(dependency)
+ errors.append(
+ "Needed %s, got %s==%s"
+ % (dependency, e.dist.project_name, e.dist.version)
)
+ except DistributionNotFound:
+ deps_needed.append(dependency)
+ errors.append("Needed %s but it was not installed" % (dependency,))
- if version_test == ">=":
- if version is None:
- raise MissingRequirementError(
- "Version of %r isn't set as __version__ of module %r"
- % (dependency, module_name), module_name, dependency
- )
- if LooseVersion(version) < LooseVersion(required_version):
- raise MissingRequirementError(
- "Version of %r in %r is too old. %r < %r"
- % (dependency, file_path, version, required_version),
- module_name, dependency
- )
- elif version_test == "==":
- if version is None:
- raise MissingRequirementError(
- "Version of %r isn't set as __version__ of module %r"
- % (dependency, module_name), module_name, dependency
- )
- if LooseVersion(version) != LooseVersion(required_version):
- raise MissingRequirementError(
- "Unexpected version of %r in %r. %r != %r"
- % (dependency, file_path, version, required_version),
- module_name, dependency
- )
+ if not for_feature:
+ # Check the optional dependencies are up to date. We allow them to not be
+ # installed.
+ OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
+ for dependency in OPTS:
+ try:
+ _get_distribution(dependency)
+ except VersionConflict:
+ deps_needed.append(dependency)
+ errors.append("Needed %s but it was not installed" % (dependency,))
+ except DistributionNotFound:
+ # If it's not found, we don't care
+ pass
-def list_requirements():
- result = []
- linked = []
- for link in DEPENDENCY_LINKS.values():
- egg = link.split("#egg=")[1]
- linked.append(egg.split('-')[0])
- result.append(link)
- for requirement in requirements(include_conditional=True):
- is_linked = False
- for link in linked:
- if requirement.replace('-', '_').startswith(link):
- is_linked = True
- if not is_linked:
- result.append(requirement)
- return result
+ if deps_needed:
+ for e in errors:
+ logging.error(e)
+
+ raise DependencyException(deps_needed)
if __name__ == "__main__":
import sys
+
sys.stdout.writelines(req + "\n" for req in list_requirements())
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 5f35c2d1be..66585c991f 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import PY3
-
from synapse.http.server import JsonResource
from synapse.rest.client import versions
from synapse.rest.client.v1 import (
@@ -56,11 +54,6 @@ from synapse.rest.client.v2_alpha import (
user_directory,
)
-if not PY3:
- from synapse.rest.client.v1_only import (
- register as v1_register,
- )
-
class ClientRestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
@@ -73,10 +66,6 @@ class ClientRestResource(JsonResource):
def register_servlets(client_resource, hs):
versions.register_servlets(client_resource)
- if not PY3:
- # "v1" (Python 2 only)
- v1_register.register_servlets(hs, client_resource)
-
# Deprecated in r0
initial_sync.register_servlets(hs, client_resource)
room.register_deprecated_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index b9c3bc4f9f..2e303264f6 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -23,7 +23,7 @@ from six.moves import http_client
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
assert_params_in_dict,
@@ -158,6 +158,11 @@ class UserRegisterServlet(ClientV1RestServlet):
raise SynapseError(400, "Invalid password")
admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
got_mac = body["mac"]
want_mac = hmac.new(
@@ -171,6 +176,9 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac.update(b"\x00")
+ want_mac.update(user_type.encode('utf8'))
want_mac = want_mac.hexdigest()
if not hmac.compare_digest(
@@ -189,6 +197,7 @@ class UserRegisterServlet(ClientV1RestServlet):
password=body["password"],
admin=bool(admin),
generate_token=False,
+ user_type=user_type,
)
result = yield register._create_registration_details(user_id, body)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6b4a85e40..942e4d3816 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,17 +18,18 @@ import xml.etree.ElementTree as ET
from six.moves import urllib
-from canonicaljson import json
-from saml2 import BINDING_HTTP_POST, config
-from saml2.client import Saml2Client
-
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.well_known import WellKnownBuilder
+from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
@@ -81,7 +82,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
@@ -89,8 +89,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
- self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
@@ -98,13 +96,12 @@ class LoginRestServlet(ClientV1RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
+ self._well_known_builder = WellKnownBuilder(hs)
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -134,29 +131,21 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
- relay_state = ""
- if "relay_state" in login_submission:
- relay_state = "&RelayState=" + urllib.parse.quote(
- login_submission["relay_state"])
- result = {
- "uri": "%s%s" % (self.idp_redirect_url, relay_state)
- }
- defer.returnValue((200, result))
- elif self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
+ if self.jwt_enabled and (login_submission["type"] ==
+ LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
- defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
- defer.returnValue(result)
else:
result = yield self._do_other_login(login_submission)
- defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
+ well_known_data = self._well_known_builder.get_well_known()
+ if well_known_data:
+ result["well_known"] = well_known_data
+ defer.returnValue((200, result))
+
@defer.inlineCallbacks
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
@@ -165,7 +154,7 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission:
Returns:
- (int, object): HTTP code/response
+ dict: HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -248,7 +237,7 @@ class LoginRestServlet(ClientV1RestServlet):
if callback is not None:
yield callback(result)
- defer.returnValue((200, result))
+ defer.returnValue(result)
@defer.inlineCallbacks
def do_token_login(self, login_submission):
@@ -268,7 +257,7 @@ class LoginRestServlet(ClientV1RestServlet):
"device_id": device_id,
}
- defer.returnValue((200, result))
+ defer.returnValue(result)
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
@@ -322,7 +311,7 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
}
- defer.returnValue((200, result))
+ defer.returnValue(result)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
@@ -345,50 +334,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
-class SAML2RestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/saml2", releases=())
-
- def __init__(self, hs):
- super(SAML2RestServlet, self).__init__(hs)
- self.sp_config = hs.config.saml2_config_path
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- saml2_auth = None
- try:
- conf = config.SPConfig()
- conf.load_file(self.sp_config)
- SP = Saml2Client(conf)
- saml2_auth = SP.parse_authn_request_response(
- request.args['SAMLResponse'][0], BINDING_HTTP_POST)
- except Exception as e: # Not authenticated
- logger.exception(e)
- if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
- username = saml2_auth.name_id.text
- handler = self.handlers.registration_handler
- (user_id, token) = yield handler.register_saml2(username)
- # Forward to the RelayState callback along with ava
- if 'RelayState' in request.args:
- request.redirect(urllib.parse.unquote(
- request.args['RelayState'][0]) +
- '?status=authenticated&access_token=' +
- token + '&user_id=' + user_id + '&ava=' +
- urllib.quote(json.dumps(saml2_auth.ava)))
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "authenticated",
- "user_id": user_id, "token": token,
- "ava": saml2_auth.ava}))
- elif 'RelayState' in request.args:
- request.redirect(urllib.parse.unquote(
- request.args['RelayState'][0]) +
- '?status=not_authenticated')
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "not_authenticated"}))
-
-
class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
@@ -421,17 +366,15 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
- self.auth_handler = hs.get_auth_handler()
- self.handlers = hs.get_handlers()
- self.macaroon_gen = hs.get_macaroon_generator()
+ self._sso_auth_handler = SSOAuthHandler(hs)
@defer.inlineCallbacks
def on_GET(self, request):
- client_redirect_url = request.args[b"redirectUrl"][0]
+ client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
- "ticket": request.args[b"ticket"][0].decode('ascii'),
+ "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
@@ -443,7 +386,6 @@ class CasTicketServlet(ClientV1RestServlet):
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
- @defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
@@ -459,28 +401,9 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id, _ = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
-
- login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ return self._sso_auth_handler.on_successful_auth(
+ user, request, client_redirect_url,
)
- redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
- login_token)
- request.redirect(redirect_url)
- finish_request(request)
-
- def add_login_token_to_redirect_url(self, url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
- return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None
@@ -515,10 +438,78 @@ class CasTicketServlet(ClientV1RestServlet):
return user, attributes
+class SSOAuthHandler(object):
+ """
+ Utility class for Resources and Servlets which handle the response from a SSO
+ service
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ def __init__(self, hs):
+ self._hostname = hs.hostname
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_handlers().registration_handler
+ self._macaroon_gen = hs.get_macaroon_generator()
+
+ @defer.inlineCallbacks
+ def on_successful_auth(
+ self, username, request, client_redirect_url,
+ user_display_name=None,
+ ):
+ """Called once the user has successfully authenticated with the SSO.
+
+ Registers the user if necessary, and then returns a redirect (with
+ a login token) to the client.
+
+ Args:
+ username (unicode|bytes): the remote user id. We'll map this onto
+ something sane for a MXID localpath.
+
+ request (SynapseRequest): the incoming request from the browser. We'll
+ respond to it with a redirect.
+
+ client_redirect_url (unicode): the redirect_url the client gave us when
+ it first started the process.
+
+ user_display_name (unicode|None): if set, and we have to register a new user,
+ we will set their displayname to this.
+
+ Returns:
+ Deferred[none]: Completes once we have handled the request.
+ """
+ localpart = map_username_to_mxid_localpart(username)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id, _ = (
+ yield self._registration_handler.register(
+ localpart=localpart,
+ generate_token=False,
+ default_display_name=user_display_name,
+ )
+ )
+
+ login_token = self._macaroon_gen.generate_short_term_login_token(
+ registered_user_id
+ )
+ redirect_url = self._add_login_token_to_redirect_url(
+ client_redirect_url, login_token
+ )
+ request.redirect(redirect_url)
+ finish_request(request)
+
+ @staticmethod
+ def _add_login_token_to_redirect_url(url, token):
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"loginToken": token})
+ url_parts[4] = urllib.parse.urlencode(query)
+ return urllib.parse.urlunparse(url_parts)
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
- if hs.config.saml2_enabled:
- SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py
deleted file mode 100644
index 936f902ace..0000000000
--- a/synapse/rest/client/v1_only/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""
-REST APIs that are only used in v1 (the legacy API).
-"""
diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py
deleted file mode 100644
index 9d4db7437c..0000000000
--- a/synapse/rest/client/v1_only/base.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains base REST classes for constructing client v1 servlets.
-"""
-
-import re
-
-from synapse.api.urls import CLIENT_PREFIX
-
-
-def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
- """Creates a regex compiled client path with the correct client path
- prefix.
-
- Args:
- path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
- Returns:
- list of SRE_Pattern
- """
- patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
- if include_in_unstable:
- unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
- patterns.append(re.compile("^" + unstable_prefix + path_regex))
- return patterns
diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py
deleted file mode 100644
index dadb376b02..0000000000
--- a/synapse/rest/client/v1_only/register.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains REST servlets to do with registration: /register"""
-import hmac
-import logging
-from hashlib import sha1
-
-from twisted.internet import defer
-
-import synapse.util.stringutils as stringutils
-from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError
-from synapse.config.server import is_threepid_reserved
-from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
-from synapse.rest.client.v1.base import ClientV1RestServlet
-from synapse.types import create_requester
-
-from .base import v1_only_client_path_patterns
-
-logger = logging.getLogger(__name__)
-
-
-# We ought to be using hmac.compare_digest() but on older pythons it doesn't
-# exist. It's a _really minor_ security flaw to use plain string comparison
-# because the timing attack is so obscured by all the other code here it's
-# unlikely to make much difference
-if hasattr(hmac, "compare_digest"):
- compare_digest = hmac.compare_digest
-else:
- def compare_digest(a, b):
- return a == b
-
-
-class RegisterRestServlet(ClientV1RestServlet):
- """Handles registration with the home server.
-
- This servlet is in control of the registration flow; the registration
- handler doesn't have a concept of multi-stages or sessions.
- """
-
- PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(RegisterRestServlet, self).__init__(hs)
- # sessions are stored as:
- # self.sessions = {
- # "session_id" : { __session_dict__ }
- # }
- # TODO: persistent storage
- self.sessions = {}
- self.enable_registration = hs.config.enable_registration
- self.auth = hs.get_auth()
- self.auth_handler = hs.get_auth_handler()
- self.handlers = hs.get_handlers()
-
- def on_GET(self, request):
-
- require_email = 'email' in self.hs.config.registrations_require_3pid
- require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
-
- flows = []
- if self.hs.config.enable_registration_captcha:
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([
- {
- "type": LoginType.RECAPTCHA,
- "stages": [
- LoginType.RECAPTCHA,
- LoginType.EMAIL_IDENTITY,
- LoginType.PASSWORD
- ]
- },
- ])
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([
- {
- "type": LoginType.RECAPTCHA,
- "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
- }
- ])
- else:
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if require_email or not require_msisdn:
- flows.extend([
- {
- "type": LoginType.EMAIL_IDENTITY,
- "stages": [
- LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
- ]
- }
- ])
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([
- {
- "type": LoginType.PASSWORD
- }
- ])
- return (200, {"flows": flows})
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- register_json = parse_json_object_from_request(request)
-
- session = (register_json["session"]
- if "session" in register_json else None)
- login_type = None
- assert_params_in_dict(register_json, ["type"])
-
- try:
- login_type = register_json["type"]
-
- is_application_server = login_type == LoginType.APPLICATION_SERVICE
- can_register = (
- self.enable_registration
- or is_application_server
- )
- if not can_register:
- raise SynapseError(403, "Registration has been disabled")
-
- stages = {
- LoginType.RECAPTCHA: self._do_recaptcha,
- LoginType.PASSWORD: self._do_password,
- LoginType.EMAIL_IDENTITY: self._do_email_identity,
- LoginType.APPLICATION_SERVICE: self._do_app_service,
- }
-
- session_info = self._get_session_info(request, session)
- logger.debug("%s : session info %s request info %s",
- login_type, session_info, register_json)
- response = yield stages[login_type](
- request,
- register_json,
- session_info
- )
-
- if "access_token" not in response:
- # isn't a final response
- response["session"] = session_info["id"]
-
- defer.returnValue((200, response))
- except KeyError as e:
- logger.exception(e)
- raise SynapseError(400, "Missing JSON keys for login type %s." % (
- login_type,
- ))
-
- def on_OPTIONS(self, request):
- return (200, {})
-
- def _get_session_info(self, request, session_id):
- if not session_id:
- # create a new session
- while session_id is None or session_id in self.sessions:
- session_id = stringutils.random_string(24)
- self.sessions[session_id] = {
- "id": session_id,
- LoginType.EMAIL_IDENTITY: False,
- LoginType.RECAPTCHA: False
- }
-
- return self.sessions[session_id]
-
- def _save_session(self, session):
- # TODO: Persistent storage
- logger.debug("Saving session %s", session)
- self.sessions[session["id"]] = session
-
- def _remove_session(self, session):
- logger.debug("Removing session %s", session)
- self.sessions.pop(session["id"])
-
- @defer.inlineCallbacks
- def _do_recaptcha(self, request, register_json, session):
- if not self.hs.config.enable_registration_captcha:
- raise SynapseError(400, "Captcha not required.")
-
- yield self._check_recaptcha(request, register_json, session)
-
- session[LoginType.RECAPTCHA] = True # mark captcha as done
- self._save_session(session)
- defer.returnValue({
- "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY]
- })
-
- @defer.inlineCallbacks
- def _check_recaptcha(self, request, register_json, session):
- if ("captcha_bypass_hmac" in register_json and
- self.hs.config.captcha_bypass_secret):
- if "user" not in register_json:
- raise SynapseError(400, "Captcha bypass needs 'user'")
-
- want = hmac.new(
- key=self.hs.config.captcha_bypass_secret,
- msg=register_json["user"],
- digestmod=sha1,
- ).hexdigest()
-
- # str() because otherwise hmac complains that 'unicode' does not
- # have the buffer interface
- got = str(register_json["captcha_bypass_hmac"])
-
- if compare_digest(want, got):
- session["user"] = register_json["user"]
- defer.returnValue(None)
- else:
- raise SynapseError(
- 400, "Captcha bypass HMAC incorrect",
- errcode=Codes.CAPTCHA_NEEDED
- )
-
- challenge = None
- user_response = None
- try:
- challenge = register_json["challenge"]
- user_response = register_json["response"]
- except KeyError:
- raise SynapseError(400, "Captcha response is required",
- errcode=Codes.CAPTCHA_NEEDED)
-
- ip_addr = self.hs.get_ip_from_request(request)
-
- handler = self.handlers.registration_handler
- yield handler.check_recaptcha(
- ip_addr,
- self.hs.config.recaptcha_private_key,
- challenge,
- user_response
- )
-
- @defer.inlineCallbacks
- def _do_email_identity(self, request, register_json, session):
- if (self.hs.config.enable_registration_captcha and
- not session[LoginType.RECAPTCHA]):
- raise SynapseError(400, "Captcha is required.")
-
- threepidCreds = register_json['threepidCreds']
- handler = self.handlers.registration_handler
- logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
- yield handler.register_email(threepidCreds)
- session["threepidCreds"] = threepidCreds # store creds for next stage
- session[LoginType.EMAIL_IDENTITY] = True # mark email as done
- self._save_session(session)
- defer.returnValue({
- "next": LoginType.PASSWORD
- })
-
- @defer.inlineCallbacks
- def _do_password(self, request, register_json, session):
- if (self.hs.config.enable_registration_captcha and
- not session[LoginType.RECAPTCHA]):
- # captcha should've been done by this stage!
- raise SynapseError(400, "Captcha is required.")
-
- if ("user" in session and "user" in register_json and
- session["user"] != register_json["user"]):
- raise SynapseError(
- 400, "Cannot change user ID during registration"
- )
-
- password = register_json["password"].encode("utf-8")
- desired_user_id = (
- register_json["user"].encode("utf-8")
- if "user" in register_json else None
- )
- threepid = None
- if session.get(LoginType.EMAIL_IDENTITY):
- threepid = session["threepidCreds"]
-
- handler = self.handlers.registration_handler
- (user_id, token) = yield handler.register(
- localpart=desired_user_id,
- password=password,
- threepid=threepid,
- )
- # Necessary due to auth checks prior to the threepid being
- # written to the db
- if is_threepid_reserved(self.hs.config, threepid):
- yield self.store.upsert_monthly_active_user(user_id)
-
- if session[LoginType.EMAIL_IDENTITY]:
- logger.debug("Binding emails %s to %s" % (
- session["threepidCreds"], user_id)
- )
- yield handler.bind_emails(user_id, session["threepidCreds"])
-
- result = {
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- }
- self._remove_session(session)
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _do_app_service(self, request, register_json, session):
- as_token = self.auth.get_access_token_from_request(request)
-
- assert_params_in_dict(register_json, ["user"])
- user_localpart = register_json["user"].encode("utf-8")
-
- handler = self.handlers.registration_handler
- user_id = yield handler.appservice_register(
- user_localpart, as_token
- )
- token = yield self.auth_handler.issue_access_token(user_id)
- self._remove_session(session)
- defer.returnValue({
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- })
-
-
-class CreateUserRestServlet(ClientV1RestServlet):
- """Handles user creation via a server-to-server interface
- """
-
- PATTERNS = v1_only_client_path_patterns("/createUser$")
-
- def __init__(self, hs):
- super(CreateUserRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- user_json = parse_json_object_from_request(request)
-
- access_token = self.auth.get_access_token_from_request(request)
- app_service = self.store.get_app_service_by_token(
- access_token
- )
- if not app_service:
- raise SynapseError(403, "Invalid application service token.")
-
- requester = create_requester(app_service.sender)
-
- logger.debug("creating user: %s", user_json)
- response = yield self._do_create(requester, user_json)
-
- defer.returnValue((200, response))
-
- def on_OPTIONS(self, request):
- return 403, {}
-
- @defer.inlineCallbacks
- def _do_create(self, requester, user_json):
- assert_params_in_dict(user_json, ["localpart", "displayname"])
-
- localpart = user_json["localpart"].encode("utf-8")
- displayname = user_json["displayname"].encode("utf-8")
- password_hash = user_json["password_hash"].encode("utf-8") \
- if user_json.get("password_hash") else None
-
- handler = self.handlers.registration_handler
- user_id, token = yield handler.get_or_create_user(
- requester=requester,
- localpart=localpart,
- displayname=displayname,
- password_hash=password_hash
- )
-
- defer.returnValue({
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- })
-
-
-def register_servlets(hs, http_server):
- RegisterRestServlet(hs).register(http_server)
- CreateUserRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 371e9aa354..f171b8d626 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class AccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
+ GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
@@ -57,10 +58,26 @@ class AccountDataServlet(RestServlet):
defer.returnValue((200, {}))
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, account_data_type):
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Cannot get account data for other users.")
+
+ event = yield self.store.get_global_account_data_by_type_for_user(
+ account_data_type, user_id,
+ )
+
+ if event is None:
+ raise NotFoundError("Account data not found")
+
+ defer.returnValue((200, event))
+
class RoomAccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
+ GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)"
@@ -99,6 +116,21 @@ class RoomAccountDataServlet(RestServlet):
defer.returnValue((200, {}))
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, room_id, account_data_type):
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Cannot get account data for other users.")
+
+ event = yield self.store.get_account_data_for_room_and_type(
+ user_id, room_id, account_data_type,
+ )
+
+ if event is None:
+ raise NotFoundError("Room account data not found")
+
+ defer.returnValue((200, event))
+
def register_servlets(hs, http_server):
AccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index d6605b6027..77316033f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -41,7 +41,7 @@ class MediaConfigResource(Resource):
@defer.inlineCallbacks
def _async_render_GET(self, request):
yield self.auth.get_user_by_req(request)
- respond_with_json(request, 200, self.limits_dict)
+ respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index f911b120b1..bdc5daecc1 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -48,7 +48,8 @@ class DownloadResource(Resource):
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
- b"default-src 'none';"
+ b"sandbox;"
+ b" default-src 'none';"
b" script-src 'none';"
b" plugin-types application/pdf;"
b" style-src 'unsafe-inline';"
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e117836e9a..bdffa97805 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -30,6 +30,7 @@ from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
NotFoundError,
+ RequestSendFailed,
SynapseError,
)
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -372,10 +373,10 @@ class MediaRepository(object):
"allow_remote": "false",
}
)
- except twisted.internet.error.DNSLookupError as e:
- logger.warn("HTTP error fetching remote media %s/%s: %r",
+ except RequestSendFailed as e:
+ logger.warn("Request failed fetching remote media %s/%s: %r",
server_name, media_id, e)
- raise NotFoundError()
+ raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index d0ecf241b6..ba3ab1d37d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -35,7 +35,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.http.client import SpiderHttpClient
+from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
respond_with_json,
respond_with_json_bytes,
@@ -69,7 +69,12 @@ class PreviewUrlResource(Resource):
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
- self.client = SpiderHttpClient(hs)
+ self.client = SimpleHttpClient(
+ hs,
+ treq_args={"browser_like_redirects": True},
+ ip_whitelist=hs.config.url_preview_ip_range_whitelist,
+ ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ )
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
@@ -318,6 +323,11 @@ class PreviewUrlResource(Resource):
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/saml2/__init__.py
new file mode 100644
index 0000000000..68da37ca6a
--- /dev/null
+++ b/synapse/rest/saml2/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.saml2.response_resource import SAML2ResponseResource
+
+logger = logging.getLogger(__name__)
+
+
+class SAML2Resource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
+ self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py
new file mode 100644
index 0000000000..e8c680aeb4
--- /dev/null
+++ b/synapse/rest/saml2/metadata_resource.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import saml2.metadata
+
+from twisted.web.resource import Resource
+
+
+class SAML2MetadataResource(Resource):
+ """A Twisted web resource which renders the SAML metadata"""
+
+ isLeaf = 1
+
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.sp_config = hs.config.saml2_sp_config
+
+ def render_GET(self, request):
+ metadata_xml = saml2.metadata.create_metadata_string(
+ configfile=None, config=self.sp_config,
+ )
+ request.setHeader(b"Content-Type", b"text/xml; charset=utf-8")
+ return metadata_xml
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
new file mode 100644
index 0000000000..69fb77b322
--- /dev/null
+++ b/synapse/rest/saml2/response_resource.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import saml2
+from saml2.client import Saml2Client
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.api.errors import CodeMessageException
+from synapse.http.server import wrap_html_request_handler
+from synapse.http.servlet import parse_string
+from synapse.rest.client.v1.login import SSOAuthHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SAML2ResponseResource(Resource):
+ """A Twisted web resource which handles the SAML response"""
+
+ isLeaf = 1
+
+ def __init__(self, hs):
+ Resource.__init__(self)
+
+ self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+ self._sso_auth_handler = SSOAuthHandler(hs)
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ def _async_render_POST(self, request):
+ resp_bytes = parse_string(request, 'SAMLResponse', required=True)
+ relay_state = parse_string(request, 'RelayState', required=True)
+
+ try:
+ saml2_auth = self._saml_client.parse_authn_request_response(
+ resp_bytes, saml2.BINDING_HTTP_POST,
+ )
+ except Exception as e:
+ logger.warning("Exception parsing SAML2 response", exc_info=1)
+ raise CodeMessageException(
+ 400, "Unable to parse SAML2 response: %s" % (e,),
+ )
+
+ if saml2_auth.not_signed:
+ raise CodeMessageException(400, "SAML2 response was not signed")
+
+ if "uid" not in saml2_auth.ava:
+ raise CodeMessageException(400, "uid not in SAML2 response")
+
+ username = saml2_auth.ava["uid"][0]
+
+ displayName = saml2_auth.ava.get("displayName", [None])[0]
+ return self._sso_auth_handler.on_successful_auth(
+ username, request, relay_state,
+ user_display_name=displayName,
+ )
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
new file mode 100644
index 0000000000..6e043d6162
--- /dev/null
+++ b/synapse/rest/well_known.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+
+from twisted.web.resource import Resource
+
+logger = logging.getLogger(__name__)
+
+
+class WellKnownBuilder(object):
+ """Utility to construct the well-known response
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ def __init__(self, hs):
+ self._config = hs.config
+
+ def get_well_known(self):
+ # if we don't have a public_base_url, we can't help much here.
+ if self._config.public_baseurl is None:
+ return None
+
+ result = {
+ "m.homeserver": {
+ "base_url": self._config.public_baseurl,
+ },
+ }
+
+ if self._config.default_identity_server:
+ result["m.identity_server"] = {
+ "base_url": self._config.default_identity_server,
+ }
+
+ return result
+
+
+class WellKnownResource(Resource):
+ """A Twisted web resource which renders the .well-known file"""
+
+ isLeaf = 1
+
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self._well_known_builder = WellKnownBuilder(hs)
+
+ def render_GET(self, request):
+ r = self._well_known_builder.get_well_known()
+ if not r:
+ request.setResponseCode(404)
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b'.well-known not available'
+
+ logger.error("returning: %s", r)
+ request.setHeader(b"Content-Type", b"application/json")
+ return json.dumps(r).encode("utf-8")
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 70048b0c09..e9ecb00277 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -607,7 +607,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
return v1.resolve_events_with_store(
state_sets, event_map, state_res_store.get_events,
)
- elif room_version in (RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST):
+ elif room_version in (
+ RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
+ ):
return v2.resolve_events_with_store(
state_sets, event_map, state_res_store,
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b23fb7e56c..24329879e5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,12 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import datetime
+import calendar
import logging
import time
-from dateutil import tz
-
from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
@@ -357,10 +355,11 @@ class DataStore(RoomMemberStore, RoomStore,
"""
Returns millisecond unixtime for start of UTC day.
"""
- now = datetime.datetime.utcnow()
- today_start = datetime.datetime(now.year, now.month,
- now.day, tzinfo=tz.tzutc())
- return int(time.mktime(today_start.timetuple())) * 1000
+ now = time.gmtime()
+ today_start = calendar.timegm((
+ now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
+ ))
+ return today_start * 1000
def generate_user_daily_visits(self):
"""
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 479e01ddc1..d6fc8edd4c 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -55,9 +55,12 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn,
tp["medium"], tp["address"]
)
+
if user_id:
- self.upsert_monthly_active_user_txn(txn, user_id)
- reserved_user_list.append(user_id)
+ is_support = self.is_support_user_txn(txn, user_id)
+ if not is_support:
+ self.upsert_monthly_active_user_txn(txn, user_id)
+ reserved_user_list.append(user_id)
else:
logger.warning(
"mau limit reserved threepid %s not found in db" % tp
@@ -182,6 +185,18 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id (str): user to add/update
"""
+ # Support user never to be included in MAU stats. Note I can't easily call this
+ # from upsert_monthly_active_user_txn because then I need a _txn form of
+ # is_support_user which is complicated because I want to cache the result.
+ # Therefore I call it here and ignore the case where
+ # upsert_monthly_active_user_txn is called directly from
+ # _initialise_reserved_users reasoning that it would be very strange to
+ # include a support user in this context.
+
+ is_support = yield self.is_support_user(user_id)
+ if is_support:
+ return
+
is_insert = yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id
@@ -200,6 +215,16 @@ class MonthlyActiveUsersStore(SQLBaseStore):
in a database thread rather than the main thread, and we can't call
txn.call_after because txn may not be a LoggingTransaction.
+ We consciously do not call is_support_txn from this method because it
+ is not possible to cache the response. is_support_txn will be false in
+ almost all cases, so it seems reasonable to call it only for
+ upsert_monthly_active_user and to call is_support_txn manually
+ for cases where upsert_monthly_active_user_txn is called directly,
+ like _initialise_reserved_users
+
+ In short, don't call this method with support users. (Support users
+ should not appear in the MAU stats).
+
Args:
txn (cursor):
user_id (str): user to add/update
@@ -208,6 +233,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
bool: True if a new entry was created, False if an
existing one was updated.
"""
+
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 80d76bf9d7..c9e11c3135 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,9 +19,11 @@ from six.moves import range
from twisted.internet import defer
+from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
+from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -112,6 +114,31 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
+ @cachedInlineCallbacks()
+ def is_support_user(self, user_id):
+ """Determines if the user is of type UserTypes.SUPPORT
+
+ Args:
+ user_id (str): user id to test
+
+ Returns:
+ Deferred[bool]: True if user is of type UserTypes.SUPPORT
+ """
+ res = yield self.runInteraction(
+ "is_support_user", self.is_support_user_txn, user_id
+ )
+ defer.returnValue(res)
+
+ def is_support_user_txn(self, txn, user_id):
+ res = self._simple_select_one_onecol_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="user_type",
+ allow_none=True,
+ )
+ return True if res == UserTypes.SUPPORT else False
+
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@@ -167,7 +194,7 @@ class RegistrationStore(RegistrationWorkerStore,
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_localpart=None, admin=False):
+ create_profile_with_displayname=None, admin=False, user_type=None):
"""Attempts to register an account.
Args:
@@ -181,8 +208,12 @@ class RegistrationStore(RegistrationWorkerStore,
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_localpart (str): Optionally create a profile for
- the given localpart.
+ create_profile_with_displayname (unicode): Optionally create a profile for
+ the user, setting their displayname to the given value
+ admin (boolean): is an admin user?
+ user_type (str|None): type of user. One of the values from
+ api.constants.UserTypes, or None for a normal user.
+
Raises:
StoreError if the user_id could not be registered.
"""
@@ -195,8 +226,9 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest,
make_guest,
appservice_id,
- create_profile_with_localpart,
- admin
+ create_profile_with_displayname,
+ admin,
+ user_type
)
def _register(
@@ -208,9 +240,12 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest,
make_guest,
appservice_id,
- create_profile_with_localpart,
+ create_profile_with_displayname,
admin,
+ user_type,
):
+ user_id_obj = UserID.from_string(user_id)
+
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next()
@@ -244,6 +279,7 @@ class RegistrationStore(RegistrationWorkerStore,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
+ "user_type": user_type,
}
)
else:
@@ -257,6 +293,7 @@ class RegistrationStore(RegistrationWorkerStore,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
+ "user_type": user_type,
}
)
except self.database_engine.module.IntegrityError:
@@ -273,12 +310,15 @@ class RegistrationStore(RegistrationWorkerStore,
(next_id, user_id, token,)
)
- if create_profile_with_localpart:
+ if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
+ #
+ # *obviously* the 'profiles' table uses localpart for user_id
+ # while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
- (create_profile_with_localpart, create_profile_with_localpart)
+ (user_id_obj.localpart, create_profile_with_displayname)
)
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/schema/delta/53/add_user_type_to_users.sql
new file mode 100644
index 0000000000..88ec2f83e5
--- /dev/null
+++ b/synapse/storage/schema/delta/53/add_user_type_to_users.sql
@@ -0,0 +1,19 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The type of the user: NULL for a regular user, or one of the constants in
+ * synapse.api.constants.UserTypes
+ */
+ALTER TABLE users ADD COLUMN user_type TEXT DEFAULT NULL;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d737bd6778..a134e9b3e8 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -432,7 +432,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_id = state_ids.get((EventTypes.Create, ""))
if not create_id:
- raise NotFoundError("Unknown room")
+ raise NotFoundError("Unknown room %s" % (room_id))
create_event = yield self.get_event(create_id)
defer.returnValue(create_event.content.get("room_version", "1"))
diff --git a/synapse/types.py b/synapse/types.py
index 41afb27a74..d8cb64addb 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
import string
from collections import namedtuple
@@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
+
+# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
+# localpart.
+#
+# It works by:
+# * building a string containing the allowed characters (excluding '=')
+# * escaping every special character with a backslash (to stop '-' being interpreted as a
+# range operator)
+# * wrapping it in a '[^...]' regex
+# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
+# bytes rather than strings
+#
+NON_MXID_CHARACTER_PATTERN = re.compile(
+ ("[^%s]" % (
+ re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
+ )).encode("ascii"),
+)
+
+
+def map_username_to_mxid_localpart(username, case_sensitive=False):
+ """Map a username onto a string suitable for a MXID
+
+ This follows the algorithm laid out at
+ https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
+
+ Args:
+ username (unicode|bytes): username to be mapped
+ case_sensitive (bool): true if TEST and test should be mapped
+ onto different mxids
+
+ Returns:
+ unicode: string suitable for a mxid localpart
+ """
+ if not isinstance(username, bytes):
+ username = username.encode('utf-8')
+
+ # first we sort out upper-case characters
+ if case_sensitive:
+ def f1(m):
+ return b"_" + m.group().lower()
+
+ username = UPPER_CASE_PATTERN.sub(f1, username)
+ else:
+ username = username.lower()
+
+ # then we sort out non-ascii characters
+ def f2(m):
+ g = m.group()[0]
+ if isinstance(g, str):
+ # on python 2, we need to do a ord(). On python 3, the
+ # byte itself will do.
+ g = ord(g)
+ return b"=%02x" % (g,)
+
+ username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
+
+ # we also do the =-escaping to mxids starting with an underscore.
+ username = re.sub(b'^_', b'=5f', username)
+
+ # we should now only have ascii bytes left, so can decode back to a
+ # unicode.
+ return username.decode('ascii')
+
+
class StreamToken(
namedtuple("Token", (
"room_key",
|