diff --git a/.travis.yml b/.travis.yml
index e6ba6f4752..a98d547978 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -4,7 +4,12 @@ language: python
# tell travis to cache ~/.cache/pip
cache: pip
+before_script:
+ - git remote set-branches --add origin develop
+ - git fetch origin develop
+
matrix:
+ fast_finish: true
include:
- python: 2.7
env: TOX_ENV=packaging
@@ -14,10 +19,13 @@ matrix:
- python: 2.7
env: TOX_ENV=py27
-
+
- python: 3.6
env: TOX_ENV=py36
+ - python: 3.6
+ env: TOX_ENV=check-newsfragment
+
install:
- pip install tox
diff --git a/CHANGES.rst b/CHANGES.rst
index f2b7f04097..70fc5af4c1 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,15 +1,36 @@
+Changes in synapse v0.31.2 (2018-06-14)
+=======================================
+
+SECURITY UPDATE: Prevent unauthorised users from setting state events in a room
+when there is no ``m.room.power_levels`` event in force in the room. (PR #3397)
+
+Discussion around the Matrix Spec change proposal for this change can be
+followed at https://github.com/matrix-org/matrix-doc/issues/1304.
+
+Changes in synapse v0.31.1 (2018-06-08)
+=======================================
+
+v0.31.1 fixes a security bug in the ``get_missing_events`` federation API
+where event visibility rules were not applied correctly.
+
+We are not aware of it being actively exploited but please upgrade asap.
+
+Bug Fixes:
+
+* Fix event filtering in get_missing_events handler (PR #3371)
+
Changes in synapse v0.31.0 (2018-06-06)
=======================================
-Most notable change from v0.30.0 is to switch to python prometheus library to improve system
-stats reporting. WARNING this changes a number of prometheus metrics in a
+Most notable change from v0.30.0 is to switch to the python prometheus library to improve system
+stats reporting. WARNING: this changes a number of prometheus metrics in a
backwards-incompatible manner. For more details, see
`docs/metrics-howto.rst <docs/metrics-howto.rst#removal-of-deprecated-metrics--time-based-counters-becoming-histograms-in-0310>`_.
Bug Fixes:
* Fix metric documentation tables (PR #3341)
-* Fix LaterGuage error handling (694968f)
+* Fix LaterGauge error handling (694968f)
* Fix replication metrics (b7e7fd2)
Changes in synapse v0.31.0-rc1 (2018-06-04)
@@ -29,7 +50,6 @@ Changes:
* Remove users from user directory on deactivate (PR #3277)
* Avoid sending consent notice to guest users (PR #3288)
* disable CPUMetrics if no /proc/self/stat (PR #3299)
-* Add local and loopback IPv6 addresses to url_preview_ip_range_blacklist (PR #3312) Thanks to @thegcat!
* Consistently use six's iteritems and wrap lazy keys/values in list() if they're not meant to be lazy (PR #3307)
* Add private IPv6 addresses to example config for url preview blacklist (PR #3317) Thanks to @thegcat!
* Reduce stuck read-receipts: ignore depth when updating (PR #3318)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index c6ee16efc7..6c295cfbfe 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -48,6 +48,26 @@ Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
+Changelog
+~~~~~~~~~
+
+All changes, even minor ones, need a corresponding changelog
+entry. These are managed by Towncrier
+(https://github.com/hawkowl/towncrier).
+
+To create a changelog entry, make a new file in the ``changelog.d``
+file named in the format of ``issuenumberOrPR.type``. The type can be
+one of ``feature``, ``bugfix``, ``removal`` (also used for
+deprecations), or ``misc`` (for internal-only changes). The content of
+the file is your changelog entry, which can contain RestructuredText
+formatting. A note of contributors is welcomed in changelogs for
+non-misc changes (the content of misc changes is not displayed).
+
+For example, a fix for a bug reported in #1234 would have its
+changelog entry in ``changelog.d/1234.bugfix``, and contain content
+like "The security levels of Florbs are now validated when
+recieved over federation. Contributed by Jane Matrix".
+
Attribution
~~~~~~~~~~~
@@ -110,11 +130,15 @@ If you agree to this for your contribution, then all that's needed is to
include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org>
-
-...using your real name; unfortunately pseudonyms and anonymous contributions
-can't be accepted. Git makes this trivial - just use the -s flag when you do
-``git commit``, having first set ``user.name`` and ``user.email`` git configs
-(which you should have done anyway :)
+
+We accept contributions under a legally identifiable name, such as
+your name on government documentation or common-law names (names
+claimed by legitimate usage or repute). Unfortunately, we cannot
+accept anonymous contributions at this time.
+
+Git allows you to add this signoff automatically when using the ``-s``
+flag to ``git commit``, which uses the name and email set in your
+``user.name`` and ``user.email`` git configs.
Conclusion
~~~~~~~~~~
diff --git a/MANIFEST.in b/MANIFEST.in
index e2a6623a63..97f57f443f 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -29,5 +29,8 @@ exclude Dockerfile
exclude .dockerignore
recursive-exclude jenkins *.sh
+include pyproject.toml
+recursive-include changelog.d *
+
prune .github
prune demo/etc
diff --git a/changelog.d/3324.removal b/changelog.d/3324.removal
new file mode 100644
index 0000000000..11dc6a3d74
--- /dev/null
+++ b/changelog.d/3324.removal
@@ -0,0 +1 @@
+Remove was_forgotten_at
diff --git a/changelog.d/3327.bugfix b/changelog.d/3327.bugfix
new file mode 100644
index 0000000000..97e8c0a990
--- /dev/null
+++ b/changelog.d/3327.bugfix
@@ -0,0 +1 @@
+Strip access_token from outgoing requests
diff --git a/changelog.d/3332.misc b/changelog.d/3332.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3332.misc
diff --git a/changelog.d/3334.feature b/changelog.d/3334.feature
new file mode 100644
index 0000000000..71c98f7262
--- /dev/null
+++ b/changelog.d/3334.feature
@@ -0,0 +1 @@
+Cache factor override system for specific caches
\ No newline at end of file
diff --git a/changelog.d/3340.doc b/changelog.d/3340.doc
new file mode 100644
index 0000000000..8395564ec7
--- /dev/null
+++ b/changelog.d/3340.doc
@@ -0,0 +1 @@
+``doc/postgres.rst``: fix display of the last command block. Thanks to @ArchangeGabriel!
diff --git a/changelog.d/3341.misc b/changelog.d/3341.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3341.misc
diff --git a/changelog.d/3344.feature b/changelog.d/3344.feature
new file mode 100644
index 0000000000..ab2e4fcef4
--- /dev/null
+++ b/changelog.d/3344.feature
@@ -0,0 +1 @@
+Add metrics to track appservice transactions
diff --git a/changelog.d/3347.misc b/changelog.d/3347.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3347.misc
diff --git a/changelog.d/3348.misc b/changelog.d/3348.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3348.misc
diff --git a/changelog.d/3349.bugfix b/changelog.d/3349.bugfix
new file mode 100644
index 0000000000..aa45bab3ba
--- /dev/null
+++ b/changelog.d/3349.bugfix
@@ -0,0 +1 @@
+Redact AS tokens in logs
diff --git a/changelog.d/3355.bugfix b/changelog.d/3355.bugfix
new file mode 100644
index 0000000000..80105a0e95
--- /dev/null
+++ b/changelog.d/3355.bugfix
@@ -0,0 +1 @@
+Fix federation backfill from SQLite servers
diff --git a/changelog.d/3356.misc b/changelog.d/3356.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3356.misc
diff --git a/changelog.d/3363.bugfix b/changelog.d/3363.bugfix
new file mode 100644
index 0000000000..d8895195c2
--- /dev/null
+++ b/changelog.d/3363.bugfix
@@ -0,0 +1 @@
+Fix event-purge-by-ts admin API
diff --git a/changelog.d/3371.bugfix b/changelog.d/3371.bugfix
new file mode 100644
index 0000000000..553f2b126e
--- /dev/null
+++ b/changelog.d/3371.bugfix
@@ -0,0 +1 @@
+Fix event filtering in get_missing_events handler
diff --git a/changelog.d/3372.feature b/changelog.d/3372.feature
new file mode 100644
index 0000000000..7f58f3ccac
--- /dev/null
+++ b/changelog.d/3372.feature
@@ -0,0 +1 @@
+Try to log more helpful info when a sig verification fails
diff --git a/changelog.d/3446.misc b/changelog.d/3446.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3446.misc
diff --git a/changelog.d/3447.misc b/changelog.d/3447.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3447.misc
diff --git a/changelog.d/3456.bugfix b/changelog.d/3456.bugfix
new file mode 100644
index 0000000000..3310dcb3ff
--- /dev/null
+++ b/changelog.d/3456.bugfix
@@ -0,0 +1 @@
+Synapse is now stricter regarding accepting events which it cannot retrieve the prev_events for.
diff --git a/changelog.d/3462.feature b/changelog.d/3462.feature
new file mode 100644
index 0000000000..305dbbeddd
--- /dev/null
+++ b/changelog.d/3462.feature
@@ -0,0 +1 @@
+Synapse now uses the best performing JSON encoder/decoder according to your runtime (simplejson on CPython, stdlib json on PyPy).
\ No newline at end of file
diff --git a/changelog.d/3465.feature b/changelog.d/3465.feature
new file mode 100644
index 0000000000..1a0b5abfb7
--- /dev/null
+++ b/changelog.d/3465.feature
@@ -0,0 +1 @@
+Add optional ip_range_whitelist param to AS registration files to lock AS IP access
diff --git a/changelog.d/3467.misc b/changelog.d/3467.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3467.misc
diff --git a/changelog.d/3470.bugfix b/changelog.d/3470.bugfix
new file mode 100644
index 0000000000..1308931191
--- /dev/null
+++ b/changelog.d/3470.bugfix
@@ -0,0 +1 @@
+Fix bug where synapse would explode when receiving unicode in HTTP User-Agent header
diff --git a/changelog.d/3480.feature b/changelog.d/3480.feature
new file mode 100644
index 0000000000..a21580943d
--- /dev/null
+++ b/changelog.d/3480.feature
@@ -0,0 +1 @@
+Reject invalid server names in federation requests
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 1c9c5a6bde..d17121a188 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -44,13 +44,26 @@ Deactivate Account
This API deactivates an account. It removes active access tokens, resets the
password, and deletes third-party IDs (to prevent the user requesting a
-password reset).
+password reset). It can also mark the user as GDPR-erased (stopping their data
+from distributed further, and deleting it entirely if there are no other
+references to it).
The api is::
POST /_matrix/client/r0/admin/deactivate/<user_id>
-including an ``access_token`` of a server admin, and an empty request body.
+with a body of:
+
+.. code:: json
+
+ {
+ "erase": true
+ }
+
+including an ``access_token`` of a server admin.
+
+The erase parameter is optional and defaults to 'false'.
+An empty body may be passed for backwards compatibility.
Reset password
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..d1603b5d8b
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,5 @@
+[tool.towncrier]
+ package = "synapse"
+ filename = "CHANGES.rst"
+ directory = "changelog.d"
+ issue_format = "`#{issue} <https://github.com/matrix-org/synapse/issues/{issue}>`_"
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 3b28417376..d2acc7654d 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -18,14 +18,22 @@
from __future__ import print_function
import argparse
+from urlparse import urlparse, urlunparse
+
import nacl.signing
import json
import base64
import requests
import sys
+
+from requests.adapters import HTTPAdapter
import srvlookup
import yaml
+# uncomment the following to enable debug logging of http requests
+#from httplib import HTTPConnection
+#HTTPConnection.debuglevel = 1
+
def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding."""
@@ -113,17 +121,6 @@ def read_signing_keys(stream):
return keys
-def lookup(destination, path):
- if ":" in destination:
- return "https://%s%s" % (destination, path)
- else:
- try:
- srv = srvlookup.lookup("matrix", "tcp", destination)[0]
- return "https://%s:%d%s" % (srv.host, srv.port, path)
- except:
- return "https://%s:%d%s" % (destination, 8448, path)
-
-
def request_json(method, origin_name, origin_key, destination, path, content):
if method is None:
if content is None:
@@ -152,13 +149,19 @@ def request_json(method, origin_name, origin_key, destination, path, content):
authorization_headers.append(bytes(header))
print ("Authorization: %s" % header, file=sys.stderr)
- dest = lookup(destination, path)
+ dest = "matrix://%s%s" % (destination, path)
print ("Requesting %s" % dest, file=sys.stderr)
- result = requests.request(
+ s = requests.Session()
+ s.mount("matrix://", MatrixConnectionAdapter())
+
+ result = s.request(
method=method,
url=dest,
- headers={"Authorization": authorization_headers[0]},
+ headers={
+ "Host": destination,
+ "Authorization": authorization_headers[0]
+ },
verify=False,
data=content,
)
@@ -242,5 +245,39 @@ def read_args_from_config(args):
args.signing_key_path = config['signing_key_path']
+class MatrixConnectionAdapter(HTTPAdapter):
+ @staticmethod
+ def lookup(s):
+ if s[-1] == ']':
+ # ipv6 literal (with no port)
+ return s, 8448
+
+ if ":" in s:
+ out = s.rsplit(":",1)
+ try:
+ port = int(out[1])
+ except ValueError:
+ raise ValueError("Invalid host:port '%s'" % s)
+ return out[0], port
+
+ try:
+ srv = srvlookup.lookup("matrix", "tcp", s)[0]
+ return srv.host, srv.port
+ except:
+ return s, 8448
+
+ def get_connection(self, url, proxies=None):
+ parsed = urlparse(url)
+
+ (host, port) = self.lookup(parsed.netloc)
+ netloc = "%s:%d" % (host, port)
+ print("Connecting to %s" % (netloc,), file=sys.stderr)
+ url = urlunparse((
+ "https", netloc, parsed.path, parsed.params, parsed.query,
+ parsed.fragment,
+ ))
+ return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
+
+
if __name__ == "__main__":
main()
diff --git a/synapse/__init__.py b/synapse/__init__.py
index ca113db434..faa183a99e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.31.0"
+__version__ = "0.31.2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 06fa38366d..088b4e8b6d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -19,6 +19,7 @@ from six import itervalues
import pymacaroons
from twisted.internet import defer
+from netaddr import IPAddress
import synapse.types
from synapse import event_auth
@@ -244,6 +245,11 @@ class Auth(object):
if app_service is None:
defer.returnValue((None, None))
+ if app_service.ip_range_whitelist:
+ ip_address = IPAddress(self.hs.get_ip_from_request(request))
+ if ip_address not in app_service.ip_range_whitelist:
+ defer.returnValue((None, None))
+
if "user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))
@@ -488,7 +494,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
- logger.warn("Unrecognised access token - not in store: %s" % (token,))
+ logger.warn("Unrecognised access token - not in store.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
@@ -511,7 +517,7 @@ class Auth(object):
)
service = self.store.get_app_service_by_token(token)
if not service:
- logger.warn("Unrecognised appservice access token: %s" % (token,))
+ logger.warn("Unrecognised appservice access token.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
@@ -655,7 +661,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", auth_events
+ EventTypes.Aliases, "", power_level_event,
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index e6ad3768f0..227a0713b2 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,7 +17,8 @@
import logging
-import simplejson as json
+from canonicaljson import json
+
from six import iteritems
from six.moves import http_client
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index dbc0e7e445..aae25e7a47 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -17,7 +17,8 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID
from twisted.internet import defer
-import simplejson as json
+from canonicaljson import json
+
import jsonschema
from jsonschema import FormatChecker
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index dd114dee07..4319ddce03 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -23,6 +23,7 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
@@ -62,7 +63,7 @@ class AppserviceServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
root_resource = create_resource_tree(resources, NoResource())
@@ -97,7 +98,7 @@ class AppserviceServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 85dada7f9f..654ddb8414 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -122,7 +122,7 @@ class ClientReaderServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 5ca77c0f1a..441467093a 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -138,7 +138,7 @@ class EventCreatorServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 2a1995d0cd..b2415cc671 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -111,7 +111,7 @@ class FederationReaderServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 81ad574043..13d2b70053 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -125,7 +125,7 @@ class FederationSenderServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 5a164a7a95..d2bae4ad03 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -176,7 +176,7 @@ class FrontendProxyServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 714f98a3e0..ae5fc751d5 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -266,7 +266,7 @@ class SynapseHomeServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
@@ -318,11 +318,6 @@ def setup(config_options):
# check any extra requirements we have now we have a config
check_requirements(config)
- version_string = "Synapse/" + get_version_string(synapse)
-
- logger.info("Server hostname: %s", config.server_name)
- logger.info("Server version: %s", version_string)
-
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
@@ -335,7 +330,7 @@ def setup(config_options):
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
- version_string=version_string,
+ version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 006bba80a8..19a682cce3 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -118,7 +118,7 @@ class MediaRepositoryServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 64df47f9cc..13cfbd08b0 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -128,7 +128,7 @@ class PusherServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 6808d6d3e0..82f06ea185 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -305,7 +305,7 @@ class SynchrotronServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index ada1c13cec..f5726e3df6 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer):
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
- "collect_metrics is not enabled!"))
+ "enable_metrics is not True!"))
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index d1c598622a..328cbfa284 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -85,7 +85,8 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
- sender=None, id=None, protocols=None, rate_limited=True):
+ sender=None, id=None, protocols=None, rate_limited=True,
+ ip_range_whitelist=None):
self.token = token
self.url = url
self.hs_token = hs_token
@@ -93,6 +94,7 @@ class ApplicationService(object):
self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ self.ip_range_whitelist = ip_range_whitelist
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 277305e184..0c27bb2fa7 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -17,6 +17,8 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
+from netaddr import IPSet
+
import yaml
import logging
@@ -154,6 +156,13 @@ def _load_appservice(hostname, as_info, config_filename):
" will not receive events or queries.",
config_filename,
)
+
+ ip_range_whitelist = None
+ if as_info.get('ip_range_whitelist'):
+ ip_range_whitelist = IPSet(
+ as_info.get('ip_range_whitelist')
+ )
+
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@@ -163,5 +172,6 @@ def _load_appservice(hostname, as_info, config_filename):
sender=user_id,
id=as_info["id"],
protocols=protocols,
- rate_limited=rate_limited
+ rate_limited=rate_limited,
+ ip_range_whitelist=ip_range_whitelist,
)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 6a7228dc2f..557c270fbe 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -12,17 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from ._base import Config
-from synapse.util.logcontext import LoggingContextFilter
-from twisted.logger import globalLogBeginner, STDLibLogObserver
import logging
import logging.config
-import yaml
-from string import Template
import os
import signal
+from string import Template
+import sys
+from twisted.logger import STDLibLogObserver, globalLogBeginner
+import yaml
+
+import synapse
+from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.versionstring import get_version_string
+from ._base import Config
DEFAULT_LOG_CONFIG = Template("""
version: 1
@@ -202,6 +205,15 @@ def setup_logging(config, use_worker_options=False):
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
+ # make sure that the first thing we log is a thing we can grep backwards
+ # for
+ logging.warn("***** STARTING SERVER *****")
+ logging.warn(
+ "Server %s version %s",
+ sys.argv[0], get_version_string(synapse),
+ )
+ logging.info("Server hostname: %s", config.server_name)
+
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
# see: https://twistedmatrix.com/trac/ticket/8164
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index f1fd488b90..2a0eddbea1 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,7 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
-import simplejson as json
+from canonicaljson import json
import logging
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 22ee0fc93f..9b17ef0a08 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -27,10 +27,12 @@ from synapse.util.metrics import Measure
from twisted.internet import defer
from signedjson.sign import (
- verify_signed_json, signature_ids, sign_json, encode_canonical_json
+ verify_signed_json, signature_ids, sign_json, encode_canonical_json,
+ SignatureVerifyException,
)
from signedjson.key import (
- is_signing_algorithm_supported, decode_verify_key_bytes
+ is_signing_algorithm_supported, decode_verify_key_bytes,
+ encode_verify_key_base64,
)
from unpaddedbase64 import decode_base64, encode_base64
@@ -56,7 +58,7 @@ Attributes:
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
- deferred(twisted.internet.defer.Deferred):
+ deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
@@ -736,6 +738,17 @@ class Keyring(object):
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
+ """Waits for the key to become available, and then performs a verification
+
+ Args:
+ verify_request (VerifyKeyRequest):
+
+ Returns:
+ Deferred[None]
+
+ Raises:
+ SynapseError if there was a problem performing the verification
+ """
server_name = verify_request.server_name
try:
with PreserveLoggingContext():
@@ -768,11 +781,17 @@ def _handle_key_deferred(verify_request):
))
try:
verify_signed_json(json_object, server_name, verify_key)
- except Exception:
+ except SignatureVerifyException as e:
+ logger.debug(
+ "Error verifying signature for %s:%s:%s with key %s: %s",
+ server_name, verify_key.alg, verify_key.version,
+ encode_verify_key_base64(verify_key),
+ str(e),
+ )
raise SynapseError(
401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
+ "Invalid signature for server %s with key %s:%s: %s" % (
+ server_name, verify_key.alg, verify_key.version, str(e),
),
Codes.UNAUTHORIZED,
)
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index eaf9cecde6..f512d88145 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -34,9 +34,11 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
+ Raises:
+ AuthError if the checks fail
Returns:
- True if the auth checks pass.
+ if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
@@ -71,7 +73,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
- return True
+ return
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
@@ -81,7 +83,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
"Creation event's room_id domain does not match sender's"
)
# FIXME
- return True
+ logger.debug("Allowing! %s", event)
+ return
creation_event = auth_events.get((EventTypes.Create, ""), None)
@@ -118,7 +121,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403,
"Alias event's state_key does not match sender's domain"
)
- return True
+ logger.debug("Allowing! %s", event)
+ return
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -127,14 +131,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
if event.type == EventTypes.Member:
- allowed = _is_membership_change_allowed(
- event, auth_events
- )
- if allowed:
- logger.debug("Allowing! %s", event)
- else:
- logger.debug("Denying! %s", event)
- return allowed
+ _is_membership_change_allowed(event, auth_events)
+ logger.debug("Allowing! %s", event)
+ return
_check_event_sender_in_room(event, auth_events)
@@ -153,7 +152,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
)
else:
- return True
+ logger.debug("Allowing! %s", event)
+ return
_can_send_event(event, auth_events)
@@ -200,7 +200,7 @@ def _is_membership_change_allowed(event, auth_events):
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
- return True
+ return
target_user_id = event.state_key
@@ -265,13 +265,13 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
- return True
+ return
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
- return True
+ return
if not caller_in_room: # caller isn't joined
raise AuthError(
@@ -334,8 +334,6 @@ def _is_membership_change_allowed(event, auth_events):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- return True
-
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
@@ -355,35 +353,46 @@ def _check_joined_room(member, user_id, room_id):
))
-def get_send_level(etype, state_key, auth_events):
- key = (EventTypes.PowerLevels, "", )
- send_level_event = auth_events.get(key)
- send_level = None
- if send_level_event:
- send_level = send_level_event.content.get("events", {}).get(
- etype
- )
- if send_level is None:
- if state_key is not None:
- send_level = send_level_event.content.get(
- "state_default", 50
- )
- else:
- send_level = send_level_event.content.get(
- "events_default", 0
- )
+def get_send_level(etype, state_key, power_levels_event):
+ """Get the power level required to send an event of a given type
+
+ The federation spec [1] refers to this as "Required Power Level".
+
+ https://matrix.org/docs/spec/server_server/unstable.html#definitions
- if send_level:
- send_level = int(send_level)
+ Args:
+ etype (str): type of event
+ state_key (str|None): state_key of state event, or None if it is not
+ a state event.
+ power_levels_event (synapse.events.EventBase|None): power levels event
+ in force at this point in the room
+ Returns:
+ int: power level required to send this event.
+ """
+
+ if power_levels_event:
+ power_levels_content = power_levels_event.content
else:
- send_level = 0
+ power_levels_content = {}
+
+ # see if we have a custom level for this event type
+ send_level = power_levels_content.get("events", {}).get(etype)
+
+ # otherwise, fall back to the state_default/events_default.
+ if send_level is None:
+ if state_key is not None:
+ send_level = power_levels_content.get("state_default", 50)
+ else:
+ send_level = power_levels_content.get("events_default", 0)
- return send_level
+ return int(send_level)
def _can_send_event(event, auth_events):
+ power_levels_event = _get_power_level_event(auth_events)
+
send_level = get_send_level(
- event.type, event.get("state_key", None), auth_events
+ event.type, event.get("state_key"), power_levels_event,
)
user_level = get_user_power_level(event.user_id, auth_events)
@@ -524,13 +533,22 @@ def _check_power_levels(event, auth_events):
def _get_power_level_event(auth_events):
- key = (EventTypes.PowerLevels, "", )
- return auth_events.get(key)
+ return auth_events.get((EventTypes.PowerLevels, ""))
def get_user_power_level(user_id, auth_events):
- power_level_event = _get_power_level_event(auth_events)
+ """Get a user's power level
+
+ Args:
+ user_id (str): user's id to look up in power_levels
+ auth_events (dict[(str, str), synapse.events.EventBase]):
+ state in force at this point in the room (or rather, a subset of
+ it including at least the create event and power levels event.
+ Returns:
+ int: the user's power level in this room.
+ """
+ power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
@@ -541,6 +559,11 @@ def get_user_power_level(user_id, auth_events):
else:
return int(level)
else:
+ # if there is no power levels event, the creator gets 100 and everyone
+ # else gets 0.
+
+ # some things which call this don't pass the create event: hack around
+ # that.
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2d420a58a2..fe51ba6806 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-import simplejson as json
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
@@ -277,7 +277,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
+ pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -470,17 +470,6 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
@@ -560,7 +549,9 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+ yield self.handler.on_receive_pdu(
+ origin, pdu, get_missing=True, sent_to_us_directly=True,
+ )
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f0aeb5a0d3..d72b057e28 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,7 +21,6 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext
-from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@@ -42,8 +41,11 @@ import logging
logger = logging.getLogger(__name__)
-sent_pdus_destination_dist = Counter(
- "synapse_federation_transaction_queue_sent_pdu_destinations", ""
+sent_pdus_destination_dist_count = Counter(
+ "synapse_federation_client_sent_pdu_destinations:count", ""
+)
+sent_pdus_destination_dist_total = Counter(
+ "synapse_federation_client_sent_pdu_destinations:total", ""
)
@@ -280,7 +282,8 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc(len(destinations))
+ sent_pdus_destination_dist_total.inc(len(destinations))
+ sent_pdus_destination_dist_count.inc()
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
@@ -451,9 +454,6 @@ class TransactionQueue(object):
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
- # XXX: what's this for?
- yield run_on_reactor()
-
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..1180d4b69d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -99,26 +100,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +108,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,6 +146,47 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+ # ensure that the origin is a valid server name
+ parse_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3c0051586d..cbef1f2770 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,8 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from twisted.internet import defer, threads
+from canonicaljson import json
+
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -23,7 +26,6 @@ from synapse.api.errors import (
)
from synapse.module_api import ModuleApi
from synapse.types import UserID
-from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable
@@ -32,7 +34,7 @@ from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import pymacaroons
-import simplejson
+import attr
import synapse.util.stringutils as stringutils
@@ -402,7 +404,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = simplejson.loads(data)
+ resp_body = json.loads(data)
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
@@ -423,15 +425,11 @@ class AuthHandler(BaseHandler):
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
- @defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
- yield run_on_reactor()
- defer.returnValue(True)
+ return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
- yield run_on_reactor()
-
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -825,6 +823,15 @@ class AuthHandler(BaseHandler):
if medium == 'email':
address = address.lower()
+ identity_handler = self.hs.get_handlers().identity_handler
+ yield identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': medium,
+ 'address': address,
+ },
+ )
+
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
@@ -849,7 +856,11 @@ class AuthHandler(BaseHandler):
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
- return make_deferred_yieldable(threads.deferToThread(_do_hash))
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
+ ),
+ )
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -869,16 +880,21 @@ class AuthHandler(BaseHandler):
)
if stored_hash:
- return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(),
+ self.hs.get_reactor().getThreadPool(),
+ _do_validate_hash,
+ ),
+ )
else:
return defer.succeed(False)
-class MacaroonGeneartor(object):
- def __init__(self, hs):
- self.clock = hs.get_clock()
- self.server_name = hs.config.server_name
- self.macaroon_secret_key = hs.config.macaroon_secret_key
+@attr.s
+class MacaroonGenerator(object):
+
+ hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
@@ -896,7 +912,7 @@ class MacaroonGeneartor(object):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
- now = self.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
@@ -908,9 +924,9 @@ class MacaroonGeneartor(object):
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
- location=self.server_name,
+ location=self.hs.config.server_name,
identifier="key",
- key=self.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c5e92f6214..a84b7b8b80 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, create_requester
from synapse.util.logcontext import run_in_background
+from synapse.api.errors import SynapseError
import logging
@@ -30,6 +31,7 @@ class DeactivateAccountHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
+ self._identity_handler = hs.get_handlers().identity_handler
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -37,14 +39,15 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
- reactor.callWhenRunning(self._start_user_parting)
+ hs.get_reactor().callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks
- def deactivate_account(self, user_id):
+ def deactivate_account(self, user_id, erase_data):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
+ erase_data (bool): whether to GDPR-erase the user's data
Returns:
Deferred
@@ -52,14 +55,35 @@ class DeactivateAccountHandler(BaseHandler):
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
- # first delete any devices belonging to the user, which will also
+ # delete threepids first. We remove these from the IS so if this fails,
+ # leave the user still active so they can try again.
+ # Ideally we would prevent password resets and then do this in the
+ # background thread.
+ threepids = yield self.store.user_get_threepids(user_id)
+ for threepid in threepids:
+ try:
+ yield self._identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': threepid['medium'],
+ 'address': threepid['address'],
+ },
+ )
+ except Exception:
+ # Do we want this to be a fatal error or should we carry on?
+ logger.exception("Failed to remove threepid from ID server")
+ raise SynapseError(400, "Failed to remove threepid from ID server")
+ yield self.store.user_delete_threepid(
+ user_id, threepid['medium'], threepid['address'],
+ )
+
+ # delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
- yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
# Add the user to a table of users pending deactivation (ie.
@@ -69,6 +93,11 @@ class DeactivateAccountHandler(BaseHandler):
# delete from user directory
yield self.user_directory_handler.handle_user_deactivated(user_id)
+ # Mark the user as erased, if they asked for that
+ if erase_data:
+ logger.info("Marking %s as erased", user_id)
+ yield self.store.mark_user_erased(user_id)
+
# Now start the process that goes through that list and
# parts users from rooms (if it isn't already running)
self._start_user_parting()
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8a2d177539..62b4892a4e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -14,10 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import simplejson as json
import logging
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from six import iteritems
@@ -80,7 +79,7 @@ class E2eKeysHandler(object):
else:
remote_queries[user_id] = device_ids
- # Firt get local devices.
+ # First get local devices.
failures = {}
results = {}
if local_query:
@@ -357,7 +356,7 @@ def _exception_to_failure(e):
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
- # give a string for e.message, which simplejson then fails to serialize.
+ # give a string for e.message, which json then fails to serialize.
return {
"status": 503, "message": str(e.message),
}
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 60b97b140e..6350a4c204 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -39,11 +39,12 @@ from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError, logcontext
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, Linearizer
+from synapse.util.async import Linearizer
from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
)
+from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id
from synapse.events.utils import prune_event
@@ -89,7 +90,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_receive_pdu(self, origin, pdu, get_missing=True):
+ def on_receive_pdu(
+ self, origin, pdu, get_missing=True, sent_to_us_directly=False,
+ ):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -103,8 +106,10 @@ class FederationHandler(BaseHandler):
"""
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
+ existing = yield self.store.get_event(
+ pdu.event_id,
+ allow_none=True,
+ allow_rejected=True,
)
# FIXME: Currently we fetch an event again when we already have it
@@ -161,14 +166,11 @@ class FederationHandler(BaseHandler):
"Ignoring PDU %s for room %s from %s as we've left the room!",
pdu.event_id, pdu.room_id, origin,
)
- return
+ defer.returnValue(None)
state = None
-
auth_chain = []
- fetch_state = False
-
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
@@ -223,26 +225,60 @@ class FederationHandler(BaseHandler):
list(prevs - seen)[:5],
)
- if prevs - seen:
- logger.info(
- "Still missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ if sent_to_us_directly and prevs - seen:
+ # If they have sent it to us directly, and the server
+ # isn't telling us about the auth events that it's
+ # made a message referencing, we explode
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
)
- fetch_state = True
+ elif prevs - seen:
+ # Calculate the state of the previous events, and
+ # de-conflict them to find the current state.
+ state_groups = []
+ auth_chains = set()
+ try:
+ # Get the state of the events we know about
+ ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
+ state_groups.append(ours)
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in prevs - seen:
+ state, got_auth_chain = (
+ yield self.replication_layer.get_state_for_room(
+ origin, pdu.room_id, p
+ )
+ )
+ auth_chains.update(got_auth_chain)
+ state_group = {(x.type, x.state_key): x.event_id for x in state}
+ state_groups.append(state_group)
+
+ # Resolve any conflicting state
+ def fetch(ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False
+ )
- if fetch_state:
- # We need to get the state at this event, since we haven't
- # processed all the prev events.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- try:
- state, auth_chain = yield self.replication_layer.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
- except Exception:
- logger.exception("Failed to get state for event: %s", pdu.event_id)
+ state_map = yield resolve_events_with_factory(
+ state_groups, {pdu.event_id: pdu}, fetch
+ )
+
+ state = (yield self.store.get_events(state_map.values())).values()
+ auth_chain = list(auth_chains)
+ except Exception:
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=pdu.event_id,
+ )
yield self._process_received_pdu(
origin,
@@ -320,11 +356,17 @@ class FederationHandler(BaseHandler):
for e in missing_events:
logger.info("Handling found event %s", e.event_id)
- yield self.on_receive_pdu(
- origin,
- e,
- get_missing=False
- )
+ try:
+ yield self.on_receive_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warn("Event %s failed history check.")
+ else:
+ raise
@log_function
@defer.inlineCallbacks
@@ -458,6 +500,47 @@ class FederationHandler(BaseHandler):
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
+ """Filter the given events for the given server, redacting those the
+ server can't see.
+
+ Assumes the server is currently in the room.
+
+ Returns
+ list[FrozenEvent]
+ """
+ # First lets check to see if all the events have a history visibility
+ # of "shared" or "world_readable". If thats the case then we don't
+ # need to check membership (as we know the server is in the room).
+ event_to_state_ids = yield self.store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
+ )
+
+ visibility_ids = set()
+ for sids in event_to_state_ids.itervalues():
+ hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist:
+ visibility_ids.add(hist)
+
+ # If we failed to find any history visibility events then the default
+ # is "shared" visiblity.
+ if not visibility_ids:
+ defer.returnValue(events)
+
+ event_map = yield self.store.get_events(visibility_ids)
+ all_open = all(
+ e.content.get("history_visibility") in (None, "shared", "world_readable")
+ for e in event_map.itervalues()
+ )
+
+ if all_open:
+ defer.returnValue(events)
+
+ # Ok, so we're dealing with events that have non-trivial visibility
+ # rules, so we need to also get the memberships of the room.
+
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
@@ -493,7 +576,20 @@ class FederationHandler(BaseHandler):
for e_id, key_to_eid in event_to_state_ids.iteritems()
}
+ erased_senders = yield self.store.are_users_erased(
+ e.sender for e in events,
+ )
+
def redact_disallowed(event, state):
+ # if the sender has been gdpr17ed, always return a redacted
+ # copy of the event.
+ if erased_senders[event.sender]:
+ logger.info(
+ "Sender of %s has been erased, redacting",
+ event.event_id,
+ )
+ return prune_event(event)
+
if not state:
return event
@@ -1371,8 +1467,6 @@ class FederationHandler(BaseHandler):
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
@@ -1403,8 +1497,6 @@ class FederationHandler(BaseHandler):
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
@@ -1446,11 +1538,20 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
+ def get_persisted_pdu(self, origin, event_id):
+ """Get an event from the database for the given server.
+
+ Args:
+ origin [str]: hostname of server which is requesting the event; we
+ will check that the server is allowed to see it.
+ event_id [str]: id of the event being requested
Returns:
- Deferred: Results in a `Pdu`.
+ Deferred[EventBase|None]: None if we know nothing about the event;
+ otherwise the (possibly-redacted) event.
+
+ Raises:
+ AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event_id,
@@ -1459,20 +1560,17 @@ class FederationHandler(BaseHandler):
)
if event:
- if do_auth:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- events = yield self._filter_events_for_server(
- origin, event.room_id, [event]
- )
-
- event = events[0]
+ in_room = yield self.auth.check_host_in_room(
+ event.room_id,
+ origin
+ )
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+ events = yield self._filter_events_for_server(
+ origin, event.room_id, [event]
+ )
+ event = events[0]
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1751,6 +1849,10 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
)
+ missing_events = yield self._filter_events_for_server(
+ origin, room_id, missing_events,
+ )
+
defer.returnValue(missing_events)
@defer.inlineCallbacks
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 91a0898860..277c2b7760 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +19,7 @@
import logging
-import simplejson as json
+from canonicaljson import json
from twisted.internet import defer
@@ -26,7 +27,6 @@ from synapse.api.errors import (
MatrixCodeMessageException, CodeMessageException
)
from ._base import BaseHandler
-from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError, Codes
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ class IdentityHandler(BaseHandler):
super(IdentityHandler, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
+ self.federation_http_client = hs.get_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
@@ -60,8 +61,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
- yield run_on_reactor()
-
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@@ -104,7 +103,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
- yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
@@ -139,9 +137,53 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
- yield run_on_reactor()
+ def unbind_threepid(self, mxid, threepid):
+ """
+ Removes a binding from an identity server
+ Args:
+ mxid (str): Matrix user ID of binding to be removed
+ threepid (dict): Dict with medium & address of binding to be removed
+
+ Returns:
+ Deferred[bool]: True on success, otherwise False
+ """
+ logger.debug("unbinding threepid %r from %s", threepid, mxid)
+ if not self.trusted_id_servers:
+ logger.warn("Can't unbind threepid: no trusted ID servers set in config")
+ defer.returnValue(False)
+
+ # We don't track what ID server we added 3pids on (perhaps we ought to)
+ # but we assume that any of the servers in the trusted list are in the
+ # same ID server federation, so we can pick any one of them to send the
+ # deletion request to.
+ id_server = next(iter(self.trusted_id_servers))
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+ content = {
+ "mxid": mxid,
+ "threepid": threepid,
+ }
+ headers = {}
+ # we abuse the federation http client to sign the request, but we have to send it
+ # using the normal http client since we don't want the SRV lookup and want normal
+ # 'browser-like' HTTPS.
+ self.federation_http_client.sign_request(
+ destination=None,
+ method='POST',
+ url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ headers_dict=headers,
+ content=content,
+ destination_is=id_server,
+ )
+ yield self.http_client.post_json_get_json(
+ url,
+ content,
+ headers,
+ )
+ defer.returnValue(True)
+ @defer.inlineCallbacks
+ def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
@@ -176,8 +218,6 @@ class IdentityHandler(BaseHandler):
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
- yield run_on_reactor()
-
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1cb81b6cf8..cbadf3c88e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import simplejson
import sys
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
import six
from six import string_types, itervalues, iteritems
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.defer import succeed
from twisted.python.failure import Failure
@@ -36,7 +35,7 @@ from synapse.events.validator import EventValidator
from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
-from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
+from synapse.util.async import ReadWriteLock, Limiter
from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
from synapse.util.frozenutils import frozendict_json_encoder
@@ -157,7 +156,7 @@ class MessageHandler(BaseHandler):
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
- reactor.callLater(24 * 3600, clear_purge)
+ self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
"""Get the current status of an active purge
@@ -491,7 +490,7 @@ class EventCreationHandler(object):
target, e
)
- is_exempt = yield self._is_exempt_from_privacy_policy(builder)
+ is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
if not is_exempt:
yield self.assert_accepted_privacy_policy(requester)
@@ -509,12 +508,13 @@ class EventCreationHandler(object):
defer.returnValue((event, context))
- def _is_exempt_from_privacy_policy(self, builder):
+ def _is_exempt_from_privacy_policy(self, builder, requester):
""""Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
builder (synapse.events.builder.EventBuilder): event being created
+ requester (Requster): user requesting this event
Returns:
Deferred[bool]: true if the event can be sent without the user
@@ -525,6 +525,9 @@ class EventCreationHandler(object):
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
return self._is_server_notices_room(builder.room_id)
+ elif membership == Membership.LEAVE:
+ # the user is always allowed to leave (but not kick people)
+ return builder.state_key == requester.user.to_string()
return succeed(False)
@defer.inlineCallbacks
@@ -793,7 +796,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
- simplejson.loads(dump)
+ json.loads(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -806,6 +809,7 @@ class EventCreationHandler(object):
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield send_event_to_master(
+ self.hs.get_clock(),
self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
@@ -959,9 +963,7 @@ class EventCreationHandler(object):
event_stream_id, max_stream_id
)
- @defer.inlineCallbacks
def _notify():
- yield run_on_reactor()
try:
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7fe568132f..7db59fba00 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -22,7 +22,7 @@ The methods that define policy are:
- should_notify
"""
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from contextlib import contextmanager
from six import itervalues, iteritems
@@ -179,7 +179,7 @@ class PresenceHandler(object):
# have not yet been persisted
self.unpersisted_users_changes = set()
- reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown)
self.serial_to_user = {}
self._next_serial = 1
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7e52adda3c..e76ef5426d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -24,7 +24,7 @@ from synapse.api.errors import (
from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID, create_requester, RoomID, RoomAlias
-from synapse.util.async import run_on_reactor, Linearizer
+from synapse.util.async import Linearizer
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
@@ -139,7 +139,6 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
@@ -431,8 +430,6 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
-
if localpart is None:
raise SynapseError(400, "Request must include user id")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2abd63ad05..ab72963d87 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -23,7 +23,7 @@ from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset
)
-from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.util import stringutils
from synapse.visibility import filter_events_for_client
@@ -115,7 +115,11 @@ class RoomCreationHandler(BaseHandler):
)
if mapping:
- raise SynapseError(400, "Room alias already taken")
+ raise SynapseError(
+ 400,
+ "Room alias already taken",
+ Codes.ROOM_IN_USE
+ )
else:
room_alias = None
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 1eca26aa1e..2e3a77ca4b 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -64,6 +64,13 @@ class SearchHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid batch")
+ logger.info(
+ "Search batch properties: %r, %r, %r",
+ batch_group, batch_group_key, batch_token,
+ )
+
+ logger.info("Search content: %s", content)
+
try:
room_cat = content["search_categories"]["room_events"]
@@ -271,6 +278,8 @@ class SearchHandler(BaseHandler):
# We should never get here due to the guard earlier.
raise NotImplementedError()
+ logger.info("Found %d events to return", len(allowed_events))
+
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
@@ -282,6 +291,11 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit
)
+ logger.info(
+ "Context for search returned %d and %d events",
+ len(res["events_before"]), len(res["events_after"]),
+ )
+
res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"]
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 51ec727df0..7f486e48e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -145,7 +145,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have chanegd
+ "device_lists", # List of user_ids whose devices have changed
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
"groups",
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a39f0f7343..7e4a114d4f 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure
-from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
from six import iteritems
@@ -174,7 +173,7 @@ class UserDirectoryHandler(object):
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id)
num_processed_rooms += 1
- yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
@@ -188,7 +187,7 @@ class UserDirectoryHandler(object):
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id)
num_processed_users += 1
- yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users")
@@ -236,7 +235,7 @@ class UserDirectoryHandler(object):
count = 0
for user_id in user_ids:
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
@@ -251,7 +250,7 @@ class UserDirectoryHandler(object):
continue
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8064a84c5c..5bdc484c15 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -42,7 +42,7 @@ from twisted.web._newclient import ResponseDone
from six import StringIO
from prometheus_client import Counter
-import simplejson as json
+from canonicaljson import json
import logging
import urllib
@@ -98,8 +98,8 @@ class SimpleHttpClient(object):
method, uri, *args, **kwargs
)
add_timeout_to_deferred(
- request_deferred,
- 60, cancelled_to_request_timed_out_error,
+ request_deferred, 60, self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
@@ -115,7 +115,7 @@ class SimpleHttpClient(object):
"Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message
)
- raise e
+ raise
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}, headers=None):
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 87a482650d..5a9cbb3324 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@@ -38,6 +38,36 @@ _Server = collections.namedtuple(
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Does some basic sanity checking of the
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ if server_name[0] != '[':
+ raise Exception()
+
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
@@ -74,21 +102,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ))
+ ), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ))
+ ), reactor)
class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac):
+ def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
+ self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn)
+ conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
@@ -98,9 +127,10 @@ class _WrappedConnection(object):
"""
__slots__ = ["conn", "last_request"]
- def __init__(self, conn):
+ def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
+ self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
@@ -131,14 +161,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 821aed362b..2cb9e3e231 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -22,12 +22,12 @@ from twisted.web._newclient import ResponseDone
from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
import synapse.metrics
-from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util.async import add_timeout_to_deferred
from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from synapse.api.errors import (
SynapseError, Codes, HttpResponseException, FederationDeniedError,
@@ -36,7 +36,6 @@ from synapse.api.errors import (
from signedjson.sign import sign_json
import cgi
-import simplejson as json
import logging
import random
import sys
@@ -193,6 +192,7 @@ class MatrixFederationHttpClient(object):
add_timeout_to_deferred(
request_deferred,
timeout / 1000. if timeout else 60,
+ self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(
@@ -234,7 +234,7 @@ class MatrixFederationHttpClient(object):
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
- yield sleep(delay)
+ yield self.clock.sleep(delay)
retries_left -= 1
else:
raise
@@ -260,14 +260,35 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
- content=None):
+ content=None, destination_is=None):
+ """
+ Signs a request by adding an Authorization header to headers_dict
+ Args:
+ destination (bytes|None): The desination home server of the request.
+ May be None if the destination is an identity server, in which case
+ destination_is must be non-None.
+ method (bytes): The HTTP method of the request
+ url_bytes (bytes): The URI path of the request
+ headers_dict (dict): Dictionary of request headers to append to
+ content (bytes): The body of the request
+ destination_is (bytes): As 'destination', but if the destination is an
+ identity server
+
+ Returns:
+ None
+ """
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
- "destination": destination,
}
+ if destination is not None:
+ request["destination"] = destination
+
+ if destination_is is not None:
+ request["destination_is"] = destination_is
+
if content is not None:
request["content"] = content
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index dc06f6c443..1b711ca2de 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -117,13 +117,17 @@ def _get_in_flight_counts():
Returns:
dict[tuple[str, str], int]
"""
- for rm in _in_flight_requests:
+ # Cast to a list to prevent it changing while the Prometheus
+ # thread is collecting metrics
+ reqs = list(_in_flight_requests)
+
+ for rm in reqs:
rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that
# type
counts = {}
- for rm in _in_flight_requests:
+ for rm in reqs:
key = (rm.method, rm.name,)
counts[key] = counts.get(key, 0) + 1
@@ -131,7 +135,7 @@ def _get_in_flight_counts():
LaterGauge(
- "synapse_http_request_metrics_in_flight_requests_count",
+ "synapse_http_server_in_flight_requests_count",
"",
["method", "servlet"],
_get_in_flight_counts,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index bc09b8b2be..517aaf7b5a 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -29,7 +29,7 @@ import synapse.metrics
import synapse.events
from canonicaljson import (
- encode_canonical_json, encode_pretty_printed_json
+ encode_canonical_json, encode_pretty_printed_json, json
)
from twisted.internet import defer
@@ -41,7 +41,6 @@ from twisted.web.util import redirectTo
import collections
import logging
import urllib
-import simplejson
logger = logging.getLogger(__name__)
@@ -410,7 +409,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- json_bytes = simplejson.dumps(json_object)
+ json_bytes = json.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index ef8e62901b..ef3a01ddc7 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -18,7 +18,9 @@
from synapse.api.errors import SynapseError, Codes
import logging
-import simplejson
+
+from canonicaljson import json
+
logger = logging.getLogger(__name__)
@@ -171,7 +173,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None
try:
- content = simplejson.loads(content_bytes)
+ content = json.loads(content_bytes)
except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 2664006f8c..fe93643b1e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -99,19 +99,36 @@ class SynapseRequest(Request):
db_txn_count = context.db_txn_count
db_txn_duration_sec = context.db_txn_duration_sec
db_sched_duration_sec = context.db_sched_duration_sec
+ evt_db_fetch_count = context.evt_db_fetch_count
except Exception:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration_sec = (0, 0)
+ evt_db_fetch_count = 0
end_time = time.time()
+ # need to decode as it could be raw utf-8 bytes
+ # from a IDN servname in an auth header
+ authenticated_entity = self.authenticated_entity
+ if authenticated_entity is not None:
+ authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+
+ # ...or could be raw utf-8 bytes in the User-Agent header.
+ # N.B. if you don't do this, the logger explodes cryptically
+ # with maximum recursion trying to log errors about
+ # the charset problem.
+ # c.f. https://github.com/matrix-org/synapse/issues/3471
+ user_agent = self.get_user_agent()
+ if user_agent is not None:
+ user_agent = user_agent.decode("utf-8", "replace")
+
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
- " %sB %s \"%s %s %s\" \"%s\"",
+ " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
- self.authenticated_entity,
+ authenticated_entity,
end_time - self.start_time,
ru_utime,
ru_stime,
@@ -123,7 +140,8 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri(),
self.clientproto,
- self.get_user_agent(),
+ user_agent,
+ evt_db_fetch_count,
)
try:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 429e79c472..2d2397caae 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -62,7 +62,7 @@ class LaterGauge(object):
calls = self.caller()
except Exception:
logger.exception(
- "Exception running callback for LaterGuage(%s)",
+ "Exception running callback for LaterGauge(%s)",
self.name,
)
yield g
@@ -140,14 +140,15 @@ gc_time = Histogram(
class GCCounts(object):
def collect(self):
- cm = GaugeMetricFamily("python_gc_counts", "GC cycle counts", labels=["gen"])
+ cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
cm.add_metric([str(n)], m)
yield cm
-REGISTRY.register(GCCounts())
+if not running_on_pypy:
+ REGISTRY.register(GCCounts())
#
# Twisted reactor metrics
@@ -190,6 +191,22 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"
# finished being processed.
event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
+last_ticked = time.time()
+
+
+class ReactorLastSeenMetric(object):
+
+ def collect(self):
+ cm = GaugeMetricFamily(
+ "python_twisted_reactor_last_seen",
+ "Seconds since the Twisted reactor was last seen",
+ )
+ cm.add_metric([], time.time() - last_ticked)
+ yield cm
+
+
+REGISTRY.register(ReactorLastSeenMetric())
+
def runUntilCurrentTimer(func):
@@ -222,6 +239,11 @@ def runUntilCurrentTimer(func):
tick_time.observe(end - start)
pending_calls_metric.observe(num_pending)
+ # Update the time we last ticked, for the metric to test whether
+ # Synapse's reactor has frozen
+ global last_ticked
+ last_ticked = end
+
if running_on_pypy:
return ret
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6dce20a284..3c0622a294 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -161,6 +161,7 @@ class Notifier(object):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
+ self.hs = hs
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
@@ -340,6 +341,7 @@ class Notifier(object):
add_timeout_to_deferred(
listener.deferred,
(end_time - now) / 1000.,
+ self.hs.get_reactor(),
)
with PreserveLoggingContext():
yield listener.deferred
@@ -561,6 +563,7 @@ class Notifier(object):
add_timeout_to_deferred(
listener.deferred.addTimeout,
(end_time - now) / 1000.,
+ self.hs.get_reactor(),
)
try:
with PreserveLoggingContext():
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index ba7286cb72..52d4f087ee 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
@@ -199,7 +199,7 @@ class EmailPusher(object):
self.timed_call = None
if soonest_due_at is not None:
- self.timed_call = reactor.callLater(
+ self.timed_call = self.hs.get_reactor().callLater(
self.seconds_until(soonest_due_at), self.on_timer
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index bf7ff74a1a..7a481b5a1e 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from . import push_rule_evaluator
@@ -220,7 +220,9 @@ class HttpPusher(object):
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
- self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer)
+ self.timed_call = self.hs.get_reactor().callLater(
+ self.backoff_delay, self.on_timer
+ )
self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
break
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 750d11ca38..36bb5bbc65 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -19,7 +19,6 @@ import logging
from twisted.internet import defer
from synapse.push.pusher import PusherFactory
-from synapse.util.async import run_on_reactor
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -125,7 +124,6 @@ class PusherPool:
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
- yield run_on_reactor()
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
@@ -151,7 +149,6 @@ class PusherPool:
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
- yield run_on_reactor()
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index a9baa2c1c3..f080f96cc1 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -21,7 +21,6 @@ from synapse.api.errors import (
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure
from synapse.types import Requester, UserID
@@ -33,11 +32,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks
-def send_event_to_master(client, host, port, requester, event, context,
+def send_event_to_master(clock, client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
+ clock (synapse.util.Clock)
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
@@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context,
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
- yield sleep(1)
+ yield clock.sleep(1)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b1f64ef0d8..97d3196633 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage.user_erasure_store import UserErasureWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -45,6 +46,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore,
StateGroupWorkerStore,
SignatureWorkerStore,
+ UserErasureWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..bb852b00af 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,7 +15,7 @@
"""A replication client for use by synapse workers.
"""
-from twisted.internet import reactor, defer
+from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
@@ -44,7 +44,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
- reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +95,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- reactor.connectTCP(host, port, factory)
+ hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 12aac3cc6b..f3908df642 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,13 +19,17 @@ allowed to be sent by which side.
"""
import logging
-import simplejson
+import platform
+if platform.python_implementation() == "PyPy":
+ import json
+ _json_encoder = json.JSONEncoder()
+else:
+ import simplejson as json
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
-_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
-
class Command(object):
"""The base command class.
@@ -102,7 +106,7 @@ class RdataCommand(Command):
return cls(
stream_name,
None if token == "batch" else int(token),
- simplejson.loads(row_json)
+ json.loads(row_json)
)
def to_line(self):
@@ -300,7 +304,7 @@ class InvalidateCacheCommand(Command):
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
- return cls(cache_func, simplejson.loads(keys_json))
+ return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((
@@ -329,7 +333,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
return cls(
user_id, access_token, ip, user_agent, device_id, last_seen
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index c870475cd1..171a698e14 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -564,11 +564,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
- "pending_commands", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_pending_commands",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands)
- for p in connected_connections
- })
+ (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
+ },
+)
def transport_buffer_size(protocol):
@@ -579,11 +581,13 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p)
- for p in connected_connections
- })
+ (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
+ },
+)
def transport_kernel_read_buffer_size(protocol, read=True):
@@ -602,37 +606,50 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
- })
+ },
+)
tcp_transport_kernel_read_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
- })
+ },
+)
tcp_inbound_commands = LaterGauge(
- "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
- })
+ },
+)
tcp_outbound_commands = LaterGauge(
- "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
- })
+ },
+)
# number of updates received for each RDATA stream
-inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
- ["stream_name"])
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 63bd6d2652..95ad8c1b4c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,7 +15,7 @@
"""The server side of the replication stream.
"""
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.protocol import Factory
from .streams import STREAMS_MAP, FederationStream
@@ -109,7 +109,7 @@ class ReplicationStreamer(object):
self.is_looping = False
self.pending_updates = False
- reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self):
# close all connections on shutdown
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index b8665a45eb..8fb08dc526 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -16,6 +16,8 @@
from twisted.internet import defer
+from six.moves import http_client
+
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester
@@ -247,6 +249,15 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
@@ -254,7 +265,9 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self._deactivate_account_handler.deactivate_account(target_user_id)
+ yield self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase,
+ )
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 34df5be4e9..88ca5184cd 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -23,7 +23,8 @@ from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
-import simplejson as json
+from canonicaljson import json
+
import urllib
from six.moves.urllib import parse as urlparse
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 9b3022e0b0..c10320dedf 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -24,8 +24,6 @@ import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
from synapse.types import create_requester
-from synapse.util.async import run_on_reactor
-
from hashlib import sha1
import hmac
import logging
@@ -272,7 +270,6 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
- yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
@@ -333,8 +330,6 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
- yield run_on_reactor()
-
if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
@@ -423,8 +418,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
- yield run_on_reactor()
-
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0b984987ed..e6ae5db79b 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -31,7 +31,7 @@ from synapse.http.servlet import (
from six.moves.urllib import parse as urlparse
import logging
-import simplejson as json
+from canonicaljson import json
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 30523995af..80dbc3c92e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
# limitations under the License.
import logging
+from six.moves import http_client
from twisted.internet import defer
from synapse.api.auth import has_access_token
@@ -24,7 +26,6 @@ from synapse.http.servlet import (
RestServlet, assert_params_in_request,
parse_json_object_from_request,
)
-from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
@@ -187,13 +188,20 @@ class DeactivateAccountRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
requester = yield self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
if requester.app_service:
yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string()
+ requester.user.to_string(), erase,
)
defer.returnValue((200, {}))
@@ -201,7 +209,7 @@ class DeactivateAccountRestServlet(RestServlet):
requester, body, self.hs.get_ip_from_request(request),
)
yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(),
+ requester.user.to_string(), erase,
)
defer.returnValue((200, {}))
@@ -300,8 +308,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- yield run_on_reactor()
-
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids(
@@ -312,8 +318,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds')
@@ -365,8 +369,6 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
required = ['medium', 'address']
@@ -381,9 +383,16 @@ class ThreepidDeleteRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address']
- )
+ try:
+ yield self.auth_handler.delete_threepid(
+ user_id, body['medium'], body['address']
+ )
+ except Exception:
+ # NB. This endpoint should succeed if there is nothing to
+ # delete, so it should only throw if something is wrong
+ # that we ought to care about.
+ logger.exception("Failed to remove threepid")
+ raise SynapseError(500, "Failed to remove threepid")
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5cab00aea9..97e7c0f7c6 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -32,7 +32,6 @@ from ._base import client_v2_patterns, interactive_auth_handler
import logging
import hmac
from hashlib import sha1
-from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
from six import string_types
@@ -191,8 +190,6 @@ class RegisterRestServlet(RestServlet):
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
kind = "user"
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a291cffbf1..d2aa47b326 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -33,7 +33,7 @@ from ._base import set_timeline_upper_limit
import itertools
import logging
-import simplejson as json
+from canonicaljson import json
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index 956bd5da75..e44d4276d2 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -22,8 +22,9 @@ from synapse.api.errors import (
from twisted.protocols.basic import FileSender
from twisted.web import server, resource
+from canonicaljson import json
+
import base64
-import simplejson as json
import logging
import os
import re
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 2ac767d2dc..218ba7a083 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
def __init__(self, hs):
+ self.hs = hs
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@@ -94,7 +95,7 @@ class MediaRepository(object):
storage_providers.append(provider)
self.media_storage = MediaStorage(
- self.primary_base_path, self.filepaths, storage_providers,
+ self.hs, self.primary_base_path, self.filepaths, storage_providers,
)
self.clock.looping_call(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index d23fe10b07..d6b8ebbedb 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -37,13 +37,15 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
+ hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
"""
- def __init__(self, local_media_directory, filepaths, storage_providers):
+ def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@@ -175,7 +177,8 @@ class MediaStorage(object):
res = yield provider.fetch(path, file_info)
if res:
with res:
- consumer = BackgroundFileConsumer(open(local_path, "w"))
+ consumer = BackgroundFileConsumer(
+ open(local_path, "w"), self.hs.get_reactor())
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 565cef2b8d..adca490640 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -23,7 +23,8 @@ import re
import shutil
import sys
import traceback
-import simplejson as json
+
+from canonicaljson import json
from six.moves import urllib_parse as urlparse
from six import string_types
diff --git a/synapse/server.py b/synapse/server.py
index 58dbf78437..c29c19289a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
+from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
@@ -165,15 +165,19 @@ class HomeServer(object):
'server_notices_sender',
]
- def __init__(self, hostname, **kwargs):
+ def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
"""
+ if not reactor:
+ from twisted.internet import reactor
+
+ self._reactor = reactor
self.hostname = hostname
self._building = {}
- self.clock = Clock()
+ self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
@@ -186,6 +190,12 @@ class HomeServer(object):
self.datastore = DataStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
+ def get_reactor(self):
+ """
+ Fetch the Twisted reactor in use by this HomeServer.
+ """
+ return self._reactor
+
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
@@ -261,7 +271,7 @@ class HomeServer(object):
return AuthHandler(self)
def build_macaroon_generator(self):
- return MacaroonGeneartor(self)
+ return MacaroonGenerator(self)
def build_device_handler(self):
return DeviceHandler(self)
@@ -328,6 +338,7 @@ class HomeServer(object):
return adbapi.ConnectionPool(
name,
+ cp_reactor=self.get_reactor(),
**self.db_config.get("args", {})
)
diff --git a/synapse/state.py b/synapse/state.py
index 216418f58d..8098db94b4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -694,10 +694,10 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
return auth_events
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
+def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ds):
+ for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 979fa22438..e843b702b9 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -20,6 +20,7 @@ import time
import logging
from synapse.storage.devices import DeviceStore
+from synapse.storage.user_erasure_store import UserErasureStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@@ -88,6 +89,7 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
+ UserErasureStore,
):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index f83ff0454a..7034a61399 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -20,10 +20,11 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+from canonicaljson import json
import abc
-import simplejson as json
import logging
logger = logging.getLogger(__name__)
@@ -114,25 +115,6 @@ class AccountDataWorkerStore(SQLBaseStore):
else:
defer.returnValue(None)
- @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
- num_args=2, list_name="user_ids", inlineCallbacks=True)
- def get_global_account_data_by_type_for_users(self, data_type, user_ids):
- rows = yield self._simple_select_many_batch(
- table="account_data",
- column="user_id",
- iterable=user_ids,
- keyvalues={
- "account_data_type": data_type,
- },
- retcols=("user_id", "content",),
- desc="get_global_account_data_by_type_for_users",
- )
-
- defer.returnValue({
- row["user_id"]: json.loads(row["content"]) if row["content"] else None
- for row in rows
- })
-
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 12ea8a158c..4d32d0bdf6 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,8 +15,8 @@
# limitations under the License.
import logging
import re
-import simplejson as json
from twisted.internet import defer
+from canonicaljson import json
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 8af325a9f5..af18964510 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.async
from ._base import SQLBaseStore
from . import engines
from twisted.internet import defer
-import simplejson as json
+from canonicaljson import json
+
import logging
logger = logging.getLogger(__name__)
@@ -92,7 +92,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Starting background schema updates")
while True:
- yield synapse.util.async.sleep(
+ yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index ce338514e8..968d2fed22 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,7 +15,7 @@
import logging
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from ._base import Cache
from . import background_updates
@@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
- reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index a879e5bfc1..38addbf9c0 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,7 +14,8 @@
# limitations under the License.
import logging
-import simplejson
+
+from canonicaljson import json
from twisted.internet import defer
@@ -85,7 +86,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = simplejson.dumps(edu)
+ edu_json = json.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -177,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
- message_json = simplejson.dumps(messages_by_device["*"])
+ message_json = json.dumps(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -199,7 +200,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = simplejson.dumps(messages_by_device[device])
+ message_json = json.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
@@ -253,7 +254,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(simplejson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
@@ -389,7 +390,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(simplejson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d149d8392e..2ed9ada783 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import simplejson as json
from twisted.internet import defer
@@ -21,6 +20,8 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from canonicaljson import json
+
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b146487943..181047c8b7 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -16,8 +16,7 @@ from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from canonicaljson import encode_canonical_json
-import simplejson as json
+from canonicaljson import encode_canonical_json, json
from ._base import SQLBaseStore
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d0350ee5fe..05cb3f61ce 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -16,11 +16,11 @@
from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
-from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging
-import simplejson as json
+
+from canonicaljson import json
from six import iteritems
@@ -84,6 +84,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
+ self._rotate_delay = 3
+ self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@@ -800,7 +802,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
if caught_up:
break
- yield sleep(5)
+ yield self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False
@@ -821,8 +823,8 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
- ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
- """, (old_rotate_stream_ordering,))
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+ """, (old_rotate_stream_ordering, self._rotate_count))
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index cb1082e864..d816d4883c 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,8 @@ from functools import wraps
import itertools
import logging
-import simplejson as json
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.storage.events_worker import EventsWorkerStore
@@ -1044,7 +1045,6 @@ class EventsStore(EventsWorkerStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
@@ -1064,6 +1064,14 @@ class EventsStore(EventsWorkerStore):
[(ev.event_id,) for ev, _ in events_and_contexts]
)
+ for table in (
+ "event_push_actions",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 32d9d00ffb..896225aab9 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -14,13 +14,14 @@
# limitations under the License.
from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logcontext import (
PreserveLoggingContext, make_deferred_yieldable, run_in_background,
+ LoggingContext,
)
from synapse.util.metrics import Measure
from synapse.api.errors import SynapseError
@@ -28,7 +29,8 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
import logging
-import simplejson as json
+
+from canonicaljson import json
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
@@ -145,6 +147,9 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
@@ -265,7 +270,7 @@ class EventsWorkerStore(SQLBaseStore):
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
+ self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@@ -278,7 +283,7 @@ class EventsWorkerStore(SQLBaseStore):
if event_list:
with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
+ self.hs.get_reactor().callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2e2763126d..eae6027cee 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -19,8 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from canonicaljson import encode_canonical_json
-import simplejson as json
+from canonicaljson import encode_canonical_json, json
class FilteringStore(SQLBaseStore):
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index da05ccb027..b77402d295 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError
from ._base import SQLBaseStore
-import simplejson as json
+from canonicaljson import json
# The category ID for the "default" category. We don't store as null in the
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 04a0b59a39..9e52e992b3 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -25,9 +25,10 @@ from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer
+from canonicaljson import json
+
import abc
import logging
-import simplejson as json
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 307660b99a..c6def861cf 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -17,12 +17,11 @@
from ._base import SQLBaseStore
from twisted.internet import defer
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging
-import simplejson as json
import types
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c93c228f6e..f230a3bab7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -21,9 +21,10 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
+from canonicaljson import json
+
import abc
import logging
-import simplejson as json
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index c241167fbe..0d18f6d869 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -460,15 +460,6 @@ class RegistrationStore(RegistrationWorkerStore,
defer.returnValue(ret['user_id'])
defer.returnValue(None)
- def user_delete_threepids(self, user_id):
- return self._simple_delete(
- "user_threepids",
- keyvalues={
- "user_id": user_id,
- },
- desc="user_delete_threepids",
- )
-
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
@@ -632,7 +623,9 @@ class RegistrationStore(RegistrationWorkerStore,
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
- return self._simple_delete_one(
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ return self._simple_delete(
"users_pending_deactivation",
keyvalues={
"user_id": user_id,
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index ea6a189185..ca0eb187e5 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -20,9 +20,10 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from canonicaljson import json
+
import collections
import logging
-import simplejson as json
import re
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 48a88f755e..8fc9549a75 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -28,7 +28,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
-import simplejson as json
+from canonicaljson import json
from six import itervalues, iteritems
@@ -455,7 +455,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
defer.returnValue(joined_hosts)
- @cached(max_entries=10000, iterable=True)
+ @cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql
new file mode 100644
index 0000000000..5d8641a9ab
--- /dev/null
+++ b/synapse/storage/schema/delta/50/erasure_store.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of users who have requested that their details be erased
+CREATE TABLE erased_users (
+ user_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f0fa5d7631..9b77c45318 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -16,7 +16,7 @@
from collections import namedtuple
import logging
import re
-import simplejson as json
+from canonicaljson import json
from six import string_types
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85b8ec2b8f..cd9821c270 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -526,10 +526,23 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
- """Given list of groups returns dict of group -> list of state events
- with matching types. `types` is a list of `(type, state_key)`, where
- a `state_key` of None matches all state_keys. If `types` is None then
- all events are returned.
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
"""
if types:
types = frozenset(types)
@@ -538,7 +551,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache(
- group, types
+ group, types,
)
results[group] = state_dict_ids
@@ -559,26 +572,40 @@ class StateGroupWorkerStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
+ # the DictionaryCache knows if it has *all* the state, but
+ # does not know if it has all of the keys of a particular type,
+ # which makes wildcard lookups expensive unless we have a complete
+ # cache. Hence, if we are doing a wildcard lookup, populate the
+ # cache fully so that we can do an efficient lookup next time.
+
+ if types and any(k is None for (t, k) in types):
+ types_to_fetch = None
+ else:
+ types_to_fetch = types
+
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types
+ missing_groups, types_to_fetch,
)
- # Now we want to update the cache with all the things we fetched
- # from the database.
for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
- state_dict.update(
- ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
- for k, v in iteritems(group_state_dict)
- )
-
+ # update the result, filtering by `types`.
+ if types:
+ for k, v in iteritems(group_state_dict):
+ (typ, _) = k
+ if k in types or (typ, None) in types:
+ state_dict[k] = v
+ else:
+ state_dict.update(group_state_dict)
+
+ # update the cache with all the things we fetched from the
+ # database.
self._state_group_cache.update(
cache_seq_num,
key=group,
- value=state_dict,
- full=(types is None),
- known_absent=types,
+ value=group_state_dict,
+ fetched_keys=types_to_fetch,
)
defer.returnValue(results)
@@ -685,7 +712,6 @@ class StateGroupWorkerStore(SQLBaseStore):
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
- full=True,
)
return state_group
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 6671d3cfca..04d123ed95 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -19,7 +19,8 @@ from synapse.storage.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-import simplejson as json
+from canonicaljson import json
+
import logging
from six.moves import range
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index e485d19b84..acbc03446e 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -19,12 +19,11 @@ from synapse.util.caches.descriptors import cached
from twisted.internet import defer
import six
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from collections import namedtuple
import logging
-import simplejson as json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
new file mode 100644
index 0000000000..47bfc01e84
--- /dev/null
+++ b/synapse/storage/user_erasure_store.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import operator
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedList, cached
+
+
+class UserErasureWorkerStore(SQLBaseStore):
+ @cached()
+ def is_user_erased(self, user_id):
+ """
+ Check if the given user id has requested erasure
+
+ Args:
+ user_id (str): full user id to check
+
+ Returns:
+ Deferred[bool]: True if the user has requested erasure
+ """
+ return self._simple_select_onecol(
+ table="erased_users",
+ keyvalues={"user_id": user_id},
+ retcol="1",
+ desc="is_user_erased",
+ ).addCallback(operator.truth)
+
+ @cachedList(
+ cached_method_name="is_user_erased",
+ list_name="user_ids",
+ inlineCallbacks=True,
+ )
+ def are_users_erased(self, user_ids):
+ """
+ Checks which users in a list have requested erasure
+
+ Args:
+ user_ids (iterable[str]): full user id to check
+
+ Returns:
+ Deferred[dict[str, bool]]:
+ for each user, whether the user has requested erasure.
+ """
+ # this serves the dual purpose of (a) making sure we can do len and
+ # iterate it multiple times, and (b) avoiding duplicates.
+ user_ids = tuple(set(user_ids))
+
+ def _get_erased_users(txn):
+ txn.execute(
+ "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
+ ",".join("?" * len(user_ids))
+ ),
+ user_ids,
+ )
+ return set(r[0] for r in txn)
+
+ erased_users = yield self.runInteraction(
+ "are_users_erased", _get_erased_users,
+ )
+ res = dict((u, u in erased_users) for u in user_ids)
+ defer.returnValue(res)
+
+
+class UserErasureStore(UserErasureWorkerStore):
+ def mark_user_erased(self, user_id):
+ """Indicate that user_id wishes their message history to be erased.
+
+ Args:
+ user_id (str): full user_id to be erased
+ """
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute(
+ "SELECT 1 FROM erased_users WHERE user_id = ?",
+ (user_id, )
+ )
+ if txn.fetchone():
+ return
+
+ # they are not already there: do the insert.
+ txn.execute(
+ "INSERT INTO erased_users (user_id) VALUES (?)",
+ (user_id, )
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.is_user_erased, (user_id,)
+ )
+ return self.runInteraction("mark_user_erased", f)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index fc11e26623..e9886ef299 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import PreserveLoggingContext
-
-from twisted.internet import defer, reactor, task
-
-import time
import logging
-
from itertools import islice
+import attr
+from twisted.internet import defer, task
+
+from synapse.util.logcontext import PreserveLoggingContext
+
logger = logging.getLogger(__name__)
@@ -31,16 +30,27 @@ def unwrapFirstError(failure):
return failure.value.subFailure
+@attr.s
class Clock(object):
- """A small utility that obtains current time-of-day so that time may be
- mocked during unit-tests.
+ """
+ A Clock wraps a Twisted reactor and provides utilities on top of it.
- TODO(paul): Also move the sleep() functionality into it
+ Args:
+ reactor: The Twisted reactor to use.
"""
+ _reactor = attr.ib()
+
+ @defer.inlineCallbacks
+ def sleep(self, seconds):
+ d = defer.Deferred()
+ with PreserveLoggingContext():
+ self._reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def time(self):
"""Returns the current system time in seconds since epoch."""
- return time.time()
+ return self._reactor.seconds()
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
@@ -56,6 +66,7 @@ class Clock(object):
msec(float): How long to wait between calls in milliseconds.
"""
call = task.LoopingCall(f)
+ call.clock = self._reactor
call.start(msec / 1000.0, now=False)
return call
@@ -73,7 +84,7 @@ class Clock(object):
callback(*args, **kwargs)
with PreserveLoggingContext():
- return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
+ return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
try:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 9dd4e6b5bc..1668df4ce6 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
from .logcontext import (
PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
-from synapse.util import logcontext, unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError, Clock
from contextlib import contextmanager
@@ -32,22 +31,6 @@ from six.moves import range
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def sleep(seconds):
- d = defer.Deferred()
- with PreserveLoggingContext():
- reactor.callLater(seconds, d.callback, seconds)
- res = yield d
- defer.returnValue(res)
-
-
-def run_on_reactor():
- """ This will cause the rest of the function to be invoked upon the next
- iteration of the main loop
- """
- return sleep(0)
-
-
class ObservableDeferred(object):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
@@ -180,13 +163,18 @@ class Linearizer(object):
# do some work.
"""
- def __init__(self, name=None):
+ def __init__(self, name=None, clock=None):
if name is None:
self.name = id(self)
else:
self.name = name
self.key_to_defer = {}
+ if not clock:
+ from twisted.internet import reactor
+ clock = Clock(reactor)
+ self._clock = clock
+
@defer.inlineCallbacks
def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that
@@ -227,7 +215,7 @@ class Linearizer(object):
# the context manager, but it needs to happen while we hold the
# lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place.
- yield run_on_reactor()
+ yield self._clock.sleep(0)
else:
logger.info("Acquired uncontended linearizer lock %r for key %r",
@@ -404,7 +392,7 @@ class DeferredTimeoutError(Exception):
"""
-def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
"""
Add a timeout to a deferred by scheduling it to be cancelled after
timeout seconds.
@@ -419,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
Args:
deferred (defer.Deferred): deferred to be timed out
timeout (Number): seconds to time out after
+ reactor (twisted.internet.reactor): the Twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately
after the deferred times out, and not if this deferred is
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index bdc21e348f..95793d466d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -107,29 +107,28 @@ class DictionaryCache(object):
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, full=False, known_absent=None):
+ def update(self, sequence, key, value, fetched_keys=None):
"""Updates the entry in the cache
Args:
sequence
- key
- value (dict): The value to update the cache with.
- full (bool): Whether the given value is the full dict, or just a
- partial subset there of. If not full then any existing entries
- for the key will be updated.
- known_absent (set): Set of keys that we know don't exist in the full
- dict.
+ key (K)
+ value (dict[X,Y]): The value to update the cache with.
+ fetched_keys (None|set[X]): All of the dictionary keys which were
+ fetched from the database.
+
+ If None, this is the complete value for key K. Otherwise, it
+ is used to infer a list of keys which we know don't exist in
+ the full dict.
"""
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- if known_absent is None:
- known_absent = set()
- if full:
- self._insert(key, value, known_absent)
+ if fetched_keys is None:
+ self._insert(key, value, set())
else:
- self._update_or_insert(key, value, known_absent)
+ self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent):
# We pop and reinsert as we need to tell the cache the size may have
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 817118e30f..0fb8620001 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -78,7 +78,8 @@ class StreamChangeCache(object):
not_known_entities = set(entities) - set(self._entity_to_key)
result = (
- set(self._cache.values()[self._cache.bisect_right(stream_pos) :])
+ {self._cache[k] for k in self._cache.islice(
+ start=self._cache.bisect_right(stream_pos))}
.intersection(entities)
.union(not_known_entities)
)
@@ -113,7 +114,8 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- return self._cache.values()[self._cache.bisect_right(stream_pos) :]
+ return [self._cache[k] for k in self._cache.islice(
+ start=self._cache.bisect_right(stream_pos))]
else:
return None
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 3380970e4e..c78801015b 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import threads, reactor
+from twisted.internet import threads
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
Args:
file_obj (file): The file like object to write to. Closed when
finished.
+ reactor (twisted.internet.reactor): the Twisted reactor to use
"""
# For PushProducers pause if we have this many unwritten slices
@@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
# And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2
- def __init__(self, file_obj):
+ def __init__(self, file_obj, reactor):
self._file_obj = file_obj
+ self._reactor = reactor
+
# Producer we're registered with
self._producer = None
@@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
self._producer = producer
self.streaming = streaming
self._finished_deferred = run_in_background(
- threads.deferToThread, self._writer
+ threads.deferToThreadPool,
+ self._reactor,
+ self._reactor.getThreadPool(),
+ self._writer,
)
if not streaming:
self._producer.resumeProducing()
@@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
# producer.
if self._producer and self._paused_producer:
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
- reactor.callFromThread(self._resume_paused_producer)
+ self._reactor.callFromThread(self._resume_paused_producer)
bytes = self._bytes_queue.get()
@@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
# If its a pull producer then we need to explicitly ask for
# more stuff.
if not self.streaming and self._producer:
- reactor.callFromThread(self._producer.resumeProducing)
+ self._reactor.callFromThread(self._producer.resumeProducing)
except Exception as e:
self._write_exception = e
raise
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 15f0a7ba9e..535e7d0e7a 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,7 +14,7 @@
# limitations under the License.
from frozendict import frozendict
-import simplejson as json
+from canonicaljson import json
from six import string_types
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a58c723403..df2b71b791 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -60,6 +60,7 @@ class LoggingContext(object):
__slots__ = [
"previous_context", "name", "ru_stime", "ru_utime",
"db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "evt_db_fetch_count",
"usage_start",
"main_thread", "alive",
"request", "tag",
@@ -90,6 +91,9 @@ class LoggingContext(object):
def add_database_scheduled(self, sched_sec):
pass
+ def record_event_fetch(self, event_count):
+ pass
+
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
@@ -109,6 +113,9 @@ class LoggingContext(object):
# sec spent waiting for db txns to be scheduled
self.db_sched_duration_sec = 0
+ # number of events this thread has fetched from the db
+ self.evt_db_fetch_count = 0
+
# If alive has the thread resource usage when the logcontext last
# became active.
self.usage_start = None
@@ -243,6 +250,14 @@ class LoggingContext(object):
"""
self.db_sched_duration_sec += sched_sec
+ def record_event_fetch(self, event_count):
+ """Record a number of events being fetched from the db
+
+ Args:
+ event_count (int): number of events being fetched
+ """
+ self.evt_db_fetch_count += event_count
+
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 0ab63c3d7d..c5a45cef7c 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.util.async import sleep
from synapse.util.logcontext import (
run_in_background, make_deferred_yieldable,
PreserveLoggingContext,
@@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index aaca2c584c..65d79cf0d0 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -12,15 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
+import logging
+import operator
from twisted.internet import defer
-from synapse.api.constants import Membership, EventTypes
-
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-
-import logging
-
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.utils import prune_event
+from synapse.util.logcontext import (
+ make_deferred_yieldable, preserve_fn,
+)
logger = logging.getLogger(__name__)
@@ -43,21 +45,35 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
-def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
- always_include_ids=frozenset()):
- """ Returns dict of user_id -> list of events that user is allowed to
- see.
+def filter_events_for_client(store, user_id, events, is_peeking=False,
+ always_include_ids=frozenset()):
+ """
+ Check which events a user is allowed to see
Args:
- user_tuples (str, bool): (user id, is_peeking) for each user to be
- checked. is_peeking should be true if:
- * the user is not currently a member of the room, and:
- * the user has not been a member of the room since the
- given events
- events ([synapse.events.EventBase]): list of events to filter
+ store (synapse.storage.DataStore): our datastore (can also be a worker
+ store)
+ user_id(str): user id to be checked
+ events(list[synapse.events.EventBase]): sequence of events to be checked
+ is_peeking(bool): should be True if:
+ * the user is not currently a member of the room, and:
+ * the user has not been a member of the room since the given
+ events
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
+
+ Returns:
+ Deferred[list[synapse.events.EventBase]]
"""
+ types = (
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, user_id),
+ )
+ event_id_to_state = yield store.get_state_for_events(
+ frozenset(e.event_id for e in events),
+ types=types,
+ )
+
forgotten = yield make_deferred_yieldable(defer.gatherResults([
defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
@@ -71,31 +87,37 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
row["event_id"] for rows in forgotten for row in rows
)
- ignore_dict_content = yield store.get_global_account_data_by_type_for_users(
- "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples]
+ ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", user_id,
)
# FIXME: This will explode if people upload something incorrect.
- ignore_dict = {
- user_id: frozenset(
- content.get("ignored_users", {}).keys() if content else []
- )
- for user_id, content in ignore_dict_content.items()
- }
+ ignore_list = frozenset(
+ ignore_dict_content.get("ignored_users", {}).keys()
+ if ignore_dict_content else []
+ )
+
+ erased_senders = yield store.are_users_erased((e.sender for e in events))
- def allowed(event, user_id, is_peeking, ignore_list):
+ def allowed(event):
"""
Args:
event (synapse.events.EventBase): event to check
- user_id (str)
- is_peeking (bool)
- ignore_list (list): list of users to ignore
+
+ Returns:
+ None|EventBase:
+ None if the user cannot see this event at all
+
+ a redacted copy of the event if they can only see a redacted
+ version
+
+ the original event if they can see it as normal.
"""
if not event.is_state() and event.sender in ignore_list:
- return False
+ return None
if event.event_id in always_include_ids:
- return True
+ return event
state = event_id_to_state[event.event_id]
@@ -109,10 +131,6 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
- # if it was world_readable, it's easy: everyone can read it
- if visibility == "world_readable":
- return True
-
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
@@ -146,7 +164,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
- return True
+ return event
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
@@ -157,70 +175,55 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
+ # XXX why do we do this?
+ # https://github.com/matrix-org/synapse/issues/3350
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
- return True
+ return event
+
+ # otherwise, it depends on the room visibility.
if visibility == "joined":
# we weren't a member at the time of the event, so we can't
# see this event.
- return False
+ return None
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
- return membership == Membership.INVITE
-
- else:
- # visibility is shared: user can also see the event if they have
- # become a member since the event
+ return (
+ event if membership == Membership.INVITE else None
+ )
+
+ elif visibility == "shared" and is_peeking:
+ # if the visibility is shared, users cannot see the event unless
+ # they have *subequently* joined the room (or were members at the
+ # time, of course)
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
- # we don't know when they left.
- return not is_peeking
+ # we don't know when they left. We just treat it as though they
+ # never joined, and restrict access.
+ return None
- defer.returnValue({
- user_id: [
- event
- for event in events
- if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, []))
- ]
- for user_id, is_peeking in user_tuples
- })
+ # the visibility is either shared or world_readable, and the user was
+ # not a member at the time. We allow it, provided the original sender
+ # has not requested their data to be erased, in which case, we return
+ # a redacted version.
+ if erased_senders[event.sender]:
+ return prune_event(event)
+ return event
-@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False,
- always_include_ids=frozenset()):
- """
- Check which events a user is allowed to see
+ # check each event: gives an iterable[None|EventBase]
+ filtered_events = itertools.imap(allowed, events)
- Args:
- user_id(str): user id to be checked
- events([synapse.events.EventBase]): list of events to be checked
- is_peeking(bool): should be True if:
- * the user is not currently a member of the room, and:
- * the user has not been a member of the room since the given
- events
+ # remove the None entries
+ filtered_events = filter(operator.truth, filtered_events)
- Returns:
- [synapse.events.EventBase]
- """
- types = (
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
- )
- event_id_to_state = yield store.get_state_for_events(
- frozenset(e.event_id for e in events),
- types=types
- )
- res = yield filter_events_for_clients(
- store, [(user_id, is_peeking)], events, event_id_to_state,
- always_include_ids=always_include_ids,
- )
- defer.returnValue(res.get(user_id, []))
+ # we turn it into a list before returning it.
+ defer.returnValue(list(filtered_events))
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4575dd9834..aec3b62897 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -86,16 +86,53 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ self.store.get_user_by_access_token = Mock(return_value=None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args["access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = yield self.auth.get_user_by_req(request)
+ self.assertEquals(requester.user.to_string(), self.test_user)
+
+ @defer.inlineCallbacks
+ def test_get_user_by_req_appservice_valid_token_good_ip(self):
+ from netaddr import IPSet
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=IPSet(["192.168/16"]),
+ )
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
+ def test_get_user_by_req_appservice_valid_token_bad_ip(self):
+ from netaddr import IPSet
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=IPSet(["192.168/16"]),
+ )
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ self.store.get_user_by_access_token = Mock(return_value=None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "131.111.8.42"
+ request.args["access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ d = self.auth.get_user_by_req(request)
+ self.failureResultOf(d, AuthError)
+
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None)
@@ -119,12 +156,16 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -133,12 +174,16 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 149e443022..cc1c862ba4 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,10 +19,10 @@ import signedjson.sign
from mock import Mock
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.util import async, logcontext
+from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest, utils
-from twisted.internet import defer
+from twisted.internet import defer, reactor
class MockPerspectiveServer(object):
@@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
+ clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(1) # XXX find out why this takes so long!
+ yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(1)
+ yield clock.sleep(1)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/http/__init__.py
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
new file mode 100644
index 0000000000..cd74825c85
--- /dev/null
+++ b/tests/http/test_endpoint.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.http.endpoint import parse_server_name
+from tests import unittest
+
+
+class ServerNameTestCase(unittest.TestCase):
+ def test_parse_server_name(self):
+ test_data = {
+ 'localhost': ('localhost', None),
+ 'my-example.com:1234': ('my-example.com', 1234),
+ '1.2.3.4': ('1.2.3.4', None),
+ '[0abc:1def::1234]': ('[0abc:1def::1234]', None),
+ '1.2.3.4:1': ('1.2.3.4', 1),
+ '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
+ }
+
+ for i, o in test_data.items():
+ self.assertEqual(parse_server_name(i), o)
+
+ def test_parse_bad_server_names(self):
+ test_data = [
+ "", # empty
+ "localhost:http", # non-numeric port
+ "1234]", # smells like ipv6 literal but isn't
+ ]
+ for i in test_data:
+ try:
+ parse_server_name(i)
+ self.fail(
+ "Expected parse_server_name(\"%s\") to throw" % i,
+ )
+ except ValueError:
+ pass
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index da54d478ce..f47a42e45d 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -37,10 +37,6 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 1}
)
- yield self.check(
- "get_global_account_data_by_type_for_users",
- [TYPE, [USER_ID]], {USER_ID: {"a": 1}}
- )
yield self.master_store.add_account_data_for_user(
USER_ID, TYPE, {"a": 2}
@@ -50,7 +46,3 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 2}
)
- yield self.check(
- "get_global_account_data_by_type_for_users",
- [TYPE, [USER_ID]], {USER_ID: {"a": 2}}
- )
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index b5bc2fa255..6a757289db 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,9 +1,9 @@
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from mock import Mock, call
-from synapse.util import async
+from synapse.util import Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_logcontexts_with_async_result(self):
@defer.inlineCallbacks
def cb():
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
defer.returnValue("yay")
@defer.inlineCallbacks
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index eef38b6781..c5e2f5549a 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
+ hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(
@@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
- self.primary_base_path, self.filepaths, storage_providers,
+ hs, self.primary_base_path, self.filepaths, storage_providers,
)
def tearDown(self):
diff --git a/tests/server.py b/tests/server.py
new file mode 100644
index 0000000000..73069dff52
--- /dev/null
+++ b/tests/server.py
@@ -0,0 +1,181 @@
+from io import BytesIO
+
+import attr
+import json
+from six import text_type
+
+from twisted.python.failure import Failure
+from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.http.site import SynapseRequest
+from twisted.internet import threads
+from tests.utils import setup_test_homeserver as _sth
+
+
+@attr.s
+class FakeChannel(object):
+ """
+ A fake Twisted Web Channel (the part that interfaces with the
+ wire).
+ """
+
+ result = attr.ib(factory=dict)
+
+ @property
+ def json_body(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ return json.loads(self.result["body"])
+
+ def writeHeaders(self, version, code, reason, headers):
+ self.result["version"] = version
+ self.result["code"] = code
+ self.result["reason"] = reason
+ self.result["headers"] = headers
+
+ def write(self, content):
+ if "body" not in self.result:
+ self.result["body"] = b""
+
+ self.result["body"] += content
+
+ def requestDone(self, _self):
+ self.result["done"] = True
+
+ def getPeer(self):
+ return None
+
+ def getHost(self):
+ return None
+
+ @property
+ def transport(self):
+ return self
+
+
+class FakeSite:
+ """
+ A fake Twisted Web Site, with mocks of the extra things that
+ Synapse adds.
+ """
+
+ server_version_string = b"1"
+ site_tag = "test"
+
+ @property
+ def access_logger(self):
+ class FakeLogger:
+ def info(self, *args, **kwargs):
+ pass
+
+ return FakeLogger()
+
+
+def make_request(method, path, content=b""):
+ """
+ Make a web request using the given method and path, feed it the
+ content, and return the Request and the Channel underneath.
+ """
+
+ if isinstance(content, text_type):
+ content = content.encode('utf8')
+
+ site = FakeSite()
+ channel = FakeChannel()
+
+ req = SynapseRequest(site, channel)
+ req.process = lambda: b""
+ req.content = BytesIO(content)
+ req.requestReceived(method, path, b"1.1")
+
+ return req, channel
+
+
+def wait_until_result(clock, channel, timeout=100):
+ """
+ Wait until the channel has a result.
+ """
+ clock.run()
+ x = 0
+
+ while not channel.result:
+ x += 1
+
+ if x > timeout:
+ raise Exception("Timed out waiting for request to finish.")
+
+ clock.advance(0.1)
+
+
+class ThreadedMemoryReactorClock(MemoryReactorClock):
+ """
+ A MemoryReactorClock that supports callFromThread.
+ """
+ def callFromThread(self, callback, *args, **kwargs):
+ """
+ Make the callback fire in the next reactor iteration.
+ """
+ d = Deferred()
+ d.addCallback(lambda x: callback(*args, **kwargs))
+ self.callLater(0, d.callback, True)
+ return d
+
+
+def setup_test_homeserver(*args, **kwargs):
+ """
+ Set up a synchronous test server, driven by the reactor used by
+ the homeserver.
+ """
+ d = _sth(*args, **kwargs).result
+
+ # Make the thread pool synchronous.
+ clock = d.get_clock()
+ pool = d.get_db_pool()
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
+
+ pool.runWithConnection = runWithConnection
+ pool.runInteraction = runInteraction
+
+ class ThreadPool:
+ """
+ Threadless thread pool.
+ """
+ def start(self):
+ pass
+
+ def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
+ def _(res):
+ if isinstance(res, Failure):
+ onResult(False, res)
+ else:
+ onResult(True, res)
+
+ d = Deferred()
+ d.addCallback(lambda x: function(*args, **kwargs))
+ d.addBoth(_)
+ clock._reactor.callLater(0, d.callback, True)
+ return d
+
+ clock.threadpool = ThreadPool()
+ pool.threadpool = ThreadPool()
+ return d
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 010aeaee7e..c066381698 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
from mock import Mock, patch
from synapse.util.distributor import Distributor
-from synapse.util.async import run_on_reactor
class DistributorTestCase(unittest.TestCase):
@@ -95,7 +94,6 @@ class DistributorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def observer():
- yield run_on_reactor()
raise MyException("Oopsie")
self.dist.observe("whail", observer)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
new file mode 100644
index 0000000000..d08e19c53a
--- /dev/null
+++ b/tests/test_event_auth.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse import event_auth
+from synapse.api.errors import AuthError
+from synapse.events import FrozenEvent
+import unittest
+
+
+class EventAuthTestCase(unittest.TestCase):
+ def test_random_users_cannot_send_state_before_first_pl(self):
+ """
+ Check that, before the first PL lands, the creator is the only user
+ that can send a state event.
+ """
+ creator = "@creator:example.com"
+ joiner = "@joiner:example.com"
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.member", joiner): _join_event(joiner),
+ }
+
+ # creator should be able to send state
+ event_auth.check(
+ _random_state_event(creator), auth_events,
+ do_sig_check=False,
+ )
+
+ # joiner should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(joiner),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ def test_state_default_level(self):
+ """
+ Check that users above the state_default level can send state and
+ those below cannot
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+ king = "@joiner2:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(creator, {
+ "state_default": "30",
+ "users": {
+ pleb: "29",
+ king: "30",
+ },
+ }),
+ ("m.room.member", pleb): _join_event(pleb),
+ ("m.room.member", king): _join_event(king),
+ }
+
+ # pleb should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ # king should be able to send state
+ event_auth.check(
+ _random_state_event(king), auth_events,
+ do_sig_check=False,
+ )
+
+
+# helpers for making events
+
+TEST_ROOM_ID = "!test:room"
+
+
+def _create_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.create",
+ "sender": user_id,
+ "content": {
+ "creator": user_id,
+ },
+ })
+
+
+def _join_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+def _power_levels_event(sender, content):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.power_levels",
+ "sender": sender,
+ "state_key": "",
+ "content": content,
+ })
+
+
+def _random_state_event(sender):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "test.state",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+event_count = 0
+
+
+def _get_event_id():
+ global event_count
+ c = event_count
+ event_count += 1
+ return "!%i:example.com" % (c, )
diff --git a/tests/test_federation.py b/tests/test_federation.py
new file mode 100644
index 0000000000..fc80a69369
--- /dev/null
+++ b/tests/test_federation.py
@@ -0,0 +1,243 @@
+
+from twisted.internet.defer import succeed, maybeDeferred
+
+from synapse.util import Clock
+from synapse.events import FrozenEvent
+from synapse.types import Requester, UserID
+
+from tests import unittest
+from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock
+
+from mock import Mock
+
+
+class MessageAcceptTests(unittest.TestCase):
+ def setUp(self):
+
+ self.http_client = Mock()
+ self.reactor = ThreadedMemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
+ self.homeserver = setup_test_homeserver(
+ http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor
+ )
+
+ user_id = UserID("us", "test")
+ our_user = Requester(user_id, None, False, None, None)
+ room_creator = self.homeserver.get_room_creation_handler()
+ room = room_creator.create_room(
+ our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ )
+ self.reactor.advance(0.1)
+ self.room_id = self.successResultOf(room)["room_id"]
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ join_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$join:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "prev_state": [(most_recent, {})],
+ "prev_events": [(most_recent, {})],
+ }
+ )
+
+ self.handler = self.homeserver.get_handlers().federation_handler
+ self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.client = self.homeserver.get_federation_client()
+ self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
+ pdus
+ )
+
+ # Send the join, it should return None (which is not an error)
+ d = self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+ self.assertEqual(self.successResultOf(d), None)
+
+ # Make sure we actually joined the room
+ self.assertEqual(
+ self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0],
+ "$join:test.serv",
+ )
+
+ def test_cant_hide_direct_ancestors(self):
+ """
+ If you send a message, you must be able to provide the direct
+ prev_events that said event references.
+ """
+
+ def post_json(destination, path, data, headers=None, timeout=0):
+ # If it asks us for new missing events, give them NOTHING
+ if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+ return {"events": []}
+
+ self.http_client.post_json = post_json
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ # Now lie about an event
+ lying_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "one:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("two:test.serv", {}), (most_recent, {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
+
+ # Step the reactor, so the database fetches come back
+ self.reactor.advance(1)
+
+ # on_receive_pdu should throw an error
+ failure = self.failureResultOf(d)
+ self.assertEqual(
+ failure.value.args[0],
+ (
+ "ERROR 403: Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ )
+
+ # Make sure the invalid event isn't there
+ extrem = maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+ @unittest.DEBUG
+ def test_cant_hide_past_history(self):
+ """
+ If you send a message, you must be able to provide the direct
+ prev_events that said event references.
+ """
+
+ def post_json(destination, path, data, headers=None, timeout=0):
+ if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+ return {
+ "events": [
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "three:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("four:test.serv", {})],
+ }
+ ]
+ }
+
+ self.http_client.post_json = post_json
+
+ def get_json(destination, path, args, headers=None):
+ if path.startswith("/_matrix/federation/v1/state_ids/"):
+ d = self.successResultOf(
+ self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
+ )
+
+ return succeed(
+ {
+ "pdu_ids": [
+ y
+ for x, y in d.items()
+ if x == ("m.room.member", "@us:test")
+ ],
+ "auth_chain_ids": d.values(),
+ }
+ )
+
+ self.http_client.get_json = get_json
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ # Make a good event
+ good_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "one:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [(most_recent, {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", good_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+ self.assertEqual(self.successResultOf(d), None)
+
+ bad_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "two:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", bad_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+
+ extrem = maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
+
+ state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
+ self.reactor.advance(1)
+ self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())
diff --git a/tests/test_server.py b/tests/test_server.py
new file mode 100644
index 0000000000..8ad822c43b
--- /dev/null
+++ b/tests/test_server.py
@@ -0,0 +1,128 @@
+import json
+import re
+
+from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.util import Clock
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import JsonResource
+from tests import unittest
+from tests.server import make_request, setup_test_homeserver
+
+
+class JsonResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = MemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
+ self.homeserver = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.reactor
+ )
+
+ def test_handler_for_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+ got_kwargs = {}
+
+ def _callback(request, **kwargs):
+ got_kwargs.update(kwargs)
+ return (200, kwargs)
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
+
+ request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
+ request.render(res)
+
+ self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
+ self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
+
+ def test_callback_direct_exception(self):
+ """
+ If the web callback raises an uncaught exception, it will be translated
+ into a 500.
+ """
+
+ def _callback(request, **kwargs):
+ raise Exception("boo")
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/foo")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'500')
+
+ def test_callback_indirect_exception(self):
+ """
+ If the web callback raises an uncaught exception in a Deferred, it will
+ be translated into a 500.
+ """
+
+ def _throw(*args):
+ raise Exception("boo")
+
+ def _callback(request, **kwargs):
+ d = Deferred()
+ d.addCallback(_throw)
+ self.reactor.callLater(1, d.callback, True)
+ return d
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/foo")
+ request.render(res)
+
+ # No error has been raised yet
+ self.assertTrue("code" not in channel.result)
+
+ # Advance time, now there's an error
+ self.reactor.advance(1)
+ self.assertEqual(channel.result["code"], b'500')
+
+ def test_callback_synapseerror(self):
+ """
+ If the web callback raises a SynapseError, it returns the appropriate
+ status code and message set in it.
+ """
+
+ def _callback(request, **kwargs):
+ raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/foo")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'403')
+ reply_body = json.loads(channel.result["body"])
+ self.assertEqual(reply_body["error"], "Forbidden!!one!")
+ self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
+
+ def test_no_handler(self):
+ """
+ If there is no handler to process the request, Synapse will return 400.
+ """
+
+ def _callback(request, **kwargs):
+ """
+ Not ever actually called!
+ """
+ self.fail("shouldn't ever get here")
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/foobar")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'400')
+ reply_body = json.loads(channel.result["body"])
+ self.assertEqual(reply_body["error"], "Unrecognized request")
+ self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")
diff --git a/tests/test_state.py b/tests/test_state.py
index a5c5e55951..71c412faf4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase):
}
)
+ power_levels = create_event(
+ type=EventTypes.PowerLevels, state_key="",
+ content={"users": {
+ "@foo:bar": "100",
+ "@user_id:example.com": "100",
+ }}
+ )
+
creation = create_event(
type=EventTypes.Create, state_key="",
content={"creator": "@foo:bar"}
@@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
@@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
@@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
diff --git a/tests/unittest.py b/tests/unittest.py
index 184fe880f3..b25f2db5d5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
- self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry)
+ self.tx_log.emit(
+ twisted.logger.LogLevel.levelWithName(log_level),
+ log_entry.replace("{", r"(").replace("}", r")"),
+ )
handler = ToTwistedHandler()
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..24754591df 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -18,7 +18,6 @@ import logging
import mock
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
from twisted.internet import defer
from synapse.util.caches import descriptors
@@ -195,7 +194,6 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bc92f85fa6..543ac5bed9 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
@@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_hit_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
@@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_miss_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
@@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
- self.cache.update(seq, key, test_value_1, full=False)
+ self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
- self.cache.update(seq, key, test_value_2, full=False)
+ self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key)
self.assertEqual(
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..c2aae8f54c 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = DummyPullProducer()
@@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
@@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..bf7e3aa885 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util import async, logcontext
+
+from synapse.util import logcontext, Clock
from tests import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from synapse.util.async import Linearizer
from six.moves import range
@@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..9cf90fcfc4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -3,8 +3,7 @@ from twisted.internet import defer
from twisted.internet import reactor
from .. import unittest
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_sleep(self):
+ clock = Clock(reactor)
+
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
competing_context.request = "competing"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
context_one.request = "one"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
- yield sleep(0)
+ yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
diff --git a/tests/utils.py b/tests/utils.py
index 262c4a5714..189fd2711c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
+def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
+ **kargs):
"""Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS.
"""
+ if reactor is None:
+ from twisted.internet import reactor
+
if config is None:
config = Mock()
config.signing_key = [MockKey()]
@@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
+ reactor=reactor,
**kargs
)
db_conn = hs.get_db_conn()
diff --git a/tox.ini b/tox.ini
index 5d79098d2f..61a20a10cb 100644
--- a/tox.ini
+++ b/tox.ini
@@ -102,3 +102,11 @@ basepython = python2.7
deps =
flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
+
+
+[testenv:check-newsfragment]
+skip_install = True
+deps = towncrier>=18.6.0rc1
+commands =
+ python -m towncrier.check --compare-with=origin/develop
+basepython = python3.6
\ No newline at end of file
|