diff options
author | Richard van der Hoff <richard@matrix.org> | 2019-03-25 16:48:56 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2019-03-25 16:48:56 +0000 |
commit | ce0ce1add3dde3da2ff366ee873174f5b5e70763 (patch) | |
tree | 8573d760dbaa68676e95e18698639537e75d75b2 | |
parent | Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes (diff) | |
parent | Fix ClientReplicationStreamProtocol.__str__ (#4929) (diff) | |
download | synapse-ce0ce1add3dde3da2ff366ee873174f5b5e70763.tar.xz |
Merge branch 'develop' into matrix-org-hotfixes
39 files changed, 686 insertions, 405 deletions
diff --git a/changelog.d/4840.feature b/changelog.d/4840.feature new file mode 100644 index 0000000000..9d1fd59053 --- /dev/null +++ b/changelog.d/4840.feature @@ -0,0 +1 @@ +Remove trailing slashes from certain outbound federation requests. Retry if receiving a 404. Context: #3622. \ No newline at end of file diff --git a/changelog.d/4869.misc b/changelog.d/4869.misc new file mode 100644 index 0000000000..d8186cc520 --- /dev/null +++ b/changelog.d/4869.misc @@ -0,0 +1 @@ +Fix yaml library warnings by using safe_load. diff --git a/changelog.d/4912.misc b/changelog.d/4912.misc new file mode 100644 index 0000000000..f05a239187 --- /dev/null +++ b/changelog.d/4912.misc @@ -0,0 +1 @@ +Allow newsfragments to end with exclamation marks. Exciting! diff --git a/changelog.d/4913.misc b/changelog.d/4913.misc new file mode 100644 index 0000000000..9e835badc0 --- /dev/null +++ b/changelog.d/4913.misc @@ -0,0 +1 @@ +Refactor some more tests to use HomeserverTestCase. diff --git a/changelog.d/4917.misc b/changelog.d/4917.misc new file mode 100644 index 0000000000..338d8a9a0c --- /dev/null +++ b/changelog.d/4917.misc @@ -0,0 +1 @@ +Refactor out the state deltas portion of the user directory store and handler. diff --git a/changelog.d/4923.misc b/changelog.d/4923.misc new file mode 100644 index 0000000000..8b5e1e3c81 --- /dev/null +++ b/changelog.d/4923.misc @@ -0,0 +1 @@ +Fix nginx example in ACME doc. diff --git a/changelog.d/4927.feature b/changelog.d/4927.feature new file mode 100644 index 0000000000..8d74262250 --- /dev/null +++ b/changelog.d/4927.feature @@ -0,0 +1 @@ +Batch up outgoing read-receipts to reduce federation traffic. diff --git a/changelog.d/4928.misc b/changelog.d/4928.misc new file mode 100644 index 0000000000..3a21752551 --- /dev/null +++ b/changelog.d/4928.misc @@ -0,0 +1 @@ +Use an explicit dbname for postgres connections in the tests. diff --git a/changelog.d/4929.misc b/changelog.d/4929.misc new file mode 100644 index 0000000000..aaf02078b9 --- /dev/null +++ b/changelog.d/4929.misc @@ -0,0 +1 @@ +Fix `ClientReplicationStreamProtocol.__str__()`. diff --git a/docs/ACME.md b/docs/ACME.md index 46136a9f2c..9eb18a9cf5 100644 --- a/docs/ACME.md +++ b/docs/ACME.md @@ -67,7 +67,7 @@ For nginx users, add the following line to your existing `server` block: ``` location /.well-known/acme-challenge { - proxy_pass http://localhost:8009/; + proxy_pass http://localhost:8009; } ``` diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index e0ac84198e..0ec5075e79 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -31,8 +31,8 @@ echo # check that any new newsfiles on this branch end with a full stop. for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do lastchar=`tr -d '\n' < $f | tail -c 1` - if [ $lastchar != '.' ]; then - echo -e "\e[31mERROR: newsfragment $f does not end with a '.'\e[39m" >&2 + if [ $lastchar != '.' -a $lastchar != '!' ]; then + echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 exit 1 fi done diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index dde8596697..ac152b5c42 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -76,7 +76,7 @@ def rows_v2(server, json): def main(): - config = yaml.load(open(sys.argv[1])) + config = yaml.safe_load(open(sys.argv[1])) valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24 server_name = config["server_name"] diff --git a/synapse/config/_base.py b/synapse/config/_base.py index a219a83550..f7d7f153bb 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -137,7 +137,7 @@ class Config(object): @staticmethod def read_config_file(file_path): with open(file_path) as file_stream: - return yaml.load(file_stream) + return yaml.safe_load(file_stream) def invoke_all(self, name, *args, **kargs): results = [] @@ -318,7 +318,7 @@ class Config(object): ) config_file.write(config_str) - config = yaml.load(config_str) + config = yaml.safe_load(config_str) obj.invoke_all("generate_files", config) print( @@ -390,7 +390,7 @@ class Config(object): server_name=server_name, generate_secrets=False, ) - config = yaml.load(config_string) + config = yaml.safe_load(config_string) config.pop("log_config") config.update(specified_config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 9e64c76544..7e89d345d8 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -68,7 +68,7 @@ def load_appservices(hostname, config_files): try: with open(config_file, 'r') as f: appservice = _load_appservice( - hostname, yaml.load(f), config_file + hostname, yaml.safe_load(f), config_file ) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 464c28c2d9..c1febbe9d3 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -195,7 +195,7 @@ def setup_logging(config, use_worker_options=False): else: def load_log_config(): with open(log_config, 'r') as f: - logging.config.dictConfig(yaml.load(f)) + logging.config.dictConfig(yaml.safe_load(f)) def sighup(*args): # it might be better to use a file watcher or something for this. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8e2be218e2..0cdb31178f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -51,9 +51,10 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state/%s/", room_id) + path = _create_v1_path("/state/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -73,9 +74,10 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state_ids/%s/", room_id) + path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -95,8 +97,11 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = _create_v1_path("/event/%s/", event_id) - return self.client.get_json(destination, path=path, timeout=timeout) + path = _create_v1_path("/event/%s", event_id) + return self.client.get_json( + destination, path=path, timeout=timeout, + try_trailing_slash_on_400=True, + ) @log_function def backfill(self, destination, room_id, event_tuples, limit): @@ -121,7 +126,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = _create_v1_path("/backfill/%s/", room_id) + path = _create_v1_path("/backfill/%s", room_id) args = { "v": event_tuples, @@ -132,6 +137,7 @@ class TransportLayerClient(object): destination, path=path, args=args, + try_trailing_slash_on_400=True, ) @defer.inlineCallbacks @@ -176,6 +182,7 @@ class TransportLayerClient(object): json_data_callback=json_data_callback, long_retries=True, backoff_on_404=True, # If we get a 404 the other side has gone + try_trailing_slash_on_400=True, ) defer.returnValue(response) @@ -959,7 +966,7 @@ def _create_v1_path(path, *args): Example: - _create_v1_path("/event/%s/", event_id) + _create_v1_path("/event/%s", event_id) Args: path (str): String template for the path @@ -980,7 +987,7 @@ def _create_v2_path(path, *args): Example: - _create_v2_path("/event/%s/", event_id) + _create_v2_path("/event/%s", event_id) Args: path (str): String template for the path diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py new file mode 100644 index 0000000000..b268bbcb2c --- /dev/null +++ b/synapse/handlers/state_deltas.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +class StateDeltasHandler(object): + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + """Given two events check if the `key_name` field in content changed + from not matching `public_value` to doing so. + + For example, check if `history_visibility` (`key_name`) changed from + `shared` to `world_readable` (`public_value`). + + Returns: + None if the field in the events either both match `public_value` + or if neither do, i.e. there has been no change. + True if it didnt match `public_value` but now does + False if it did match `public_value` but now doesn't + """ + prev_event = None + event = None + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + + if not event and not prev_event: + logger.debug("Neither event exists: %r %r", prev_event_id, event_id) + defer.returnValue(None) + + prev_value = None + value = None + + if prev_event: + prev_value = prev_event.content.get(key_name) + + if event: + value = event.content.get(key_name) + + logger.debug("prev_value: %r -> value: %r", prev_value, value) + + if value == public_value and prev_value != public_value: + defer.returnValue(True) + elif value != public_value and prev_value == public_value: + defer.returnValue(False) + else: + defer.returnValue(None) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 7dc0e236e7..b689979b4b 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import get_localpart_from_id @@ -29,7 +30,7 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -class UserDirectoryHandler(object): +class UserDirectoryHandler(StateDeltasHandler): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,6 +42,8 @@ class UserDirectoryHandler(object): """ def __init__(self, hs): + super(UserDirectoryHandler, self).__init__(hs) + self.store = hs.get_datastore() self.state = hs.get_state_handler() self.server_name = hs.hostname @@ -360,7 +363,7 @@ class UserDirectoryHandler(object): @defer.inlineCallbacks def _handle_remove_user(self, room_id, user_id): - """Called when we might need to remove user to directory + """Called when we might need to remove user from directory Args: room_id (str): room_id that user left or stopped being public that @@ -402,47 +405,3 @@ class UserDirectoryHandler(object): if prev_name != new_name or prev_avatar != new_avatar: yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) - - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): - """Given two events check if the `key_name` field in content changed - from not matching `public_value` to doing so. - - For example, check if `history_visibility` (`key_name`) changed from - `shared` to `world_readable` (`public_value`). - - Returns: - None if the field in the events either both match `public_value` - or if neither do, i.e. there has been no change. - True if it didnt match `public_value` but now does - False if it did match `public_value` but now doesn't - """ - prev_event = None - event = None - if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - - if event_id: - event = yield self.store.get_event(event_id, allow_none=True) - - if not event and not prev_event: - logger.debug("Neither event exists: %r %r", prev_event_id, event_id) - defer.returnValue(None) - - prev_value = None - value = None - - if prev_event: - prev_value = prev_event.content.get(key_name) - - if event: - value = event.content.get(key_name) - - logger.debug("prev_value: %r -> value: %r", prev_value, value) - - if value == public_value and prev_value != public_value: - defer.returnValue(True) - elif value != public_value and prev_value == public_value: - defer.returnValue(False) - else: - defer.returnValue(None) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1682c9af13..8e855d13d6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -189,6 +189,57 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks + def _send_request_with_optional_trailing_slash( + self, + request, + try_trailing_slash_on_400=False, + **send_request_args + ): + """Wrapper for _send_request which can optionally retry the request + upon receiving a combination of a 400 HTTP response code and a + 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 + due to #3622. + + Args: + request (MatrixFederationRequest): details of request to be sent + try_trailing_slash_on_400 (bool): Whether on receiving a 400 + 'M_UNRECOGNIZED' from the server to retry the request with a + trailing slash appended to the request path. + send_request_args (Dict): A dictionary of arguments to pass to + `_send_request()`. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + + Returns: + Deferred[Dict]: Parsed JSON response body. + """ + try: + response = yield self._send_request( + request, **send_request_args + ) + except HttpResponseException as e: + # Received an HTTP error > 300. Check if it meets the requirements + # to retry with a trailing slash + if not try_trailing_slash_on_400: + raise + + if e.code != 400 or e.to_synapse_error().errcode != "M_UNRECOGNIZED": + raise + + # Retry with a trailing slash if we received a 400 with + # 'M_UNRECOGNIZED' which some endpoints can return when omitting a + # trailing slash on Synapse <= v0.99.3. + request.path += "/" + + response = yield self._send_request( + request, **send_request_args + ) + + defer.returnValue(response) + + @defer.inlineCallbacks def _send_request( self, request, @@ -196,7 +247,7 @@ class MatrixFederationHttpClient(object): timeout=None, long_retries=False, ignore_backoff=False, - backoff_on_404=False + backoff_on_404=False, ): """ Sends a request to the given server. @@ -473,7 +524,8 @@ class MatrixFederationHttpClient(object): json_data_callback=None, long_retries=False, timeout=None, ignore_backoff=False, - backoff_on_404=False): + backoff_on_404=False, + try_trailing_slash_on_400=False): """ Sends the specifed json data using PUT Args: @@ -493,7 +545,12 @@ class MatrixFederationHttpClient(object): and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as a failure of the server (and should therefore back off future - requests) + requests). + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end + of the request. Workaround for #3622 in Synapse <= v0.99.3. This + will be attempted before backing off if backing off has been + enabled. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The @@ -509,7 +566,6 @@ class MatrixFederationHttpClient(object): RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ - request = MatrixFederationRequest( method="PUT", destination=destination, @@ -519,17 +575,19 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=backoff_on_404, + ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, - ignore_backoff=ignore_backoff, - backoff_on_404=backoff_on_404, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks @@ -592,7 +650,8 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False): + timeout=None, ignore_backoff=False, + try_trailing_slash_on_400=False): """ GETs some json from the given host homeserver and path Args: @@ -606,6 +665,9 @@ class MatrixFederationHttpClient(object): be retried. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end of + the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -631,16 +693,19 @@ class MatrixFederationHttpClient(object): query=args, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=False, + ignore_backoff=ignore_backoff, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, - ignore_backoff=ignore_backoff, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 55630ba9a7..02e5bf6cc8 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -223,14 +223,25 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return # Now lets try and call on_<CMD_NAME> function - try: - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - getattr(self, "on_%s" % (cmd_name,)), - cmd, - ) - except Exception: - logger.exception("[%s] Failed to handle line: %r", self.id(), line) + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), + self.handle_command, + cmd, + ) + + def handle_command(self, cmd): + """Handle a command we have received over the replication stream. + + By default delegates to on_<COMMAND> + + Args: + cmd (synapse.replication.tcp.commands.Command): received command + + Returns: + Deferred + """ + handler = getattr(self, "on_%s" % (cmd.NAME,)) + return handler(cmd) def close(self): logger.warn("[%s] Closing connection", self.id()) @@ -364,8 +375,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.unregisterProducer() def __str__(self): + addr = None + if self.transport: + addr = str(self.transport.getPeer()) return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % ( - self.name, self.conn_id, self.addr, + self.name, self.conn_id, addr, ) def id(self): @@ -381,12 +395,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer, addr): + def __init__(self, server_name, clock, streamer): BaseReplicationStreamProtocol.__init__(self, clock) # Old style class self.server_name = server_name self.streamer = streamer - self.addr = addr # The streams the client has subscribed to and is up to date with self.replication_streams = set() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 47cdf30bd3..7fc346c7b6 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -57,7 +57,6 @@ class ReplicationStreamProtocolFactory(Factory): self.server_name, self.clock, self.streamer, - addr ) diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 728746bd12..a6e9d6709e 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -23,7 +23,7 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ - +import itertools import logging from collections import namedtuple @@ -195,8 +195,8 @@ class Stream(object): limit=MAX_EVENTS_BEHIND + 1, ) - if len(rows) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) + # never turn more than MAX_EVENTS_BEHIND + 1 into updates. + rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: rows = yield self.update_function( from_token, current_token, @@ -204,6 +204,11 @@ class Stream(object): updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + # check we didn't get more rows than the limit. + # doing it like this allows the update_function to be a generator. + if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: + raise Exception("stream %s has fallen behind" % (self.NAME)) + defer.returnValue((updates, current_token)) def current_token(self): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 0fd1ccc40a..89a1f7e3d7 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -301,7 +301,9 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return txn.fetchall() + return ( + r[0:5] + (json.loads(r[5]), ) for r in txn + ) return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py new file mode 100644 index 0000000000..57bc45cdb9 --- /dev/null +++ b/synapse/storage/state_deltas.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class StateDeltasStore(SQLBaseStore): + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index d360e857d1..65bdb1b4a5 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, JoinRules from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.state import StateFilter +from synapse.storage.state_deltas import StateDeltasStore from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -31,7 +32,7 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryStore(BackgroundUpdateStore): +class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -488,16 +489,6 @@ class UserDirectoryStore(BackgroundUpdateStore): defer.returnValue(user_ids) - @defer.inlineCallbacks - def get_all_local_users(self): - """Get all local users - """ - sql = """ - SELECT name FROM users - """ - rows = yield self._execute("get_all_local_users", None, sql) - defer.returnValue([name for name, in rows]) - def add_users_who_share_private_room(self, room_id, user_id_tuples): """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. @@ -675,59 +666,6 @@ class UserDirectoryStore(BackgroundUpdateStore): desc="update_user_directory_stream_pos", ) - def get_current_state_deltas(self, prev_stream_id): - prev_stream_id = int(prev_stream_id) - if not self._curr_state_delta_stream_cache.has_any_entity_changed( - prev_stream_id - ): - return [] - - def get_current_state_deltas_txn(txn): - # First we calculate the max stream id that will give us less than - # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) - FROM current_state_delta_stream - WHERE stream_id > ? - GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 - """ - txn.execute(sql, (prev_stream_id,)) - - total = 0 - max_stream_id = prev_stream_id - for max_stream_id, count in txn: - total += count - if total > 100: - # We arbitarily limit to 100 entries to ensure we don't - # select toooo many. - break - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - return self.cursor_to_dict(txn) - - return self.runInteraction( - "get_current_state_deltas", get_current_state_deltas_txn - ) - - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( - table="current_state_delta_stream", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", - ) - @defer.inlineCallbacks def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory diff --git a/synctl b/synctl index 816c898b36..07a68e6d85 100755 --- a/synctl +++ b/synctl @@ -164,7 +164,7 @@ def main(): sys.exit(1) with open(configfile) as stream: - config = yaml.load(stream) + config = yaml.safe_load(stream) pidfile = config["pid_file"] cache_factor = config.get("synctl_cache_factor") @@ -206,7 +206,7 @@ def main(): workers = [] for worker_configfile in worker_configfiles: with open(worker_configfile) as stream: - worker_config = yaml.load(stream) + worker_config = yaml.safe_load(stream) worker_app = worker_config["worker_app"] if worker_app == "synapse.app.homeserver": # We need to special case all of this to pick up options that may diff --git a/tests/config/test_load.py b/tests/config/test_load.py index d5f1777093..6bfc1970ad 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): self.generate_config() with open(self.file, "r") as f: - raw = yaml.load(f) + raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) config = HomeServerConfig.load_config("", ["-c", self.file]) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 3dc2631523..47fffcfeb2 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -22,7 +22,7 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -74,7 +74,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): )) def test_room_publish_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: [] room_list_publication_rules: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 2217eb2a10..017ea0385e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester -from tests.utils import default_config, setup_test_homeserver - from .. import unittest @@ -32,26 +30,23 @@ class RegistrationHandlers(object): self.registration_handler = RegistrationHandler(hs) -class RegistrationTestCase(unittest.TestCase): +class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ - @defer.inlineCallbacks - def setUp(self): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() - - hs_config = default_config("test") + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") # some of the tests rely on us having a user consent version hs_config.user_consent_version = "test_consent_version" hs_config.max_mau_value = 50 - self.hs = yield setup_test_homeserver( - self.addCleanup, - config=hs_config, - expire_access_token=True, - ) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): + self.mock_distributor = Mock() + self.mock_distributor.declare("registered_user") + self.mock_captcha_client = Mock() self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) @@ -63,136 +58,133 @@ class RegistrationTestCase(unittest.TestCase): self.requester = create_requester("@requester:test") - @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, frank.localpart, "Frankie" + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, frank.localpart, "Frankie") ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) self.assertEquals(result_token, 'secret') - @defer.inlineCallbacks def test_if_user_exists(self): store = self.hs.get_datastore() frank = UserID.from_string("@frank:test") - yield store.register( - user_id=frank.to_string(), - token="jkv;g498752-43gj['eamb!-5", - password_hash=None, + self.get_success( + store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None, + ) ) local_part = frank.localpart user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, None + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, local_part, None) ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - @defer.inlineCallbacks def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.get_success( + self.handler.get_or_create_user(self.requester, 'a', "display_name") + ) - @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'c', "User") + self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) - @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - room_id = yield directory_handler.get_association(room_alias) + room_id = self.get_success(directory_handler.get_association(room_alias)) self.assertTrue(room_id['room_id'] in rooms) self.assertEqual(len(rooms), 1) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_with_no_rooms(self): self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_room_is_another_domain(self): self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_when_support_user_exists(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = yield self.handler.register(localpart='support') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='support')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - with self.assertRaises(SynapseError): - yield directory_handler.get_association(room_alias) + self.get_failure(directory_handler.get_association(room_alias), SynapseError) - @defer.inlineCallbacks def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -208,27 +200,27 @@ class RegistrationTestCase(unittest.TestCase): # (Messing with the internals of event_creation_handler is fragile # but can't see a better way to do this. One option could be to subclass # the test with custom config.) - event_creation_handler._block_events_without_consent_error = ("Error") + event_creation_handler._block_events_without_consent_error = "Error" event_creation_handler._consent_uri_builder = Mock() room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] # When:- # * the user is registered and post consent actions are called - res = yield self.handler.register(localpart='jeff') - yield self.handler.post_consent_actions(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + self.get_success(self.handler.post_consent_actions(res[0])) # Then:- # * Ensure that they have not been joined to the room - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_register_support_user(self): - res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + res = self.get_success( + self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + ) self.assertTrue(self.store.is_support_user(res[0])) - @defer.inlineCallbacks def test_register_not_support_user(self): - res = yield self.handler.register(localpart='user') + res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 13486930fb..7decb22933 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) def test_started_typing_remote_recv(self): @@ -269,6 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index b03b37affe..cd8e086f86 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -268,6 +268,105 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, TimeoutError) + def test_client_requires_trailing_slashes(self): + """ + If a connection is made to a client but the client rejects it due to + requiring a trailing slash. We need to retry the request with a + trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 400 Bad Request\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 59\r\n" + b"\r\n" + b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + ) + + # We should get another request with a trailing slash + self.assertRegex(conn.value(), b"^GET /foo/bar/") + + # Send a happy response this time + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b'{}' + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + def test_client_does_not_retry_on_400_plus(self): + """ + Another test for trailing slashes but now test that we don't retry on + trailing slashes on a non-400/M_UNRECOGNIZED response. + + See test_client_requires_trailing_slashes() for context. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 404 Not Found\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should not get another request + self.assertEqual(conn.value(), b"") + + # We should get a 404 failure response + self.failureResultOf(d) + def test_client_sends_body(self): self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, diff --git a/tests/replication/tcp/__init__.py b/tests/replication/tcp/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/__init__.py b/tests/replication/tcp/streams/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/streams/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py new file mode 100644 index 0000000000..38b368a972 --- /dev/null +++ b/tests/replication/tcp/streams/_base.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.server import FakeTransport + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): + # build a replication server + server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = server_factory.streamer + server = server_factory.buildProtocol(None) + + # build a replication client, with a dummy handler + self.test_handler = TestReplicationClientHandler() + self.client = ClientReplicationStreamProtocol( + "client", "test", clock, self.test_handler + ) + + # wire them together + self.client.makeConnection(FakeTransport(server, reactor)) + server.makeConnection(FakeTransport(self.client, reactor)) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def replicate_stream(self, stream, token="NOW"): + """Make the client end a REPLICATE command to set up a subscription to a stream""" + self.client.send_command(ReplicateCommand(stream, token)) + + +class TestReplicationClientHandler(object): + """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): + self.received_rdata_rows = [] + + def get_streams_to_replicate(self): + return {} + + def get_currently_syncing_users(self): + return [] + + def update_connection(self, connection): + pass + + def finished_connecting(self): + pass + + def on_rdata(self, stream_name, token, rows): + for r in rows: + self.received_rdata_rows.append( + (stream_name, token, r) + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py new file mode 100644 index 0000000000..9aa9dfe82e --- /dev/null +++ b/tests/replication/tcp/streams/test_receipts.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.streams import ReceiptsStreamRow + +from tests.replication.tcp.streams._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" +ROOM_ID = "!room:blue" +EVENT_ID = "$event:blue" + + +class ReceiptsStreamTestCase(BaseStreamTestCase): + def test_receipt(self): + # make the client subscribe to the receipts stream + self.replicate_stream("receipts", "NOW") + + # tell the master to send a new receipt + self.get_success( + self.hs.get_datastore().insert_receipt( + ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + ) + ) + self.replicate() + + # there should be one RDATA command + rdata_rows = self.test_handler.received_rdata_rows + self.assertEqual(1, len(rdata_rows)) + self.assertEqual(rdata_rows[0][0], "receipts") + row = rdata_rows[0][2] # type: ReceiptsStreamRow + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual({"a": 1}, row.data) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 9c401bf300..05b0143c42 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -18,136 +18,11 @@ import time import attr -from twisted.internet import defer - from synapse.api.constants import Membership -from tests import unittest from tests.server import make_request, render -class RestTestCase(unittest.TestCase): - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - - This subclass assumes there are mock_resource and auth_user_id attributes. - """ - - def __init__(self, *args, **kwargs): - super(RestTestCase, self).__init__(*args, **kwargs) - self.mock_resource = None - self.auth_user_id = None - - @defer.inlineCallbacks - def create_room_as(self, room_creator, is_public=True, tok=None): - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/createRoom" - content = "{}" - if not is_public: - content = '{"visibility":"private"}' - if tok: - path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("POST", path, content) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = temp_id - defer.returnValue(response["room_id"]) - - @defer.inlineCallbacks - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def join(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def leave(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - - (code, response) = yield self.mock_resource.trigger( - "PUT", path, json.dumps(data) - ) - self.assertEquals( - expect_code, - code, - msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response), - ) - - self.auth_user_id = temp_id - - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code, msg=response) - defer.returnValue(response) - - @defer.inlineCallbacks - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - if body is None: - body = "body_text_here" - - path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body - if tok: - path = path + "?access_token=%s" % tok - - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) - - def assert_dict(self, required, actual): - """Does a partial assert of a dict. - - Args: - required (dict): The keys and value which MUST be in 'actual'. - actual (dict): The test result. Extra keys will not be checked. - """ - for key in required: - self.assertEquals( - required[key], actual[key], msg="%s mismatch. %s" % (key, actual) - ) - - @attr.s class RestHelper(object): """Contains extra helper functions to quickly and clearly perform a given diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3bd9f1e9c1..be73e718c2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet import defer @@ -9,16 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest -from tests.utils import default_config, setup_test_homeserver -class TestResourceLimitsServerNotices(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs_config = default_config(name="test") +class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): + + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" - self.hs = yield setup_test_homeserver(self.addCleanup, config=hs_config) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -53,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_flag_off(self): """Tests cases where the flags indicate nothing to do""" # test hs disabled case self.hs.config.hs_disabled = True - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False self.hs.limit_usage_by_mau = False - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" @@ -81,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase): return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): - """Test when user has blocked notice, but notice ought to be there (NOOP)""" + """ + Test when user has blocked notice, but notice ought to be there (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) @@ -98,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice(self): - """Test when user does not have blocked notice, but should have one""" + """ + Test when user does not have blocked notice, but should have one + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event self.assertTrue(self._send_notice.call_count == 2) - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): - """Test when user does not have blocked notice, nor should they (NOOP)""" - + """ + Test when user does not have blocked notice, nor should they (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock() - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): - - """Test when user is not part of the MAU cohort - this should not ever + """ + Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() -class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() @@ -168,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Now lets get the last load of messages in the service notice room and # check that there is only one server notice - room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id + room_id = self.get_success( + self.server_notices_manager.get_notice_room_for_user(self.user_id) ) - token = yield self.event_source.get_current_token() - events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key + token = self.get_success(self.event_source.get_current_token()) + events, _ = self.get_success( + self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key + ) ) count = 0 diff --git a/tests/unittest.py b/tests/unittest.py index 7772a47078..27403de908 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -314,6 +314,9 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) + if "config" not in kwargs: + config = self.default_config() + kwargs["config"] = config hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -336,6 +339,15 @@ class HomeserverTestCase(TestCase): self.pump(by=by) return self.successResultOf(d) + def get_failure(self, d, exc): + """ + Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + """ + if not isinstance(d, Deferred): + return d + self.pump() + return self.failureResultOf(d, exc) + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. diff --git a/tests/utils.py b/tests/utils.py index d4ab4209ed..615b9f8cca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +44,10 @@ from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. +# +# When running under postgres, we first create a base database with the name +# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we +# create another unique database, using the base database as a template. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None) @@ -50,28 +55,20 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None) POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) +# the dbname we will connect to in order to create the base database. +POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -def setupdb(): +def setupdb(): # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: - pgconfig = { - "name": "psycopg2", - "args": { - "database": POSTGRES_BASE_DB, - "user": POSTGRES_USER, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "cp_min": 1, - "cp_max": 5, - }, - } - config = Mock() - config.password_providers = [] - config.database_config = pgconfig - db_engine = create_engine(pgconfig) + # create a PostgresEngine + db_engine = create_engine({"name": "psycopg2", "args": {}}) + + # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True cur = db_conn.cursor() @@ -96,7 +93,8 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True cur = db_conn.cursor() |