From 34ea14139d57895431bbd14190472ccb5111c92d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Jan 2019 09:25:59 +0000 Subject: Fixup docstrings for matrixfederationclient --- synapse/http/matrixfederationclient.py | 141 +++++++++++++++++---------------- 1 file changed, 73 insertions(+), 68 deletions(-) (limited to 'synapse') diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index be4076fc6a..f2a42f97a6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -229,19 +229,18 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred: resolves with the http response object on success. - - Fails with ``HttpResponseException``: if we get an HTTP response - code >= 300 (except 429). - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist - - Fails with ``RequestSendFailed`` if there were problems connecting to - the remote, due to e.g. DNS failures, connection timeouts etc. + Deferred[twisted.web.client.Response]: resolves with the HTTP + response object on success. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ if timeout: _sec_timeout = timeout / 1000 @@ -516,17 +515,18 @@ class MatrixFederationHttpClient(object): requests) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -570,17 +570,18 @@ class MatrixFederationHttpClient(object): try the request anyway. args (dict): query params Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( @@ -625,17 +626,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ logger.debug("get_json args: %s", args) @@ -676,17 +678,18 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. - - Fails with ``HttpResponseException`` if we get an HTTP response - code >= 300. - - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + result will be the decoded JSON body. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="DELETE", @@ -719,18 +722,20 @@ class MatrixFederationHttpClient(object): args (dict): Optional dictionary used to create the query string. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. - Returns: - Deferred: resolves with an (int,dict) tuple of the file length and - a dict of the response headers. - - Fails with ``HttpResponseException`` if we get an HTTP response code - >= 300 - Fails with ``NotRetryingDestination`` if we are not yet ready - to retry this server. - - Fails with ``FederationDeniedError`` if this destination - is not on our federation whitelist + Returns: + Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + the file length and a dict of the response headers. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( method="GET", -- cgit 1.5.1 From 7960c26fda8dd8d2d1333b45ef4f5c644930a619 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 9 Jan 2019 22:26:25 +1100 Subject: Fix adding new rows instead of updating them if one of the key values is a NULL in upserts. (#4369) --- changelog.d/4369.bugfix | 1 + synapse/storage/_base.py | 10 +++++- tests/storage/test_client_ips.py | 71 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4369.bugfix (limited to 'synapse') diff --git a/changelog.d/4369.bugfix b/changelog.d/4369.bugfix new file mode 100644 index 0000000000..a78d557932 --- /dev/null +++ b/changelog.d/4369.bugfix @@ -0,0 +1 @@ +Prevent users with access tokens predating the introduction of device IDs from creating spurious entries in the user_ips table. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1d3069b143..865b5e915a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -547,11 +547,19 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues) ) sqlargs = list(values.values()) + list(keyvalues.values()) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 4577e9422b..858efe4992 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -62,6 +62,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) + def test_insert_new_client_ip_none_device_id(self): + """ + An insert with a device ID of NULL will not create a new entry, but + update an existing entry in the user_ips table. + """ + self.reactor.advance(12345678) + + user_id = "@user:id" + + # Add & trigger the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", None + ) + ) + self.reactor.advance(200) + self.pump(0) + + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + 'access_token': 'access_token', + 'ip': 'ip', + 'user_agent': 'user_agent', + 'device_id': None, + 'last_seen': 12345678000, + } + ], + ) + + # Add another & trigger the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", None + ) + ) + self.reactor.advance(10) + self.pump(0) + + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + # Only one result, has been upserted. + self.assertEqual( + result, + [ + { + 'access_token': 'access_token', + 'ip': 'ip', + 'user_agent': 'user_agent', + 'device_id': None, + 'last_seen': 12345878000, + } + ], + ) + def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 -- cgit 1.5.1 From a35c66a00bbee0bb6185c4d37914a41c52c83ddf Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 12 Jan 2019 06:21:50 +1100 Subject: Remove duplicates in the user_ips table and add an index (#4370) --- .travis.yml | 3 + changelog.d/4370.misc | 1 + synapse/storage/client_ips.py | 138 ++++++++++++++++++++- synapse/storage/schema/delta/53/user_ips_index.sql | 26 ++++ 4 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 changelog.d/4370.misc create mode 100644 synapse/storage/schema/delta/53/user_ips_index.sql (limited to 'synapse') diff --git a/.travis.yml b/.travis.yml index 84d5efff9b..728f0e248a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,9 @@ cache: # - $HOME/.cache/pip/wheels +addons: + postgresql: "9.4" + # don't clone the whole repo history, one commit will do git: depth: 1 diff --git a/changelog.d/4370.misc b/changelog.d/4370.misc new file mode 100644 index 0000000000..047061ed3c --- /dev/null +++ b/changelog.d/4370.misc @@ -0,0 +1 @@ +Apply a unique index to the user_ips table, preventing duplicates. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 9ad17b7c25..5d548f250a 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -65,7 +65,27 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["last_seen"], ) - # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) + self.register_background_update_handler( + "user_ips_remove_dupes", + self._remove_user_ip_dupes, + ) + + # Register a unique index + self.register_background_index_update( + "user_ips_device_unique_index", + index_name="user_ips_user_token_ip_unique_index", + table="user_ips", + columns=["user_id", "access_token", "ip"], + unique=True, + ) + + # Drop the old non-unique index + self.register_background_update_handler( + "user_ips_drop_nonunique_index", + self._remove_user_ip_nonunique, + ) + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} self._client_ip_looper = self._clock.looping_call( @@ -75,6 +95,116 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) + @defer.inlineCallbacks + def _remove_user_ip_nonunique(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS user_ips_user_ip" + ) + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update("user_ips_drop_nonunique_index") + defer.returnValue(1) + + @defer.inlineCallbacks + def _remove_user_ip_dupes(self, progress, batch_size): + + last_seen_progress = progress.get("last_seen", 0) + + def get_last_seen(txn): + txn.execute( + """ + SELECT last_seen FROM user_ips + WHERE last_seen > ? + ORDER BY last_seen + LIMIT 1 + OFFSET ? + """, + (last_seen_progress, batch_size) + ) + results = txn.fetchone() + return results + + # Get a last seen that's sufficiently far away enough from the last one + last_seen = yield self.runInteraction( + "user_ips_dups_get_last_seen", get_last_seen + ) + + if not last_seen: + # If we get a None then we're reaching the end and just need to + # delete the last batch. + last = True + + # We fake not having an upper bound by using a future date, by + # just multiplying the current time by two.... + last_seen = int(self.clock.time_msec()) * 2 + else: + last = False + last_seen = last_seen[0] + + def remove(txn, last_seen_progress, last_seen): + # This works by looking at all entries in the given time span, and + # then for each (user_id, access_token, ip) tuple in that range + # checking for any duplicates in the rest of the table (via a join). + # It then only returns entries which have duplicates, and the max + # last_seen across all duplicates, which can the be used to delete + # all other duplicates. + # It is efficient due to the existence of (user_id, access_token, + # ip) and (last_seen) indices. + txn.execute( + """ + SELECT user_id, access_token, ip, + MAX(device_id), MAX(user_agent), MAX(last_seen) + FROM ( + SELECT user_id, access_token, ip + FROM user_ips + WHERE ? <= last_seen AND last_seen < ? + ORDER BY last_seen + ) c + INNER JOIN user_ips USING (user_id, access_token, ip) + GROUP BY user_id, access_token, ip + HAVING count(*) > 1""", + (last_seen_progress, last_seen) + ) + res = txn.fetchall() + + # We've got some duplicates + for i in res: + user_id, access_token, ip, device_id, user_agent, last_seen = i + + # Drop all the duplicates + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? + """, + (user_id, access_token, ip) + ) + + # Add in one to be the last_seen + txn.execute( + """ + INSERT INTO user_ips + (user_id, access_token, ip, device_id, user_agent, last_seen) + VALUES (?, ?, ?, ?, ?, ?) + """, + (user_id, access_token, ip, device_id, user_agent, last_seen) + ) + + self._background_update_progress_txn( + txn, "user_ips_remove_dupes", {"last_seen": last_seen} + ) + + yield self.runInteraction( + "user_ips_dups_remove", remove, last_seen_progress, last_seen + ) + if last: + yield self._end_background_update("user_ips_remove_dupes") + + defer.returnValue(batch_size) + @defer.inlineCallbacks def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, now=None): @@ -127,10 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "user_id": user_id, "access_token": access_token, "ip": ip, - "user_agent": user_agent, - "device_id": device_id, }, values={ + "user_agent": user_agent, + "device_id": device_id, "last_seen": last_seen, }, lock=False, @@ -227,7 +357,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): results = {} for key in self._batch_row_update: - uid, access_token, ip = key + uid, access_token, ip, = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/schema/delta/53/user_ips_index.sql new file mode 100644 index 0000000000..4ca346c111 --- /dev/null +++ b/synapse/storage/schema/delta/53/user_ips_index.sql @@ -0,0 +1,26 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- delete duplicates +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_remove_dupes', '{}'); + +-- add a new unique index to user_ips table +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); + +-- drop the old original index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); \ No newline at end of file -- cgit 1.5.1 From bb63e7ca4f989d980de9925adcf14f0393ac0406 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 15 Jan 2019 11:14:34 +0000 Subject: Add groundwork for new versions of federation APIs --- synapse/api/urls.py | 3 +- synapse/federation/transport/client.py | 132 +++++++++++++++++---------------- synapse/federation/transport/server.py | 6 +- 3 files changed, 73 insertions(+), 68 deletions(-) (limited to 'synapse') diff --git a/synapse/api/urls.py b/synapse/api/urls.py index f78695b657..ba019001ff 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -24,7 +24,8 @@ from synapse.config import ConfigError CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" -FEDERATION_PREFIX = "/_matrix/federation/v1" +FEDERATION_PREFIX = "/_matrix/federation" +FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index edba5a9808..260178c47b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -21,7 +21,7 @@ from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.util.logutils import log_function logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = _create_path(PREFIX, "/state/%s/", room_id) + path = _create_v1_path("/state/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -73,7 +73,7 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = _create_path(PREFIX, "/state_ids/%s/", room_id) + path = _create_v1_path("/state_ids/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -95,7 +95,7 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = _create_path(PREFIX, "/event/%s/", event_id) + path = _create_v1_path("/event/%s/", event_id) return self.client.get_json(destination, path=path, timeout=timeout) @log_function @@ -121,7 +121,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = _create_path(PREFIX, "/backfill/%s/", room_id) + path = _create_v1_path("/backfill/%s/", room_id) args = { "v": event_tuples, @@ -167,7 +167,7 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id) + path = _create_v1_path("/send/%s/", transaction.transaction_id) response = yield self.client.put_json( transaction.destination, @@ -184,7 +184,7 @@ class TransportLayerClient(object): @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False): - path = _create_path(PREFIX, "/query/%s", query_type) + path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( destination=destination, @@ -231,7 +231,7 @@ class TransportLayerClient(object): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id) + path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -258,7 +258,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_join(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id) + path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -271,7 +271,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id) + path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -290,7 +290,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id) + path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -306,7 +306,7 @@ class TransportLayerClient(object): def get_public_rooms(self, remote_server, limit, since_token, search_filter=None, include_all_networks=False, third_party_instance_id=None): - path = PREFIX + "/publicRooms" + path = _create_v1_path("/publicRooms") args = { "include_all_networks": "true" if include_all_networks else "false", @@ -332,7 +332,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) response = yield self.client.put_json( destination=destination, @@ -345,7 +345,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id) + path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) content = yield self.client.get_json( destination=destination, @@ -357,7 +357,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_query_auth(self, destination, room_id, event_id, content): - path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id) + path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( destination=destination, @@ -392,7 +392,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/user/keys/query" + path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( destination=destination, @@ -419,7 +419,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = _create_path(PREFIX, "/user/devices/%s", user_id) + path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( destination=destination, @@ -455,7 +455,7 @@ class TransportLayerClient(object): A dict containg the one-time keys. """ - path = PREFIX + "/user/keys/claim" + path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( destination=destination, @@ -469,7 +469,7 @@ class TransportLayerClient(object): @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth, timeout): - path = _create_path(PREFIX, "/get_missing_events/%s", room_id,) + path = _create_v1_path("/get_missing_events/%s", room_id,) content = yield self.client.post_json( destination=destination, @@ -489,7 +489,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id,) return self.client.get_json( destination=destination, @@ -508,7 +508,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id,) return self.client.post_json( destination=destination, @@ -522,7 +522,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_path(PREFIX, "/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id,) return self.client.get_json( destination=destination, @@ -535,7 +535,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_path(PREFIX, "/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id,) return self.client.get_json( destination=destination, @@ -548,7 +548,7 @@ class TransportLayerClient(object): content): """Add a room to a group """ - path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -562,8 +562,8 @@ class TransportLayerClient(object): config_key, content): """Update room in group """ - path = _create_path( - PREFIX, "/groups/%s/room/%s/config/%s", + path = _create_v1_path( + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key, ) @@ -578,7 +578,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -591,7 +591,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_path(PREFIX, "/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id,) return self.client.get_json( destination=destination, @@ -604,7 +604,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id,) return self.client.get_json( destination=destination, @@ -617,8 +617,8 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_path( - PREFIX, "/groups/%s/users/%s/accept_invite", + path = _create_v1_path( + "/groups/%s/users/%s/accept_invite", group_id, user_id, ) @@ -633,7 +633,7 @@ class TransportLayerClient(object): def join_group(self, destination, group_id, user_id, content): """Attempts to join a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( destination=destination, @@ -646,7 +646,7 @@ class TransportLayerClient(object): def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): """Invite a user to a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -662,7 +662,7 @@ class TransportLayerClient(object): invited. """ - path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id) + path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -676,7 +676,7 @@ class TransportLayerClient(object): user_id, content): """Remove a user fron a group """ - path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id) + path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -693,7 +693,7 @@ class TransportLayerClient(object): kicked from the group. """ - path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id) + path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -708,7 +708,7 @@ class TransportLayerClient(object): the attestations """ - path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id) + path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -723,12 +723,12 @@ class TransportLayerClient(object): """Update a room entry in a group summary """ if category_id: - path = _create_path( - PREFIX, "/groups/%s/summary/categories/%s/rooms/%s", + path = _create_v1_path( + "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -744,12 +744,12 @@ class TransportLayerClient(object): """Delete a room entry in a group summary """ if category_id: - path = _create_path( - PREFIX + "/groups/%s/summary/categories/%s/rooms/%s", + path = _create_v1_path( + "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -762,7 +762,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_path(PREFIX, "/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id,) return self.client.get_json( destination=destination, @@ -775,7 +775,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.get_json( destination=destination, @@ -789,7 +789,7 @@ class TransportLayerClient(object): content): """Update a category in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.post_json( destination=destination, @@ -804,7 +804,7 @@ class TransportLayerClient(object): category_id): """Delete a category in a group """ - path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) return self.client.delete_json( destination=destination, @@ -817,7 +817,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_path(PREFIX, "/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id,) return self.client.get_json( destination=destination, @@ -830,7 +830,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.get_json( destination=destination, @@ -844,7 +844,7 @@ class TransportLayerClient(object): content): """Update a role in a group """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.post_json( destination=destination, @@ -858,7 +858,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) return self.client.delete_json( destination=destination, @@ -873,12 +873,12 @@ class TransportLayerClient(object): """Update a users entry in a group """ if role_id: - path = _create_path( - PREFIX, "/groups/%s/summary/roles/%s/users/%s", + path = _create_v1_path( + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) return self.client.post_json( destination=destination, @@ -893,7 +893,7 @@ class TransportLayerClient(object): content): """Sets the join policy for a group """ - path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) return self.client.put_json( destination=destination, @@ -909,12 +909,12 @@ class TransportLayerClient(object): """Delete a users entry in a group """ if role_id: - path = _create_path( - PREFIX, "/groups/%s/summary/roles/%s/users/%s", + path = _create_v1_path( + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) return self.client.delete_json( destination=destination, @@ -927,7 +927,7 @@ class TransportLayerClient(object): """Get the groups a list of users are publicising """ - path = PREFIX + "/get_groups_publicised" + path = _create_v1_path("/get_groups_publicised") content = {"user_ids": user_ids} @@ -939,20 +939,22 @@ class TransportLayerClient(object): ) -def _create_path(prefix, path, *args): - """Creates a path from the prefix, path template and args. Ensures that - all args are url encoded. +def _create_v1_path(path, *args): + """Creates a path against V1 federation API from the path template and + args. Ensures that all args are url encoded. Example: - _create_path(PREFIX, "/event/%s/", event_id) + _create_v1_path("/event/%s/", event_id) Args: - prefix (str) path (str): String template for the path args: ([str]): Args to insert into path. Each arg will be url encoded Returns: str """ - return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return ( + FEDERATION_V1_PREFIX + + path % tuple(urllib.parse.quote(arg, "") for arg in args) + ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 3553f418f1..e592269cf4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -22,7 +22,7 @@ from twisted.internet import defer import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -227,6 +227,8 @@ class BaseFederationServlet(object): """ REQUIRE_AUTH = True + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator @@ -286,7 +288,7 @@ class BaseFederationServlet(object): return new_func def register(self, server): - pattern = re.compile("^" + PREFIX + self.PATH + "$") + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): code = getattr(self, "on_%s" % (method), None) -- cgit 1.5.1 From 4a4d2e17bc28562cc6d0de55bc8c8c05335414cd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 15 Jan 2019 13:22:44 +0000 Subject: Add /v2/invite federation API --- synapse/api/urls.py | 1 + synapse/federation/federation_server.py | 4 ++-- synapse/federation/transport/server.py | 42 +++++++++++++++++++++++++++++---- 3 files changed, 41 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/api/urls.py b/synapse/api/urls.py index ba019001ff..8102176653 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -26,6 +26,7 @@ CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" FEDERATION_PREFIX = "/_matrix/federation" FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" +FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 98722ae543..37d29e7027 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -369,13 +369,13 @@ class FederationServer(FederationBase): }) @defer.inlineCallbacks - def on_invite_request(self, origin, content): + def on_invite_request(self, origin, content, room_version): pdu = event_from_pdu_json(content) origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) + defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)}) @defer.inlineCallbacks def on_send_join_request(self, origin, content): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index e592269cf4..4557a9e66e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -21,8 +21,9 @@ import re from twisted.internet import defer import synapse +from synapse.api.constants import RoomVersions from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.urls import FEDERATION_V1_PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -490,14 +491,46 @@ class FederationSendJoinServlet(BaseFederationServlet): defer.returnValue((200, content)) -class FederationInviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + content = yield self.handler.on_invite_request( + origin, content, room_version=RoomVersions.V1, + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + defer.returnValue((200, (200, content))) + + +class FederationV2InviteServlet(BaseFederationServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content - content = yield self.handler.on_invite_request(origin, content) + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + content = yield self.handler.on_invite_request( + origin, event, room_version=room_version, + ) defer.returnValue((200, content)) @@ -1265,7 +1298,8 @@ FEDERATION_SERVLET_CLASSES = ( FederationEventServlet, FederationSendJoinServlet, FederationSendLeaveServlet, - FederationInviteServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, -- cgit 1.5.1 From 9ec56d693577232dbb9f75f58a44c01a999fbea9 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 15 Jan 2019 14:38:15 +0000 Subject: ALL_USER_TYPES should be a tuple --- changelog.d/4392.bugfix | 1 + synapse/api/constants.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4392.bugfix (limited to 'synapse') diff --git a/changelog.d/4392.bugfix b/changelog.d/4392.bugfix new file mode 100644 index 0000000000..2223f7dcd6 --- /dev/null +++ b/changelog.d/4392.bugfix @@ -0,0 +1 @@ +Fix typo in ALL_USER_TYPES definition to ensure type is a tuple diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 87bc1cb53d..022f772714 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -128,4 +128,4 @@ class UserTypes(object): 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ SUPPORT = "support" - ALL_USER_TYPES = (SUPPORT) + ALL_USER_TYPES = (SUPPORT,) -- cgit 1.5.1 From fab948120f88e86df73d6bdd46b22bc5bc8583f9 Mon Sep 17 00:00:00 2001 From: Andrej Shadura Date: Wed, 16 Jan 2019 09:41:29 +0100 Subject: Use msgpack instead of msgpack-python The package msgpack-python has been deprecated. Signed-off-by: Andrej Shadura --- synapse/python_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 69c5f9fe2e..c7dd4917d8 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -53,7 +53,7 @@ REQUIREMENTS = [ "sortedcontainers>=1.4.4", "psutil>=2.0.0", "pymacaroons-pynacl>=0.9.3", - "msgpack-python>=0.4.2", + "msgpack>=0.5.0", "phonenumbers>=8.2.0", "six>=1.10", # prometheus_client 0.4.0 changed the format of counter metrics -- cgit 1.5.1 From 64cf6788d91dbd2a1a3b3613881e9c5609dc10af Mon Sep 17 00:00:00 2001 From: Andrej Shadura Date: Wed, 16 Jan 2019 09:44:22 +0100 Subject: Depend on pymacaroons >= 0.13.0 instead on pymacaroons-pynacl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since 0.13.0, pymacaroons works correctly with pynacl, so there isn’t any more reason to depend on an outdated pynacl fork. Signed-off-by: Andrej Shadura --- synapse/python_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index c7dd4917d8..2fe790f95d 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -52,7 +52,7 @@ REQUIREMENTS = [ "pillow>=3.1.2", "sortedcontainers>=1.4.4", "psutil>=2.0.0", - "pymacaroons-pynacl>=0.9.3", + "pymacaroons>=0.13.0", "msgpack>=0.5.0", "phonenumbers>=8.2.0", "six>=1.10", -- cgit 1.5.1 From 05e129664931c114fcaae8bebe0a26685dcd9c6d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 16 Jan 2019 23:14:11 +0000 Subject: don't store more remote device lists if they have more than 1K devices (#4397) --- changelog.d/4397.bugfix | 1 + synapse/handlers/device.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 changelog.d/4397.bugfix (limited to 'synapse') diff --git a/changelog.d/4397.bugfix b/changelog.d/4397.bugfix new file mode 100644 index 0000000000..e7526d4454 --- /dev/null +++ b/changelog.d/4397.bugfix @@ -0,0 +1 @@ +Fix high CPU usage due to remote devicelist updates diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9e017116a9..8955cde4ed 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -532,6 +532,25 @@ class DeviceListEduUpdater(object): stream_id = result["stream_id"] devices = result["devices"] + + # If the remote server has more than ~1000 devices for this user + # we assume that something is going horribly wrong (e.g. a bot + # that logs in and creates a new device every time it tries to + # send a message). Maintaining lots of devices per user in the + # cache can cause serious performance issues as if this request + # takes more than 60s to complete, internal replication from the + # inbound federation worker to the synapse master may time out + # causing the inbound federation to fail and causing the remote + # server to retry, causing a DoS. So in this scenario we give + # up on storing the total list of devices and only handle the + # delta instead. + if len(devices) > 1000: + logger.warn( + "Ignoring device list snapshot for %s as it has >1K devs (%d)", + user_id, len(devices) + ) + devices = [] + yield self.store.update_remote_device_list_cache( user_id, devices, stream_id, ) -- cgit 1.5.1 From 3982a6ee078bdec3b25ab7850161cb0cb7157590 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 16 Jan 2019 23:14:41 +0000 Subject: Changing macaroon_secret_key no longer logs you out (#4387) --- changelog.d/4387.misc | 1 + synapse/config/key.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) create mode 100644 changelog.d/4387.misc (limited to 'synapse') diff --git a/changelog.d/4387.misc b/changelog.d/4387.misc new file mode 100644 index 0000000000..0c04a0fa9b --- /dev/null +++ b/changelog.d/4387.misc @@ -0,0 +1 @@ +Fix a comment in the generated config file diff --git a/synapse/config/key.py b/synapse/config/key.py index 3b11f0cfa9..dce4b19a2d 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -83,9 +83,6 @@ class KeyConfig(Config): # a secret which is used to sign access tokens. If none is specified, # the registration_shared_secret is used, if one is given; otherwise, # a secret key is derived from the signing key. - # - # Note that changing this will invalidate any active access tokens, so - # all clients will have to log back in. %(macaroon_secret_key)s # Used to enable access token expiration. -- cgit 1.5.1 From 9feb5d0b71104bea4e366d451d5dddd447e16196 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Jan 2019 12:40:09 +0000 Subject: sign_request -> build_auth_headers (#4408) Just got very confused about the fact that the headers are only an output, not an input. --- changelog.d/4408.misc | 1 + synapse/handlers/identity.py | 9 ++++++--- synapse/http/matrixfederationclient.py | 23 +++++++++++------------ 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 changelog.d/4408.misc (limited to 'synapse') diff --git a/changelog.d/4408.misc b/changelog.d/4408.misc new file mode 100644 index 0000000000..729bafd62e --- /dev/null +++ b/changelog.d/4408.misc @@ -0,0 +1 @@ +Refactor 'sign_request' as 'build_auth_headers' \ No newline at end of file diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5feb3f22a6..39184f0e22 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler): "mxid": mxid, "threepid": threepid, } - headers = {} + # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. - self.federation_http_client.sign_request( + auth_headers = self.federation_http_client.build_auth_headers( destination=None, method='POST', url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), - headers_dict=headers, content=content, destination_is=id_server, ) + headers = { + b"Authorization": auth_headers, + } + try: yield self.http_client.post_json_get_json( url, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index f2a42f97a6..ea2fc64b99 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -298,9 +298,9 @@ class MatrixFederationHttpClient(object): json = request.get_json() if json: headers_dict[b"Content-Type"] = [b"application/json"] - self.sign_request( + auth_headers = self.build_auth_headers( destination_bytes, method_bytes, url_to_sign_bytes, - headers_dict, json, + json, ) data = encode_canonical_json(json) producer = FileBodyProducer( @@ -309,11 +309,12 @@ class MatrixFederationHttpClient(object): ) else: producer = None - self.sign_request( + auth_headers = self.build_auth_headers( destination_bytes, method_bytes, url_to_sign_bytes, - headers_dict, ) + headers_dict[b"Authorization"] = auth_headers + logger.info( "{%s} [%s] Sending request: %s %s", request.txn_id, request.destination, request.method, @@ -440,24 +441,23 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) - def sign_request(self, destination, method, url_bytes, headers_dict, - content=None, destination_is=None): + def build_auth_headers( + self, destination, method, url_bytes, content=None, destination_is=None, + ): """ - Signs a request by adding an Authorization header to headers_dict + Builds the Authorization headers for a federation request Args: destination (bytes|None): The desination home server of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. method (bytes): The HTTP method of the request url_bytes (bytes): The URI path of the request - headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to - append to content (object): The body of the request destination_is (bytes): As 'destination', but if the destination is an identity server Returns: - None + list[bytes]: a list of headers to be added as "Authorization:" headers """ request = { "method": method, @@ -484,8 +484,7 @@ class MatrixFederationHttpClient(object): self.server_name, key, sig, )).encode('ascii') ) - - headers_dict[b"Authorization"] = auth_headers + return auth_headers @defer.inlineCallbacks def put_json(self, destination, path, args={}, data={}, -- cgit 1.5.1 From 676cf2ee267cedd05109678dece0227fd1801423 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Jan 2019 14:00:23 +0000 Subject: Fix incorrect logcontexts after a Deferred was cancelled (#4407) --- changelog.d/4407.bugfix | 1 + synapse/python_dependencies.py | 11 ++++- synapse/util/async_helpers.py | 4 +- tests/util/test_async_utils.py | 104 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 changelog.d/4407.bugfix create mode 100644 tests/util/test_async_utils.py (limited to 'synapse') diff --git a/changelog.d/4407.bugfix b/changelog.d/4407.bugfix new file mode 100644 index 0000000000..54c5e76d1f --- /dev/null +++ b/changelog.d/4407.bugfix @@ -0,0 +1 @@ +Fix incorrect logcontexts after a Deferred was cancelled diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2fe790f95d..882e844eb1 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -40,7 +40,11 @@ REQUIREMENTS = [ "signedjson>=1.0.0", "pynacl>=1.2.1", "service_identity>=16.0.0", - "Twisted>=17.1.0", + + # our logcontext handling relies on the ability to cancel inlineCallbacks + # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. + "Twisted>=18.7.0", + "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", @@ -59,8 +63,11 @@ REQUIREMENTS = [ # prometheus_client 0.4.0 changed the format of counter metrics # (cf https://github.com/matrix-org/synapse/issues/4001) "prometheus_client>=0.0.18,<0.4.0", + # we use attr.s(slots), which arrived in 16.0.0 - "attrs>=16.0.0", + # Twisted 18.7.0 requires attrs>=17.4.0 + "attrs>=17.4.0", + "netaddr>=0.7.18", ] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index ec7b2c9672..430bb15f51 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -387,12 +387,14 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred that wraps and times out the given deferred, correctly handling the case where the given deferred's canceller throws. + (See https://twistedmatrix.com/trac/ticket/9534) + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: deferred (Deferred) timeout (float): Timeout in seconds - reactor (twisted.internet.reactor): The twisted reactor to use + reactor (twisted.interfaces.IReactorTime): The twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py new file mode 100644 index 0000000000..84dd71e47a --- /dev/null +++ b/tests/util/test_async_utils.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.task import Clock + +from synapse.util import logcontext +from synapse.util.async_helpers import timeout_deferred +from synapse.util.logcontext import LoggingContext + +from tests.unittest import TestCase + + +class TimeoutDeferredTest(TestCase): + def setUp(self): + self.clock = Clock() + + def test_times_out(self): + """Basic test case that checks that the original deferred is cancelled and that + the timing-out deferred is errbacked + """ + cancelled = [False] + + def canceller(_d): + cancelled[0] = True + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + self.assertFalse(cancelled[0], "deferred was cancelled prematurely") + + self.clock.pump((1.0, )) + + self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + + def test_times_out_when_canceller_throws(self): + """Test that we have successfully worked around + https://twistedmatrix.com/trac/ticket/9534""" + + def canceller(_d): + raise Exception("can't cancel this deferred") + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + + self.clock.pump((1.0, )) + + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + + def test_logcontext_is_preserved_on_cancellation(self): + blocking_was_cancelled = [False] + + @defer.inlineCallbacks + def blocking(): + non_completing_d = Deferred() + with logcontext.PreserveLoggingContext(): + try: + yield non_completing_d + except CancelledError: + blocking_was_cancelled[0] = True + raise + + with logcontext.LoggingContext("one") as context_one: + # the errbacks should be run in the test logcontext + def errback(res, deferred_name): + self.assertIs( + LoggingContext.current_context(), context_one, + "errback %s run in unexpected logcontext %s" % ( + deferred_name, LoggingContext.current_context(), + ) + ) + return res + + original_deferred = blocking() + original_deferred.addErrback(errback, "orig") + timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) + self.assertNoResult(timing_out_d) + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + timing_out_d.addErrback(errback, "timingout") + + self.clock.pump((1.0, )) + + self.assertTrue( + blocking_was_cancelled[0], + "non-completing deferred was not cancelled", + ) + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.assertIs(LoggingContext.current_context(), context_one) -- cgit 1.5.1 From de6888e7ce7ec58e8b7935e35bb6572504cae873 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 18 Jan 2019 12:07:38 +0000 Subject: Remove redundant WrappedConnection (#4409) * Remove redundant WrappedConnection The matrix federation client uses an HTTP connection pool, which times out its idle HTTP connections, so there is no need for any of this business. --- changelog.d/4409.misc | 1 + synapse/http/endpoint.py | 75 ++-------------------------------- synapse/http/matrixfederationclient.py | 30 +++++++------- tests/http/test_fedclient.py | 54 ++++++++++++++++++++---- 4 files changed, 67 insertions(+), 93 deletions(-) create mode 100644 changelog.d/4409.misc (limited to 'synapse') diff --git a/changelog.d/4409.misc b/changelog.d/4409.misc new file mode 100644 index 0000000000..9cf2adfbb1 --- /dev/null +++ b/changelog.d/4409.misc @@ -0,0 +1 @@ +Remove redundant federation connection wrapping code diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index f86a0b624e..1c3b7ea28a 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -140,82 +140,15 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= default_port = 8448 if port is None: - return _WrappingEndpointFac(SRVClientEndpoint( + return SRVClientEndpoint( reactor, "matrix", domain, protocol="tcp", default_port=default_port, endpoint=transport_endpoint, endpoint_kw_args=endpoint_kw_args - ), reactor) + ) else: - return _WrappingEndpointFac(transport_endpoint( + return transport_endpoint( reactor, domain, port, **endpoint_kw_args - ), reactor) - - -class _WrappingEndpointFac(object): - def __init__(self, endpoint_fac, reactor): - self.endpoint_fac = endpoint_fac - self.reactor = reactor - - @defer.inlineCallbacks - def connect(self, protocolFactory): - conn = yield self.endpoint_fac.connect(protocolFactory) - conn = _WrappedConnection(conn, self.reactor) - defer.returnValue(conn) - - -class _WrappedConnection(object): - """Wraps a connection and calls abort on it if it hasn't seen any action - for 2.5-3 minutes. - """ - __slots__ = ["conn", "last_request"] - - def __init__(self, conn, reactor): - object.__setattr__(self, "conn", conn) - object.__setattr__(self, "last_request", time.time()) - self._reactor = reactor - - def __getattr__(self, name): - return getattr(self.conn, name) - - def __setattr__(self, name, value): - setattr(self.conn, name, value) - - def _time_things_out_maybe(self): - # We use a slightly shorter timeout here just in case the callLater is - # triggered early. Paranoia ftw. - # TODO: Cancel the previous callLater rather than comparing time.time()? - if time.time() - self.last_request >= 2.5 * 60: - self.abort() - # Abort the underlying TLS connection. The abort() method calls - # loseConnection() on the TLS connection which tries to - # shutdown the connection cleanly. We call abortConnection() - # since that will promptly close the TLS connection. - # - # In Twisted >18.4; the TLS connection will be None if it has closed - # which will make abortConnection() throw. Check that the TLS connection - # is not None before trying to close it. - if self.transport.getHandle() is not None: - self.transport.abortConnection() - - def request(self, request): - self.last_request = time.time() - - # Time this connection out if we haven't send a request in the last - # N minutes - # TODO: Cancel the previous callLater? - self._reactor.callLater(3 * 60, self._time_things_out_maybe) - - d = self.conn.request(request) - - def update_request_time(res): - self.last_request = time.time() - # TODO: Cancel the previous callLater? - self._reactor.callLater(3 * 60, self._time_things_out_maybe) - return res - - d.addCallback(update_request_time) - - return d + ) class SRVClientEndpoint(object): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ea2fc64b99..250bb1ef91 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -321,23 +321,23 @@ class MatrixFederationHttpClient(object): url_str, ) - # we don't want all the fancy cookie and redirect handling that - # treq.request gives: just use the raw Agent. - request_deferred = self.agent.request( - method_bytes, - url_bytes, - headers=Headers(headers_dict), - bodyProducer=producer, - ) - - request_deferred = timeout_deferred( - request_deferred, - timeout=_sec_timeout, - reactor=self.hs.get_reactor(), - ) - try: with Measure(self.clock, "outbound_request"): + # we don't want all the fancy cookie and redirect handling + # that treq.request gives: just use the raw Agent. + request_deferred = self.agent.request( + method_bytes, + url_bytes, + headers=Headers(headers_dict), + bodyProducer=producer, + ) + + request_deferred = timeout_deferred( + request_deferred, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + response = yield make_deferred_yieldable( request_deferred, ) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index b2e38276d8..8426eee400 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -17,6 +17,7 @@ from mock import Mock from twisted.internet.defer import TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError +from twisted.test.proto_helpers import StringTransport from twisted.web.client import ResponseNeverReceived from twisted.web.http import HTTPChannel @@ -44,7 +45,7 @@ class FederationClientTests(HomeserverTestCase): def test_dns_error(self): """ - If the DNS raising returns an error, it will bubble up. + If the DNS lookup returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() @@ -63,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.pump() # Nothing happened yet - self.assertFalse(d.called) + self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -72,7 +73,7 @@ class FederationClientTests(HomeserverTestCase): self.assertEqual(clients[0][1], 8008) # Deferred is still without a result - self.assertFalse(d.called) + self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) @@ -94,7 +95,7 @@ class FederationClientTests(HomeserverTestCase): self.pump() # Nothing happened yet - self.assertFalse(d.called) + self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -107,7 +108,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred is still without a result - self.assertFalse(d.called) + self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) @@ -135,7 +136,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred does not have a result - self.assertFalse(d.called) + self.assertNoResult(d) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") @@ -159,7 +160,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred does not have a result - self.assertFalse(d.called) + self.assertNoResult(d) # Send it the HTTP response client.dataReceived( @@ -195,3 +196,42 @@ class FederationClientTests(HomeserverTestCase): request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}') + + def test_closes_connection(self): + """Check that the client closes unused HTTP connections""" + d = self.cl.get_json("testserv:8008", "foo/bar") + + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + self.assertFalse(conn.disconnecting) + + # wait for a while + self.pump(120) + + self.assertTrue(conn.disconnecting) -- cgit 1.5.1 From 25dd56ace36e9d3dbb717a34c9b9051b243f84ad Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Jan 2019 12:17:04 +0000 Subject: Fix race when persisting create event (#4404) * Fix race when persisting create event When persisting a chunk of DAG it is sometimes requried to do a state resolution, which requires knowledge of the room version. If this happens while we're persisting the create event then we need to use that event rather than attempting to look it up in the database. --- changelog.d/4404.bugfix | 1 + synapse/storage/events.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4404.bugfix (limited to 'synapse') diff --git a/changelog.d/4404.bugfix b/changelog.d/4404.bugfix new file mode 100644 index 0000000000..5d40a3a60b --- /dev/null +++ b/changelog.d/4404.bugfix @@ -0,0 +1 @@ +Fix potential bug where creating or joining a room could fail diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2047110b1d..79e0276de6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -739,7 +739,18 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore } events_map = {ev.event_id: ev for ev, _ in events_context} - room_version = yield self.get_room_version(room_id) + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.get_room_version(room_id) logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( -- cgit 1.5.1 From 702c4b750c4db529d3789a899aa9badfa8c9df6e Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 21 Jan 2019 09:42:59 +0000 Subject: Migrate encryption state on room upgrade (#4411) * Migrate encryption state on room upgrade Signed-off-by: Andrew Morgan * Add changelog file --- changelog.d/4411.bugfix | 1 + synapse/api/constants.py | 1 + synapse/handlers/room.py | 1 + 3 files changed, 3 insertions(+) create mode 100644 changelog.d/4411.bugfix (limited to 'synapse') diff --git a/changelog.d/4411.bugfix b/changelog.d/4411.bugfix new file mode 100644 index 0000000000..219e98a924 --- /dev/null +++ b/changelog.d/4411.bugfix @@ -0,0 +1 @@ +Ensure encrypted room state is persisted across room upgrades. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 022f772714..46c4b4b9dc 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -68,6 +68,7 @@ class EventTypes(object): Aliases = "m.room.aliases" Redaction = "m.room.redaction" ThirdPartyInvite = "m.room.third_party_invite" + Encryption = "m.room.encryption" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 581e96c743..cb8c5f77dd 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -269,6 +269,7 @@ class RoomCreationHandler(BaseHandler): (EventTypes.RoomHistoryVisibility, ""), (EventTypes.GuestAccess, ""), (EventTypes.RoomAvatar, ""), + (EventTypes.Encryption, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( -- cgit 1.5.1 From 534926230257ef65b683f439dd71156c359d5004 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 21 Jan 2019 14:59:37 +0000 Subject: Config option to disable requesting MSISDN on registration --- changelog.d/4423.feature | 1 + synapse/config/registration.py | 7 +++++++ synapse/rest/client/v2_alpha/register.py | 16 +++++----------- 3 files changed, 13 insertions(+), 11 deletions(-) create mode 100644 changelog.d/4423.feature (limited to 'synapse') diff --git a/changelog.d/4423.feature b/changelog.d/4423.feature new file mode 100644 index 0000000000..74aeab6d39 --- /dev/null +++ b/changelog.d/4423.feature @@ -0,0 +1 @@ +Config option to disable requesting MSISDN on registration. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 6c2b543b8c..e725773735 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -50,6 +50,8 @@ class RegistrationConfig(Config): raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.disable_msisdn_registration = config.get("disable_msisdn_registration", False) + def default_config(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -70,6 +72,11 @@ class RegistrationConfig(Config): # - email # - msisdn + # Explicitly disable asking for MSISDNs from the registration + # flow (overrides registrations_require_3pid if MSISDNs are set as required) + # + # disable_msisdn_registration = True + # Mandate that users are only allowed to associate certain formats of # 3PIDs with accounts on this server. # diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index aec0c6b075..14025cd219 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - # Only give msisdn flows if the x_show_msisdn flag is given: - # this is a hack to work around the fact that clients were shipped - # that use fallback registration if they see any flows that they don't - # recognise, which means we break registration for these clients if we - # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot - # Android <=0.6.9 have fallen below an acceptable threshold, this - # parameter should go away and we should always advertise msisdn flows. - show_msisdn = False - if 'x_show_msisdn' in body and body['x_show_msisdn']: - show_msisdn = True - # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one require_email = 'email' in self.hs.config.registrations_require_3pid require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + show_msisdn = True + if self.hs.config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + flows = [] if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required -- cgit 1.5.1 From 1b53cc3cb4372d624194360b6039c5f7d05a2007 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 21 Jan 2019 15:17:20 +0000 Subject: fix line length --- synapse/config/registration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index e725773735..fe520d6855 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -50,7 +50,9 @@ class RegistrationConfig(Config): raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_msisdn_registration = config.get("disable_msisdn_registration", False) + self.disable_msisdn_registration = ( + config.get("disable_msisdn_registration", False) + ) def default_config(self, generate_secrets=False, **kwargs): if generate_secrets: -- cgit 1.5.1 From 23b08135998e932d5d600941bd42389db0628a11 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 22 Jan 2019 21:58:50 +1100 Subject: Require ECDH key exchange & remove dh_params (#4429) * remove dh_params and set better cipher string --- README.rst | 2 +- changelog.d/4229.feature | 1 + debian/homeserver.yaml | 3 --- demo/demo.tls.dh | 9 --------- docker/conf/homeserver.yaml | 1 - synapse/config/tls.py | 40 --------------------------------------- synapse/crypto/context_factory.py | 6 ++++-- tests/config/test_generate.py | 1 - 8 files changed, 6 insertions(+), 57 deletions(-) create mode 100644 changelog.d/4229.feature delete mode 100644 demo/demo.tls.dh (limited to 'synapse') diff --git a/README.rst b/README.rst index 8bff55e78e..05a3bb3751 100644 --- a/README.rst +++ b/README.rst @@ -220,7 +220,7 @@ is configured to use TLS with a self-signed certificate. If you would like to do initial test with a client without having to setup a reverse proxy, you can temporarly use another certificate. (Note that a self-signed certificate is fine for `Federation`_). You can do so by changing -``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path`` +``tls_certificate_path`` and ``tls_private_key_path`` in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure to read `Using a reverse proxy with Synapse`_ when doing so. diff --git a/changelog.d/4229.feature b/changelog.d/4229.feature new file mode 100644 index 0000000000..0d1996c7e8 --- /dev/null +++ b/changelog.d/4229.feature @@ -0,0 +1 @@ +Synapse's cipher string has been updated to require ECDH key exchange. Configuring and generating dh_params is no longer required, and they will be ignored. diff --git a/debian/homeserver.yaml b/debian/homeserver.yaml index 188a2d5483..0bb2d22a95 100644 --- a/debian/homeserver.yaml +++ b/debian/homeserver.yaml @@ -9,9 +9,6 @@ tls_certificate_path: "/etc/matrix-synapse/homeserver.tls.crt" # PEM encoded private key for TLS tls_private_key_path: "/etc/matrix-synapse/homeserver.tls.key" -# PEM dh parameters for ephemeral keys -tls_dh_params_path: "/etc/matrix-synapse/homeserver.tls.dh" - # Don't bind to the https port no_tls: False diff --git a/demo/demo.tls.dh b/demo/demo.tls.dh deleted file mode 100644 index cbc58272a0..0000000000 --- a/demo/demo.tls.dh +++ /dev/null @@ -1,9 +0,0 @@ -2048-bit DH parameters taken from rfc3526 ------BEGIN DH PARAMETERS----- -MIIBCAKCAQEA///////////JD9qiIWjCNMTGYouA3BzRKQJOCIpnzHQCC76mOxOb -IlFKCHmONATd75UZs806QxswKwpt8l8UN0/hNW1tUcJF5IW1dmJefsb0TELppjft -awv/XLb0Brft7jhr+1qJn6WunyQRfEsf5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT -mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVSu57VKQdwlpZtZww1Tkq8mATxdGwIyhgh -fDKQXkYuNs474553LBgOhgObJ4Oi7Aeij7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq -5RXSJhiY+gUQFXKOWoqsqmj//////////wIBAg== ------END DH PARAMETERS----- diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index c2b8576a32..529118d184 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -4,7 +4,6 @@ tls_certificate_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.crt" tls_private_key_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.key" -tls_dh_params_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.dh" no_tls: {{ "True" if SYNAPSE_NO_TLS else "False" }} tls_fingerprints: [] diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fef1ea99cb..bb8952c672 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -14,7 +14,6 @@ # limitations under the License. import os -import subprocess from hashlib import sha256 from unpaddedbase64 import encode_base64 @@ -23,8 +22,6 @@ from OpenSSL import crypto from ._base import Config -GENERATE_DH_PARAMS = False - class TlsConfig(Config): def read_config(self, config): @@ -42,10 +39,6 @@ class TlsConfig(Config): config.get("tls_private_key_path") ) - self.tls_dh_params_path = self.check_file( - config.get("tls_dh_params_path"), "tls_dh_params" - ) - self.tls_fingerprints = config["tls_fingerprints"] # Check that our own certificate is included in the list of fingerprints @@ -72,7 +65,6 @@ class TlsConfig(Config): tls_certificate_path = base_key_name + ".tls.crt" tls_private_key_path = base_key_name + ".tls.key" - tls_dh_params_path = base_key_name + ".tls.dh" return """\ # PEM encoded X509 certificate for TLS. @@ -85,9 +77,6 @@ class TlsConfig(Config): # PEM encoded private key for TLS tls_private_key_path: "%(tls_private_key_path)s" - # PEM dh parameters for ephemeral keys - tls_dh_params_path: "%(tls_dh_params_path)s" - # Don't bind to the https port no_tls: False @@ -131,7 +120,6 @@ class TlsConfig(Config): def generate_files(self, config): tls_certificate_path = config["tls_certificate_path"] tls_private_key_path = config["tls_private_key_path"] - tls_dh_params_path = config["tls_dh_params_path"] if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "wb") as private_key_file: @@ -165,31 +153,3 @@ class TlsConfig(Config): cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) certificate_file.write(cert_pem) - - if not self.path_exists(tls_dh_params_path): - if GENERATE_DH_PARAMS: - subprocess.check_call([ - "openssl", "dhparam", - "-outform", "PEM", - "-out", tls_dh_params_path, - "2048" - ]) - else: - with open(tls_dh_params_path, "w") as dh_params_file: - dh_params_file.write( - "2048-bit DH parameters taken from rfc3526\n" - "-----BEGIN DH PARAMETERS-----\n" - "MIIBCAKCAQEA///////////JD9qiIWjC" - "NMTGYouA3BzRKQJOCIpnzHQCC76mOxOb\n" - "IlFKCHmONATd75UZs806QxswKwpt8l8U" - "N0/hNW1tUcJF5IW1dmJefsb0TELppjft\n" - "awv/XLb0Brft7jhr+1qJn6WunyQRfEsf" - "5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT\n" - "mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVS" - "u57VKQdwlpZtZww1Tkq8mATxdGwIyhgh\n" - "fDKQXkYuNs474553LBgOhgObJ4Oi7Aei" - "j7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq\n" - "5RXSJhiY+gUQFXKOWoqsqmj/////////" - "/wIBAg==\n" - "-----END DH PARAMETERS-----\n" - ) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 02b76dfcfb..6ba3eca7b2 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -46,8 +46,10 @@ class ServerContextFactory(ContextFactory): if not config.no_tls: context.use_privatekey(config.tls_private_key) - context.load_tmp_dh(config.tls_dh_params_path) - context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") + # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ + context.set_cipher_list( + "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1" + ) def getContext(self): return self._context diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 0c23068bcf..b5ad99348d 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -51,7 +51,6 @@ class ConfigGenerationTestCase(unittest.TestCase): "lemurs.win.log.config", "lemurs.win.signing.key", "lemurs.win.tls.crt", - "lemurs.win.tls.dh", "lemurs.win.tls.key", ] ), -- cgit 1.5.1 From 33a55289cb887df88f7b506a8af2ebabd51c59b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 22 Jan 2019 10:59:27 +0000 Subject: Refactor and bugfix for resove_service (#4427) --- changelog.d/4427.misc | 1 + synapse/http/endpoint.py | 75 +---------- synapse/http/federation/__init__.py | 14 ++ synapse/http/federation/srv_resolver.py | 124 +++++++++++++++++ tests/http/federation/__init__.py | 14 ++ tests/http/federation/test_srv_resolver.py | 209 +++++++++++++++++++++++++++++ tests/test_dns.py | 129 ------------------ 7 files changed, 365 insertions(+), 201 deletions(-) create mode 100644 changelog.d/4427.misc create mode 100644 synapse/http/federation/__init__.py create mode 100644 synapse/http/federation/srv_resolver.py create mode 100644 tests/http/federation/__init__.py create mode 100644 tests/http/federation/test_srv_resolver.py delete mode 100644 tests/test_dns.py (limited to 'synapse') diff --git a/changelog.d/4427.misc b/changelog.d/4427.misc new file mode 100644 index 0000000000..75500bdbc2 --- /dev/null +++ b/changelog.d/4427.misc @@ -0,0 +1 @@ +Refactor and cleanup for SRV record lookup diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 1c3b7ea28a..815f8ff2f7 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,29 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging import random import re -import time from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.error import ConnectError -from twisted.names import client, dns -from twisted.names.error import DNSNameError, DomainError -logger = logging.getLogger(__name__) - -SERVER_CACHE = {} +from synapse.http.federation.srv_resolver import Server, resolve_service -# our record of an individual server which can be tried to reach a destination. -# -# "host" is the hostname acquired from the SRV record. Except when there's -# no SRV record, in which case it is the original hostname. -_Server = collections.namedtuple( - "_Server", "priority weight host port expires" -) +logger = logging.getLogger(__name__) def parse_server_name(server_name): @@ -165,12 +153,9 @@ class SRVClientEndpoint(object): self.service_name = "_%s._%s.%s" % (service, protocol, domain) if default_port is not None: - self.default_server = _Server( + self.default_server = Server( host=domain, port=default_port, - priority=0, - weight=0, - expires=0, ) else: self.default_server = None @@ -240,57 +225,3 @@ class SRVClientEndpoint(object): ) connection = yield endpoint.connect(protocolFactory) defer.returnValue(connection) - - -@defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): - cache_entry = cache.get(service_name, None) - if cache_entry: - if all(s.expires > int(clock.time()) for s in cache_entry): - servers = list(cache_entry) - defer.returnValue(servers) - - servers = [] - - try: - try: - answers, _, _ = yield dns_client.lookupService(service_name) - except DNSNameError: - defer.returnValue([]) - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): - raise ConnectError("Service %s unavailable" % service_name) - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - - payload = answer.payload - - servers.append(_Server( - host=str(payload.target), - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + answer.ttl, - )) - - servers.sort() - cache[service_name] = list(servers) - except DomainError as e: - # We failed to resolve the name (other than a NameError) - # Try something in the cache, else rereaise - cache_entry = cache.get(service_name, None) - if cache_entry: - logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e - ) - servers = list(cache_entry) - else: - raise e - - defer.returnValue(servers) diff --git a/synapse/http/federation/__init__.py b/synapse/http/federation/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/synapse/http/federation/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py new file mode 100644 index 0000000000..c49b82c394 --- /dev/null +++ b/synapse/http/federation/srv_resolver.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import attr + +from twisted.internet import defer +from twisted.internet.error import ConnectError +from twisted.names import client, dns +from twisted.names.error import DNSNameError, DomainError + +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) + +SERVER_CACHE = {} + + +@attr.s +class Server(object): + """ + Our record of an individual server which can be tried to reach a destination. + + Attributes: + host (bytes): target hostname + port (int): + priority (int): + weight (int): + expires (int): when the cache should expire this record - in *seconds* since + the epoch + """ + host = attr.ib() + port = attr.ib() + priority = attr.ib(default=0) + weight = attr.ib(default=0) + expires = attr.ib(default=0) + + +@defer.inlineCallbacks +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + """Look up a SRV record, with caching + + The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, + but the cache never gets populated), so we add our own caching layer here. + + Args: + service_name (unicode|bytes): record to look up + dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl + cache (dict): cache object + clock (object): clock implementation. must provide a time() method. + + Returns: + Deferred[list[Server]]: a list of the SRV records, or an empty list if none found + """ + # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded + # byteses; however they will obviously end up as separate entries in the cache. We + # should pick one form and stick with it. + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + + try: + answers, _, _ = yield make_deferred_yieldable( + dns_client.lookupService(service_name), + ) + except DNSNameError: + # TODO: cache this. We can get the SOA out of the exception, and use + # the negative-TTL value. + defer.returnValue([]) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + defer.returnValue(list(cache_entry)) + else: + raise e + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b'.')): + raise ConnectError("Service %s unavailable" % service_name) + + servers = [] + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + servers.append(Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=int(clock.time()) + answer.ttl, + )) + + servers.sort() # FIXME: get rid of this (it's broken by the attrs change) + cache[service_name] = list(servers) + defer.returnValue(servers) diff --git a/tests/http/federation/__init__.py b/tests/http/federation/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/tests/http/federation/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py new file mode 100644 index 0000000000..1271a495e1 --- /dev/null +++ b/tests/http/federation/test_srv_resolver.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.internet.error import ConnectError +from twisted.names import dns, error + +from synapse.http.federation.srv_resolver import resolve_service +from synapse.util.logcontext import LoggingContext + +from tests import unittest +from tests.utils import MockClock + + +class SrvResolverTestCase(unittest.TestCase): + def test_resolve(self): + dns_client_mock = Mock() + + service_name = b"test_service.example.com" + host_name = b"example.com" + + answer_srv = dns.RRHeader( + type=dns.SRV, payload=dns.Record_SRV(target=host_name) + ) + + result_deferred = Deferred() + dns_client_mock.lookupService.return_value = result_deferred + + cache = {} + + @defer.inlineCallbacks + def do_lookup(): + with LoggingContext("one") as ctx: + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + self.assertNoResult(resolve_d) + + # should have reset to the sentinel context + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + + result = yield resolve_d + + # should have restored our context + self.assertIs(LoggingContext.current_context(), ctx) + + defer.returnValue(result) + + test_d = do_lookup() + self.assertNoResult(test_d) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + result_deferred.callback( + ([answer_srv], None, None) + ) + + servers = self.successResultOf(test_d) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, host_name) + + @defer.inlineCallbacks + def test_from_cache_expired_and_dns_fail(self): + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.example.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 0 + + cache = {service_name: [entry]} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_from_cache(self): + clock = MockClock() + + dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock.lookupService = Mock(spec_set=[]) + + service_name = "test_service.example.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + + cache = {service_name: [entry]} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache, clock=clock + ) + + self.assertFalse(dns_client_mock.lookupService.called) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_empty_cache(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.example.com" + + cache = {} + + with self.assertRaises(error.DNSServerError): + yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) + + @defer.inlineCallbacks + def test_name_error(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) + + service_name = "test_service.example.com" + + cache = {} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + self.assertEquals(len(servers), 0) + self.assertEquals(len(cache), 0) + + def test_disabled_service(self): + """ + test the behaviour when there is a single record which is ".". + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + self.assertNoResult(resolve_d) + + # returning a single "." should make the lookup fail with a ConenctError + lookup_deferred.callback(( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + )) + + self.failureResultOf(resolve_d, ConnectError) + + def test_non_srv_answer(self): + """ + test the behaviour when the dns server gives us a spurious non-SRV response + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + self.assertNoResult(resolve_d) + + lookup_deferred.callback(( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + )) + + servers = self.successResultOf(resolve_d) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, b"host") diff --git a/tests/test_dns.py b/tests/test_dns.py deleted file mode 100644 index 90bd34be34..0000000000 --- a/tests/test_dns.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mock import Mock - -from twisted.internet import defer -from twisted.names import dns, error - -from synapse.http.endpoint import resolve_service - -from tests.utils import MockClock - -from . import unittest - - -@unittest.DEBUG -class DnsTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_resolve(self): - dns_client_mock = Mock() - - service_name = "test_service.example.com" - host_name = "example.com" - - answer_srv = dns.RRHeader( - type=dns.SRV, payload=dns.Record_SRV(target=host_name) - ) - - dns_client_mock.lookupService.return_value = defer.succeed( - ([answer_srv], None, None) - ) - - cache = {} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - dns_client_mock.lookupService.assert_called_once_with(service_name) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) - - @defer.inlineCallbacks - def test_from_cache_expired_and_dns_fail(self): - dns_client_mock = Mock() - dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - - service_name = "test_service.example.com" - - entry = Mock(spec_set=["expires"]) - entry.expires = 0 - - cache = {service_name: [entry]} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - dns_client_mock.lookupService.assert_called_once_with(service_name) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - - @defer.inlineCallbacks - def test_from_cache(self): - clock = MockClock() - - dns_client_mock = Mock(spec_set=['lookupService']) - dns_client_mock.lookupService = Mock(spec_set=[]) - - service_name = "test_service.example.com" - - entry = Mock(spec_set=["expires"]) - entry.expires = 999999999 - - cache = {service_name: [entry]} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache, clock=clock - ) - - self.assertFalse(dns_client_mock.lookupService.called) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - - @defer.inlineCallbacks - def test_empty_cache(self): - dns_client_mock = Mock() - - dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - - service_name = "test_service.example.com" - - cache = {} - - with self.assertRaises(error.DNSServerError): - yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) - - @defer.inlineCallbacks - def test_name_error(self): - dns_client_mock = Mock() - - dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) - - service_name = "test_service.example.com" - - cache = {} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) -- cgit 1.5.1 From 6bfa735a69717cf24f2196f1abb801d031f6993a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 22 Jan 2019 11:04:20 +0000 Subject: Make key fetches use regular federation client (#4426) All this magic is redundant. --- changelog.d/4426.misc | 1 + synapse/crypto/keyclient.py | 149 -------------------------------------------- synapse/crypto/keyring.py | 30 +++------ 3 files changed, 8 insertions(+), 172 deletions(-) create mode 100644 changelog.d/4426.misc delete mode 100644 synapse/crypto/keyclient.py (limited to 'synapse') diff --git a/changelog.d/4426.misc b/changelog.d/4426.misc new file mode 100644 index 0000000000..cda50438e0 --- /dev/null +++ b/changelog.d/4426.misc @@ -0,0 +1 @@ +Remove redundant SynapseKeyClientProtocol magic \ No newline at end of file diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py deleted file mode 100644 index d40e4b8591..0000000000 --- a/synapse/crypto/keyclient.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six.moves import urllib - -from canonicaljson import json - -from twisted.internet import defer, reactor -from twisted.internet.error import ConnectError -from twisted.internet.protocol import Factory -from twisted.names.error import DomainError -from twisted.web.http import HTTPClient - -from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util import logcontext - -logger = logging.getLogger(__name__) - -KEY_API_V2 = "/_matrix/key/v2/server/%s" - - -@defer.inlineCallbacks -def fetch_server_key(server_name, tls_client_options_factory, key_id): - """Fetch the keys for a remote server.""" - - factory = SynapseKeyClientFactory() - factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), ) - factory.host = server_name - endpoint = matrix_federation_endpoint( - reactor, server_name, tls_client_options_factory, timeout=30 - ) - - for i in range(5): - try: - with logcontext.PreserveLoggingContext(): - protocol = yield endpoint.connect(factory) - server_response, server_certificate = yield protocol.remote_key - defer.returnValue((server_response, server_certificate)) - except SynapseKeyClientError as e: - logger.warn("Error getting key for %r: %s", server_name, e) - if e.status.startswith(b"4"): - # Don't retry for 4xx responses. - raise IOError("Cannot get key for %r" % server_name) - except (ConnectError, DomainError) as e: - logger.warn("Error getting key for %r: %s", server_name, e) - except Exception: - logger.exception("Error getting key for %r", server_name) - raise IOError("Cannot get key for %r" % server_name) - - -class SynapseKeyClientError(Exception): - """The key wasn't retrieved from the remote server.""" - status = None - pass - - -class SynapseKeyClientProtocol(HTTPClient): - """Low level HTTPS client which retrieves an application/json response from - the server and extracts the X.509 certificate for the remote peer from the - SSL connection.""" - - timeout = 30 - - def __init__(self): - self.remote_key = defer.Deferred() - self.host = None - self._peer = None - - def connectionMade(self): - self._peer = self.transport.getPeer() - logger.debug("Connected to %s", self._peer) - - if not isinstance(self.path, bytes): - self.path = self.path.encode('ascii') - - if not isinstance(self.host, bytes): - self.host = self.host.encode('ascii') - - self.sendCommand(b"GET", self.path) - if self.host: - self.sendHeader(b"Host", self.host) - self.endHeaders() - self.timer = reactor.callLater( - self.timeout, - self.on_timeout - ) - - def errback(self, error): - if not self.remote_key.called: - self.remote_key.errback(error) - - def callback(self, result): - if not self.remote_key.called: - self.remote_key.callback(result) - - def handleStatus(self, version, status, message): - if status != b"200": - # logger.info("Non-200 response from %s: %s %s", - # self.transport.getHost(), status, message) - error = SynapseKeyClientError( - "Non-200 response %r from %r" % (status, self.host) - ) - error.status = status - self.errback(error) - self.transport.abortConnection() - - def handleResponse(self, response_body_bytes): - try: - json_response = json.loads(response_body_bytes) - except ValueError: - # logger.info("Invalid JSON response from %s", - # self.transport.getHost()) - self.transport.abortConnection() - return - - certificate = self.transport.getPeerCertificate() - self.callback((json_response, certificate)) - self.transport.abortConnection() - self.timer.cancel() - - def on_timeout(self): - logger.debug( - "Timeout waiting for response from %s: %s", - self.host, self._peer, - ) - self.errback(IOError("Timeout waiting for response")) - self.transport.abortConnection() - - -class SynapseKeyClientFactory(Factory): - def protocol(self): - protocol = SynapseKeyClientProtocol() - protocol.path = self.path - protocol.host = self.host - return protocol diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 515ebbc148..3a96980bed 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import logging from collections import namedtuple +from six.moves import urllib + from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -30,13 +31,11 @@ from signedjson.sign import ( signature_ids, verify_signed_json, ) -from unpaddedbase64 import decode_base64, encode_base64 +from unpaddedbase64 import decode_base64 -from OpenSSL import crypto from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.crypto.keyclient import fetch_server_key from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( LoggingContext, @@ -503,31 +502,16 @@ class Keyring(object): if requested_key_id in keys: continue - (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_client_options_factory, requested_key_id + response = yield self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), + ignore_backoff=True, ) if (u"signatures" not in response or server_name not in response[u"signatures"]): raise KeyLookupError("Key response not signed by remote server") - if "tls_fingerprints" not in response: - raise KeyLookupError("Key response missing TLS fingerprints") - - certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, tls_certificate - ) - sha256_fingerprint = hashlib.sha256(certificate_bytes).digest() - sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) - - response_sha256_fingerprints = set() - for fingerprint in response[u"tls_fingerprints"]: - if u"sha256" in fingerprint: - response_sha256_fingerprints.add(fingerprint[u"sha256"]) - - if sha256_fingerprint_b64 not in response_sha256_fingerprints: - raise KeyLookupError("TLS certificate not allowed by fingerprints") - response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], -- cgit 1.5.1 From 2557531f0f60ee1891a74babed54800cb1bcfd06 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Jan 2019 11:55:53 +0000 Subject: Fix bug when removing duplicate rows from user_ips This was caused by accidentally overwritting a `last_seen` variable in a for loop, causing the wrong value to be written to the progress table. The result of which was that we didn't scan sections of the table when searching for duplicates, and so some duplicates did not get deleted. --- synapse/storage/client_ips.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 5d548f250a..9548cfd3f8 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -110,8 +110,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def _remove_user_ip_dupes(self, progress, batch_size): + # This works function works by scanning the user_ips table in batches + # based on `last_seen`. For each row in a batch it searches the rest of + # the table to see if there are any duplicates, if there are then they + # are removed and replaced with a suitable row. - last_seen_progress = progress.get("last_seen", 0) + # Fetch the start of the batch + begin_last_seen = progress.get("last_seen", 0) def get_last_seen(txn): txn.execute( @@ -122,17 +127,20 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): LIMIT 1 OFFSET ? """, - (last_seen_progress, batch_size) + (begin_last_seen, batch_size) ) - results = txn.fetchone() - return results - - # Get a last seen that's sufficiently far away enough from the last one - last_seen = yield self.runInteraction( + row = txn.fetchone() + if row: + return row[0] + else: + return None + + # Get a last seen that has roughly `batch_size` since `begin_last_seen` + end_last_seen = yield self.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) - if not last_seen: + if end_last_seen is None: # If we get a None then we're reaching the end and just need to # delete the last batch. last = True @@ -142,9 +150,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): last_seen = int(self.clock.time_msec()) * 2 else: last = False - last_seen = last_seen[0] - def remove(txn, last_seen_progress, last_seen): + def remove(txn, begin_last_seen, end_last_seen): # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -166,7 +173,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip HAVING count(*) > 1""", - (last_seen_progress, last_seen) + (begin_last_seen, end_last_seen) ) res = txn.fetchall() @@ -194,11 +201,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) self._background_update_progress_txn( - txn, "user_ips_remove_dupes", {"last_seen": last_seen} + txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) yield self.runInteraction( - "user_ips_dups_remove", remove, last_seen_progress, last_seen + "user_ips_dups_remove", remove, begin_last_seen, end_last_seen ) if last: yield self._end_background_update("user_ips_remove_dupes") -- cgit 1.5.1 From 1c9704f8ab72047a83c6a3b364f3693b332434f2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Jan 2019 16:20:33 +0000 Subject: Don't shadow params --- synapse/storage/client_ips.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 9548cfd3f8..c485211175 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -151,7 +151,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): else: last = False - def remove(txn, begin_last_seen, end_last_seen): + def remove(txn): # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -204,9 +204,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction( - "user_ips_dups_remove", remove, begin_last_seen, end_last_seen - ) + yield self.runInteraction("user_ips_dups_remove", remove) + if last: yield self._end_background_update("user_ips_remove_dupes") -- cgit 1.5.1 From 7f503f83b92150edeb4cc5ae98f6d9bb2e9cdb49 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Jan 2019 16:31:05 +0000 Subject: Refactor to rewrite the SQL instead --- synapse/storage/client_ips.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index c485211175..78721a941a 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -140,16 +140,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "user_ips_dups_get_last_seen", get_last_seen ) - if end_last_seen is None: - # If we get a None then we're reaching the end and just need to - # delete the last batch. - last = True - - # We fake not having an upper bound by using a future date, by - # just multiplying the current time by two.... - last_seen = int(self.clock.time_msec()) * 2 - else: - last = False + # If it returns None, then we're processing the last batch + last = end_last_seen is None def remove(txn): # This works by looking at all entries in the given time span, and @@ -160,6 +152,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # all other duplicates. # It is efficient due to the existence of (user_id, access_token, # ip) and (last_seen) indices. + + # Define the search space, which requires handling the last batch in + # a different way + if last: + clause = "? <= last_seen" + args = (begin_last_seen,) + else: + clause = "? <= last_seen AND last_seen < ?" + args = (begin_last_seen, end_last_seen) + txn.execute( """ SELECT user_id, access_token, ip, @@ -167,13 +169,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): FROM ( SELECT user_id, access_token, ip FROM user_ips - WHERE ? <= last_seen AND last_seen < ? + WHERE {} ORDER BY last_seen ) c INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip - HAVING count(*) > 1""", - (begin_last_seen, end_last_seen) + HAVING count(*) > 1 + """.format(clause), + args ) res = txn.fetchall() -- cgit 1.5.1 From 44be7513bff501055cc2e667a5cca3fb87c23f70 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 21 Jan 2019 23:27:57 +0000 Subject: MatrixFederationAgent Pull the magic that is currently in matrix_federation_endpoint and friends into an agent-like thing --- synapse/http/federation/matrix_federation_agent.py | 114 +++++++++++++++++++++ synapse/http/federation/srv_resolver.py | 33 ++++++ 2 files changed, 147 insertions(+) create mode 100644 synapse/http/federation/matrix_federation_agent.py (limited to 'synapse') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py new file mode 100644 index 0000000000..32bfd68ed1 --- /dev/null +++ b/synapse/http/federation/matrix_federation_agent.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.iweb import IAgent + +from synapse.http.endpoint import parse_server_name +from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) + + +@implementer(IAgent) +class MatrixFederationAgent(object): + """An Agent-like thing which provides a `request` method which will look up a matrix + server and send an HTTP request to it. + + Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) + + Args: + reactor (IReactor): twisted reactor to use for underlying requests + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + """ + + def __init__(self, reactor, tls_client_options_factory): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory + + self._pool = HTTPConnectionPool(reactor) + self._pool.retryAutomatically = False + self._pool.maxPersistentPerHost = 5 + self._pool.cachedConnectionTimeout = 2 * 60 + + @defer.inlineCallbacks + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Args: + method (bytes): HTTP method: GET/POST/etc + + uri (bytes): Absolute URI to be retrieved + + headers (twisted.web.http_headers.Headers|None): + HTTP headers to send with the request, or None to + send no extra headers. + + bodyProducer (twisted.web.iweb.IBodyProducer|None): + An object which can generate bytes to make up the + body of this request (for example, the properly encoded contents of + a file for a file upload). Or None if the request is to have + no body. + + Returns: + Deferred[twisted.web.iweb.IResponse]: + fires when the header of the response has been received (regardless of the + response status code). Fails if there is any problem which prevents that + response from being received (including problems that prevent the request + from being sent). + """ + + parsed_uri = URI.fromBytes(uri) + server_name_bytes = parsed_uri.netloc + host, port = parse_server_name(server_name_bytes.decode("ascii")) + + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + if self._tls_client_options_factory is None: + tls_options = None + else: + tls_options = self._tls_client_options_factory.get_options(host) + + if port is not None: + target = (host, port) + else: + server_list = yield resolve_service(server_name_bytes) + if not server_list: + target = (host, 8448) + logger.debug("No SRV record for %s, using %s", host, target) + else: + target = pick_server_from_list(server_list) + + class EndpointFactory(object): + @staticmethod + def endpointForURI(_uri): + logger.info("Connecting to %s:%s", target[0], target[1]) + ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) + if tls_options is not None: + ep = wrapClientTLS(tls_options, ep) + return ep + + agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) + res = yield make_deferred_yieldable( + agent.request(method, uri, headers, bodyProducer) + ) + defer.returnValue(res) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index c49b82c394..ded0b32832 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import random import time import attr @@ -51,6 +52,38 @@ class Server(object): expires = attr.ib(default=0) +def pick_server_from_list(server_list): + """Randomly choose a server from the server list + + Args: + server_list (list[Server]): list of candidate servers + + Returns: + Tuple[bytes, int]: (host, port) pair for the chosen server + """ + if not server_list: + raise RuntimeError("pick_server_from_list called with empty list") + + # TODO: currently we only use the lowest-priority servers. We should maintain a + # cache of servers known to be "down" and filter them out + + min_priority = min(s.priority for s in server_list) + eligible_servers = list(s for s in server_list if s.priority == min_priority) + total_weight = sum(s.weight for s in eligible_servers) + target_weight = random.randint(0, total_weight) + + for s in eligible_servers: + target_weight -= s.weight + + if target_weight <= 0: + return s.host, s.port + + # this should be impossible. + raise RuntimeError( + "pick_server_from_list got to end of eligible server list.", + ) + + @defer.inlineCallbacks def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): """Look up a SRV record, with caching -- cgit 1.5.1 From 7871146667afd76576939c91986a8f8bacb49446 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 21 Jan 2019 23:29:47 +0000 Subject: Make MatrixFederationClient use MatrixFederationAgent ... instead of the matrix_federation_endpoint --- synapse/http/matrixfederationclient.py | 37 ++++--------- tests/http/test_fedclient.py | 96 ++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 27 deletions(-) (limited to 'synapse') diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 250bb1ef91..980e912348 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -32,7 +32,7 @@ from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool +from twisted.web.client import FileBodyProducer from twisted.web.http_headers import Headers import synapse.metrics @@ -44,7 +44,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.http.endpoint import matrix_federation_endpoint +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable from synapse.util.metrics import Measure @@ -66,20 +66,6 @@ else: MAXINT = sys.maxint -class MatrixFederationEndpointFactory(object): - def __init__(self, hs): - self.reactor = hs.get_reactor() - self.tls_client_options_factory = hs.tls_client_options_factory - - def endpointForURI(self, uri): - destination = uri.netloc.decode('ascii') - - return matrix_federation_endpoint( - self.reactor, destination, timeout=10, - tls_client_options_factory=self.tls_client_options_factory - ) - - _next_id = 1 @@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object): self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname reactor = hs.get_reactor() - pool = HTTPConnectionPool(reactor) - pool.retryAutomatically = False - pool.maxPersistentPerHost = 5 - pool.cachedConnectionTimeout = 2 * 60 - self.agent = Agent.usingEndpointFactory( - reactor, MatrixFederationEndpointFactory(hs), pool=pool + + self.agent = MatrixFederationAgent( + hs.get_reactor(), + hs.tls_client_options_factory, ) self.clock = hs.get_clock() self._store = hs.get_datastore() @@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers logger.info( - "{%s} [%s] Sending request: %s %s", + "{%s} [%s] Sending request: %s %s; timeout %fs", request.txn_id, request.destination, request.method, - url_str, + url_str, _sec_timeout, ) try: @@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object): reactor=self.hs.get_reactor(), ) - response = yield make_deferred_yieldable( - request_deferred, - ) + response = yield request_deferred except DNSLookupError as e: raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) except Exception as e: + logger.info("Failed to send request: %s", e) raise_from(RequestSendFailed(e, can_retry=True), e) logger.info( diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 8426eee400..d37f8f9981 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -15,6 +15,7 @@ from mock import Mock +from twisted.internet import defer from twisted.internet.defer import TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError from twisted.test.proto_helpers import StringTransport @@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) +from synapse.util.logcontext import LoggingContext from tests.server import FakeTransport from tests.unittest import HomeserverTestCase +def check_logcontext(context): + current = LoggingContext.current_context() + if current is not context: + raise AssertionError( + "Expected logcontext %s but was %s" % (context, current), + ) + + class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): @@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase): self.cl = MatrixFederationHttpClient(self.hs) self.reactor.lookups["testserv"] = "1.2.3.4" + def test_client_get(self): + """ + happy-path test of a GET request + """ + @defer.inlineCallbacks + def do_request(): + with LoggingContext("one") as context: + fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + + # Nothing happened yet + self.assertNoResult(fetch_d) + + # should have reset logcontext to the sentinel + check_logcontext(LoggingContext.sentinel) + + try: + fetch_res = yield fetch_d + defer.returnValue(fetch_res) + finally: + check_logcontext(context) + + test_d = do_request() + + self.pump() + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # that should have made it send the request to the transport + self.assertRegex(transport.value(), b"^GET /foo/bar") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # Send it the HTTP response + res_json = '{ "a": 1 }'.encode('ascii') + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: %i\r\n" + b"\r\n" + b"%s" % (len(res_json), res_json) + ) + + self.pump() + + res = self.successResultOf(test_d) + + # check the response is as expected + self.assertEqual(res, {"a": 1}) + def test_dns_error(self): """ If the DNS lookup returns an error, it will bubble up. @@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) + def test_client_connection_refused(self): + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8008) + e = Exception("go away") + factory.clientConnectionFailed(None, e) + self.pump(0.5) + + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestSendFailed) + self.assertIs(f.value.inner_exception, e) + def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a -- cgit 1.5.1 From fe212bbe4a6774e34b0a474e1d3c6aee4ae4e3c4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 21 Jan 2019 23:42:22 +0000 Subject: Kill off matrix_federation_endpoint this thing is now redundant. --- synapse/http/endpoint.py | 144 -------------------------------- synapse/http/federation/srv_resolver.py | 1 - 2 files changed, 145 deletions(-) (limited to 'synapse') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 815f8ff2f7..cd79ebab62 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -13,15 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import random import re -from twisted.internet import defer -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.error import ConnectError - -from synapse.http.federation.srv_resolver import Server, resolve_service - logger = logging.getLogger(__name__) @@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name): )) return host, port - - -def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None, - timeout=None): - """Construct an endpoint for the given matrix destination. - - Args: - reactor: Twisted reactor. - destination (unicode): The name of the server to connect to. - tls_client_options_factory - (synapse.crypto.context_factory.ClientTLSOptionsFactory): - Factory which generates TLS options for client connections. - timeout (int): connection timeout in seconds - """ - - domain, port = parse_server_name(destination) - - endpoint_kw_args = {} - - if timeout is not None: - endpoint_kw_args.update(timeout=timeout) - - if tls_client_options_factory is None: - transport_endpoint = HostnameEndpoint - default_port = 8008 - else: - # the SNI string should be the same as the Host header, minus the port. - # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777, - # the Host header and SNI should therefore be the server_name of the remote - # server. - tls_options = tls_client_options_factory.get_options(domain) - - def transport_endpoint(reactor, host, port, timeout): - return wrapClientTLS( - tls_options, - HostnameEndpoint(reactor, host, port, timeout=timeout), - ) - default_port = 8448 - - if port is None: - return SRVClientEndpoint( - reactor, "matrix", domain, protocol="tcp", - default_port=default_port, endpoint=transport_endpoint, - endpoint_kw_args=endpoint_kw_args - ) - else: - return transport_endpoint( - reactor, domain, port, **endpoint_kw_args - ) - - -class SRVClientEndpoint(object): - """An endpoint which looks up SRV records for a service. - Cycles through the list of servers starting with each call to connect - picking the next server. - Implements twisted.internet.interfaces.IStreamClientEndpoint. - """ - - def __init__(self, reactor, service, domain, protocol="tcp", - default_port=None, endpoint=HostnameEndpoint, - endpoint_kw_args={}): - self.reactor = reactor - self.service_name = "_%s._%s.%s" % (service, protocol, domain) - - if default_port is not None: - self.default_server = Server( - host=domain, - port=default_port, - ) - else: - self.default_server = None - - self.endpoint = endpoint - self.endpoint_kw_args = endpoint_kw_args - - self.servers = None - self.used_servers = None - - @defer.inlineCallbacks - def fetch_servers(self): - self.used_servers = [] - self.servers = yield resolve_service(self.service_name) - - def pick_server(self): - if not self.servers: - if self.used_servers: - self.servers = self.used_servers - self.used_servers = [] - self.servers.sort() - elif self.default_server: - return self.default_server - else: - raise ConnectError( - "No server available for %s" % self.service_name - ) - - # look for all servers with the same priority - min_priority = self.servers[0].priority - weight_indexes = list( - (index, server.weight + 1) - for index, server in enumerate(self.servers) - if server.priority == min_priority - ) - - total_weight = sum(weight for index, weight in weight_indexes) - target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: - target_weight -= weight - if target_weight <= 0: - server = self.servers[index] - # XXX: this looks totally dubious: - # - # (a) we never reuse a server until we have been through - # all of the servers at the same priority, so if the - # weights are A: 100, B:1, we always do ABABAB instead of - # AAAA...AAAB (approximately). - # - # (b) After using all the servers at the lowest priority, - # we move onto the next priority. We should only use the - # second priority if servers at the top priority are - # unreachable. - # - del self.servers[index] - self.used_servers.append(server) - return server - - @defer.inlineCallbacks - def connect(self, protocolFactory): - if self.servers is None: - yield self.fetch_servers() - server = self.pick_server() - logger.info("Connecting to %s:%s", server.host, server.port) - endpoint = self.endpoint( - self.reactor, server.host, server.port, **self.endpoint_kw_args - ) - connection = yield endpoint.connect(protocolFactory) - defer.returnValue(connection) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index ded0b32832..a6e92fdf40 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -152,6 +152,5 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t expires=int(clock.time()) + answer.ttl, )) - servers.sort() # FIXME: get rid of this (it's broken by the attrs change) cache[service_name] = list(servers) defer.returnValue(servers) -- cgit 1.5.1 From 53a327b4d5bdcab36651aa86ddf1815ff86e5db2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 22 Jan 2019 17:35:09 +0000 Subject: Require that service_name be a byte string it is only ever a bytes now, so let's enforce that. --- synapse/http/federation/srv_resolver.py | 8 ++++---- tests/http/federation/test_srv_resolver.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'synapse') diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index a6e92fdf40..e05f934d0b 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -92,7 +92,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t but the cache never gets populated), so we add our own caching layer here. Args: - service_name (unicode|bytes): record to look up + service_name (bytes): record to look up dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl cache (dict): cache object clock (object): clock implementation. must provide a time() method. @@ -100,9 +100,9 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t Returns: Deferred[list[Server]]: a list of the SRV records, or an empty list if none found """ - # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded - # byteses; however they will obviously end up as separate entries in the cache. We - # should pick one form and stick with it. + if not isinstance(service_name, bytes): + raise TypeError("%r is not a byte string" % (service_name,)) + cache_entry = cache.get(service_name, None) if cache_entry: if all(s.expires > int(clock.time()) for s in cache_entry): diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 1271a495e1..de4d0089c8 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,7 +83,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.example.com" + service_name = b"test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 0 @@ -106,7 +106,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock = Mock(spec_set=['lookupService']) dns_client_mock.lookupService = Mock(spec_set=[]) - service_name = "test_service.example.com" + service_name = b"test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 999999999 @@ -128,7 +128,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.example.com" + service_name = b"test_service.example.com" cache = {} @@ -141,7 +141,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) - service_name = "test_service.example.com" + service_name = b"test_service.example.com" cache = {} -- cgit 1.5.1 From 7021784d4612c328ef174963c6d5ca9a37d24bc7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 22 Jan 2019 17:42:26 +0000 Subject: put resolve_service in an object this makes it easier to stub things out for tests. --- synapse/http/federation/matrix_federation_agent.py | 16 ++- synapse/http/federation/srv_resolver.py | 133 +++++++++++---------- tests/http/federation/test_srv_resolver.py | 38 +++--- 3 files changed, 104 insertions(+), 83 deletions(-) (limited to 'synapse') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 32bfd68ed1..64c780a341 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.iweb import IAgent from synapse.http.endpoint import parse_server_name -from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service +from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) @@ -37,13 +37,23 @@ class MatrixFederationAgent(object): Args: reactor (IReactor): twisted reactor to use for underlying requests + tls_client_options_factory (ClientTLSOptionsFactory|None): factory to use for fetching client tls options, or none to disable TLS. + + srv_resolver (SrvResolver|None): + SRVResolver impl to use for looking up SRV records. None to use a default + implementation. """ - def __init__(self, reactor, tls_client_options_factory): + def __init__( + self, reactor, tls_client_options_factory, _srv_resolver=None, + ): self._reactor = reactor self._tls_client_options_factory = tls_client_options_factory + if _srv_resolver is None: + _srv_resolver = SrvResolver() + self._srv_resolver = _srv_resolver self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False @@ -91,7 +101,7 @@ class MatrixFederationAgent(object): if port is not None: target = (host, port) else: - server_list = yield resolve_service(server_name_bytes) + server_list = yield self._srv_resolver.resolve_service(server_name_bytes) if not server_list: target = (host, 8448) logger.debug("No SRV record for %s, using %s", host, target) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index e05f934d0b..71830c549d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -84,73 +84,86 @@ def pick_server_from_list(server_list): ) -@defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): - """Look up a SRV record, with caching +class SrvResolver(object): + """Interface to the dns client to do SRV lookups, with result caching. The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, but the cache never gets populated), so we add our own caching layer here. Args: - service_name (bytes): record to look up dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl cache (dict): cache object - clock (object): clock implementation. must provide a time() method. - - Returns: - Deferred[list[Server]]: a list of the SRV records, or an empty list if none found + get_time (callable): clock implementation. Should return seconds since the epoch """ - if not isinstance(service_name, bytes): - raise TypeError("%r is not a byte string" % (service_name,)) - - cache_entry = cache.get(service_name, None) - if cache_entry: - if all(s.expires > int(clock.time()) for s in cache_entry): - servers = list(cache_entry) - defer.returnValue(servers) - - try: - answers, _, _ = yield make_deferred_yieldable( - dns_client.lookupService(service_name), - ) - except DNSNameError: - # TODO: cache this. We can get the SOA out of the exception, and use - # the negative-TTL value. - defer.returnValue([]) - except DomainError as e: - # We failed to resolve the name (other than a NameError) - # Try something in the cache, else rereaise - cache_entry = cache.get(service_name, None) + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): + self._dns_client = dns_client + self._cache = cache + self._get_time = get_time + + @defer.inlineCallbacks + def resolve_service(self, service_name): + """Look up a SRV record + + Args: + service_name (bytes): record to look up + + Returns: + Deferred[list[Server]]: + a list of the SRV records, or an empty list if none found + """ + now = int(self._get_time()) + + if not isinstance(service_name, bytes): + raise TypeError("%r is not a byte string" % (service_name,)) + + cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + if all(s.expires > now for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + + try: + answers, _, _ = yield make_deferred_yieldable( + self._dns_client.lookupService(service_name), ) - defer.returnValue(list(cache_entry)) - else: - raise e - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): - raise ConnectError("Service %s unavailable" % service_name) - - servers = [] - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - - payload = answer.payload - - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=int(clock.time()) + answer.ttl, - )) - - cache[service_name] = list(servers) - defer.returnValue(servers) + except DNSNameError: + # TODO: cache this. We can get the SOA out of the exception, and use + # the negative-TTL value. + defer.returnValue([]) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = self._cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + defer.returnValue(list(cache_entry)) + else: + raise e + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b'.')): + raise ConnectError("Service %s unavailable" % service_name) + + servers = [] + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + servers.append(Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + )) + + self._cache[service_name] = list(servers) + defer.returnValue(servers) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index de4d0089c8..a872e2441e 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred from twisted.internet.error import ConnectError from twisted.names import dns, error -from synapse.http.federation.srv_resolver import resolve_service +from synapse.http.federation.srv_resolver import SrvResolver from synapse.util.logcontext import LoggingContext from tests import unittest @@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = result_deferred cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @defer.inlineCallbacks def do_lookup(): + with LoggingContext("one") as ctx: - resolve_d = resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) @@ -89,10 +89,9 @@ class SrvResolverTestCase(unittest.TestCase): entry.expires = 0 cache = {service_name: [entry]} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + servers = yield resolver.resolve_service(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -112,11 +111,12 @@ class SrvResolverTestCase(unittest.TestCase): entry.expires = 999999999 cache = {service_name: [entry]} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache, clock=clock + resolver = SrvResolver( + dns_client=dns_client_mock, cache=cache, get_time=clock.time, ) + servers = yield resolver.resolve_service(service_name) + self.assertFalse(dns_client_mock.lookupService.called) self.assertEquals(len(servers), 1) @@ -131,9 +131,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) + yield resolver.resolve_service(service_name) @defer.inlineCallbacks def test_name_error(self): @@ -144,10 +145,9 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + servers = yield resolver.resolve_service(service_name) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) # returning a single "." should make the lookup fail with a ConenctError @@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) lookup_deferred.callback(( -- cgit 1.5.1 From 6129e52f437c2e03b711453434924e170f3d11bf Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 23 Jan 2019 19:39:06 +1100 Subject: Support ACME for certificate provisioning (#4384) --- changelog.d/4384.feature | 1 + scripts-dev/build_debian_packages | 2 +- synapse/app/homeserver.py | 56 ++++++++++++--- synapse/config/_base.py | 4 +- synapse/config/tls.py | 115 ++++++++++++++++++++++------- synapse/handlers/acme.py | 147 ++++++++++++++++++++++++++++++++++++++ synapse/python_dependencies.py | 4 ++ synapse/server.py | 5 ++ 8 files changed, 298 insertions(+), 36 deletions(-) create mode 100644 changelog.d/4384.feature create mode 100644 synapse/handlers/acme.py (limited to 'synapse') diff --git a/changelog.d/4384.feature b/changelog.d/4384.feature new file mode 100644 index 0000000000..daedcd58c4 --- /dev/null +++ b/changelog.d/4384.feature @@ -0,0 +1 @@ +Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt). diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 577d93e6f6..6b9be99060 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -10,12 +10,12 @@ # can be passed on the commandline for debugging. import argparse -from concurrent.futures import ThreadPoolExecutor import os import signal import subprocess import sys import threading +from concurrent.futures import ThreadPoolExecutor DISTS = ( "debian:stretch", diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f3ac3d19f0..ffc49d77cc 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -13,10 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import logging import os import sys +import traceback from six import iteritems @@ -324,17 +326,12 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - tls_server_context_factory = context_factory.ServerContextFactory(config) - tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) - database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( config.server_name, db_config=config.database_config, - tls_server_context_factory=tls_server_context_factory, - tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, @@ -361,12 +358,53 @@ def setup(config_options): logger.info("Database prepared in %s.", config.database_config['name']) hs.setup() - hs.start_listening() + @defer.inlineCallbacks def start(): - hs.get_pusherpool().start() - hs.get_datastore().start_profiling() - hs.get_datastore().start_doing_background_updates() + try: + # Check if the certificate is still valid. + cert_days_remaining = hs.config.is_disk_cert_valid() + + if hs.config.acme_enabled: + # If ACME is enabled, we might need to provision a certificate + # before starting. + acme = hs.get_acme_handler() + + # Start up the webservices which we will respond to ACME + # challenges with. + yield acme.start_listening() + + # We want to reprovision if cert_days_remaining is None (meaning no + # certificate exists), or the days remaining number it returns + # is less than our re-registration threshold. + if (cert_days_remaining is None) or ( + not cert_days_remaining > hs.config.acme_reprovision_threshold + ): + yield acme.provision_certificate() + + # Read the certificate from disk and build the context factories for + # TLS. + hs.config.read_certificate_from_disk() + hs.tls_server_context_factory = context_factory.ServerContextFactory(config) + hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + config + ) + + # It is now safe to start your Synapse. + hs.start_listening() + hs.get_pusherpool().start() + hs.get_datastore().start_profiling() + hs.get_datastore().start_doing_background_updates() + except Exception as e: + # If a DeferredList failed (like in listening on the ACME listener), + # we need to print the subfailure explicitly. + if isinstance(e, defer.FirstError): + e.subFailure.printTraceback(sys.stderr) + sys.exit(1) + + # Something else went wrong when starting. Print it and bail out. + traceback.print_exc(file=sys.stderr) + sys.exit(1) reactor.callWhenRunning(start) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fd2d6d52ef..5858fb92b4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -367,7 +367,7 @@ class Config(object): if not keys_directory: keys_directory = os.path.dirname(config_files[-1]) - config_dir_path = os.path.abspath(keys_directory) + self.config_dir_path = os.path.abspath(keys_directory) specified_config = {} for config_file in config_files: @@ -379,7 +379,7 @@ class Config(object): server_name = specified_config["server_name"] config_string = self.generate_config( - config_dir_path=config_dir_path, + config_dir_path=self.config_dir_path, data_dir_path=os.getcwd(), server_name=server_name, generate_secrets=False, diff --git a/synapse/config/tls.py b/synapse/config/tls.py index bb8952c672..a75e233aa0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -13,60 +13,110 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os +from datetime import datetime from hashlib import sha256 from unpaddedbase64 import encode_base64 from OpenSSL import crypto -from ._base import Config +from synapse.config._base import Config + +logger = logging.getLogger() class TlsConfig(Config): def read_config(self, config): - self.tls_certificate = self.read_tls_certificate( - config.get("tls_certificate_path") - ) - self.tls_certificate_file = config.get("tls_certificate_path") + acme_config = config.get("acme", {}) + self.acme_enabled = acme_config.get("enabled", False) + self.acme_url = acme_config.get( + "url", "https://acme-v01.api.letsencrypt.org/directory" + ) + self.acme_port = acme_config.get("port", 8449) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"]) + self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) + + self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path")) + self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path")) + self._original_tls_fingerprints = config["tls_fingerprints"] + self.tls_fingerprints = list(self._original_tls_fingerprints) self.no_tls = config.get("no_tls", False) - if self.no_tls: - self.tls_private_key = None - else: - self.tls_private_key = self.read_tls_private_key( - config.get("tls_private_key_path") - ) + # This config option applies to non-federation HTTP clients + # (e.g. for talking to recaptcha, identity servers, and such) + # It should never be used in production, and is intended for + # use only when running tests. + self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( + "use_insecure_ssl_client_just_for_testing_do_not_use" + ) + + self.tls_certificate = None + self.tls_private_key = None + + def is_disk_cert_valid(self): + """ + Is the certificate we have on disk valid, and if so, for how long? + + Returns: + int: Days remaining of certificate validity. + None: No certificate exists. + """ + if not os.path.exists(self.tls_certificate_file): + return None + + try: + with open(self.tls_certificate_file, 'rb') as f: + cert_pem = f.read() + except Exception: + logger.exception("Failed to read existing certificate off disk!") + raise + + try: + tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + except Exception: + logger.exception("Failed to parse existing certificate off disk!") + raise + + # YYYYMMDDhhmmssZ -- in UTC + expires_on = datetime.strptime( + tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" + ) + now = datetime.utcnow() + days_remaining = (expires_on - now).days + return days_remaining - self.tls_fingerprints = config["tls_fingerprints"] + def read_certificate_from_disk(self): + """ + Read the certificates from disk. + """ + self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file) + + if not self.no_tls: + self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file) + + self.tls_fingerprints = list(self._original_tls_fingerprints) # Check that our own certificate is included in the list of fingerprints # and include it if it is not. x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, - self.tls_certificate + crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) - # This config option applies to non-federation HTTP clients - # (e.g. for talking to recaptcha, identity servers, and such) - # It should never be used in production, and is intended for - # use only when running tests. - self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( - "use_insecure_ssl_client_just_for_testing_do_not_use" - ) - def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" tls_private_key_path = base_key_name + ".tls.key" - return """\ + return ( + """\ # PEM encoded X509 certificate for TLS. # You can replace the self-signed certificate that synapse # autogenerates on launch with your own SSL certificate + key pair @@ -107,7 +157,24 @@ class TlsConfig(Config): # tls_fingerprints: [] # tls_fingerprints: [{"sha256": ""}] - """ % locals() + + ## Support for ACME certificate auto-provisioning. + # acme: + # enabled: false + ## ACME path. + ## If you only want to test, use the staging url: + ## https://acme-staging.api.letsencrypt.org/directory + # url: 'https://acme-v01.api.letsencrypt.org/directory' + ## Port number (to listen for the HTTP-01 challenge). + ## Using port 80 requires utilising something like authbind, or proxying to it. + # port: 8449 + ## Hosts to bind to. + # bind_addresses: ['127.0.0.1'] + ## How many days remaining on a certificate before it is renewed. + # reprovision_threshold: 30 + """ + % locals() + ) def read_tls_certificate(self, cert_path): cert_pem = self.read_file(cert_path, "tls_certificate") diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py new file mode 100644 index 0000000000..73ea7ed018 --- /dev/null +++ b/synapse/handlers/acme.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import serverFromString +from twisted.python.filepath import FilePath +from twisted.python.url import URL +from twisted.web import server, static +from twisted.web.resource import Resource + +logger = logging.getLogger(__name__) + +try: + from txacme.interfaces import ICertificateStore + + @attr.s + @implementer(ICertificateStore) + class ErsatzStore(object): + """ + A store that only stores in memory. + """ + + certs = attr.ib(default=attr.Factory(dict)) + + def store(self, server_name, pem_objects): + self.certs[server_name] = [o.as_bytes() for o in pem_objects] + return defer.succeed(None) + + +except ImportError: + # txacme is missing + pass + + +class AcmeHandler(object): + def __init__(self, hs): + self.hs = hs + self.reactor = hs.get_reactor() + + @defer.inlineCallbacks + def start_listening(self): + + # Configure logging for txacme, if you need to debug + # from eliot import add_destinations + # from eliot.twisted import TwistedDestination + # + # add_destinations(TwistedDestination()) + + from txacme.challenges import HTTP01Responder + from txacme.service import AcmeIssuingService + from txacme.endpoint import load_or_create_client_key + from txacme.client import Client + from josepy.jwa import RS256 + + self._store = ErsatzStore() + responder = HTTP01Responder() + + self._issuer = AcmeIssuingService( + cert_store=self._store, + client_creator=( + lambda: Client.from_url( + reactor=self.reactor, + url=URL.from_text(self.hs.config.acme_url), + key=load_or_create_client_key( + FilePath(self.hs.config.config_dir_path) + ), + alg=RS256, + ) + ), + clock=self.reactor, + responders=[responder], + ) + + well_known = Resource() + well_known.putChild(b'acme-challenge', responder.resource) + responder_resource = Resource() + responder_resource.putChild(b'.well-known', well_known) + responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) + + srv = server.Site(responder_resource) + + listeners = [] + + for host in self.hs.config.acme_bind_addresses: + logger.info( + "Listening for ACME requests on %s:%s", host, self.hs.config.acme_port + ) + endpoint = serverFromString( + self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host) + ) + listeners.append(endpoint.listen(srv)) + + # Make sure we are registered to the ACME server. There's no public API + # for this, it is usually triggered by startService, but since we don't + # want it to control where we save the certificates, we have to reach in + # and trigger the registration machinery ourselves. + self._issuer._registered = False + yield self._issuer._ensure_registered() + + # Return a Deferred that will fire when all the servers have started up. + yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True) + + @defer.inlineCallbacks + def provision_certificate(self): + + logger.warning("Reprovisioning %s", self.hs.hostname) + + try: + yield self._issuer.issue_cert(self.hs.hostname) + except Exception: + logger.exception("Fail!") + raise + logger.warning("Reprovisioned %s, saving.", self.hs.hostname) + cert_chain = self._store.certs[self.hs.hostname] + + try: + with open(self.hs.config.tls_private_key_file, "wb") as private_key_file: + for x in cert_chain: + if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"): + private_key_file.write(x) + + with open(self.hs.config.tls_certificate_file, "wb") as certificate_file: + for x in cert_chain: + if x.startswith(b"-----BEGIN CERTIFICATE-----"): + certificate_file.write(x) + except Exception: + logger.exception("Failed saving!") + raise + + defer.returnValue(True) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 882e844eb1..756721e304 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = { # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], + # ACME support is required to provision TLS certificates from authorities + # that use the protocol, such as Let's Encrypt. + "acme": ["txacme>=0.9.2"], + "saml2": ["pysaml2>=4.5.0"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0"], diff --git a/synapse/server.py b/synapse/server.py index 9985687b95..c8914302cf 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler from synapse.handlers import Handlers +from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.deactivate_account import DeactivateAccountHandler @@ -129,6 +130,7 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'acme_handler', 'auth_handler', 'device_handler', 'e2e_keys_handler', @@ -310,6 +312,9 @@ class HomeServer(object): def build_e2e_room_keys_handler(self): return E2eRoomKeysHandler(self) + def build_acme_handler(self): + return AcmeHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) -- cgit 1.5.1 From 90743c9d8910679bb688070b3929f3005e106d00 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Jan 2019 08:45:18 +0000 Subject: Fixup removal of duplicate `user_ips` rows (#4432) * Remove unnecessary ORDER BY clause * Add logging * Newsfile --- changelog.d/4432.misc | 1 + synapse/storage/client_ips.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4432.misc (limited to 'synapse') diff --git a/changelog.d/4432.misc b/changelog.d/4432.misc new file mode 100644 index 0000000000..047061ed3c --- /dev/null +++ b/changelog.d/4432.misc @@ -0,0 +1 @@ +Apply a unique index to the user_ips table, preventing duplicates. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 78721a941a..b228a20ac2 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -143,6 +143,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # If it returns None, then we're processing the last batch last = end_last_seen is None + logger.info( + "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", + begin_last_seen, end_last_seen, + ) + def remove(txn): # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range @@ -170,7 +175,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): SELECT user_id, access_token, ip FROM user_ips WHERE {} - ORDER BY last_seen ) c INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip -- cgit 1.5.1 From 82a92ba535b424009ef752add2e0d5e198254e04 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 23 Jan 2019 15:01:09 +0000 Subject: Add metric for user dir current event stream position --- synapse/handlers/user_directory.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 3c40999338..120815b09b 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,6 +19,7 @@ from six import iteritems from twisted.internet import defer +import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo @@ -163,6 +164,11 @@ class UserDirectoryHandler(object): yield self._handle_deltas(deltas) self.pos = deltas[-1]["stream_id"] + + # Expose current event processing position to prometheus + synapse.metrics.event_processing_positions.labels( + "user_dir").set(self.pos) + yield self.store.update_user_directory_stream_pos(self.pos) @defer.inlineCallbacks -- cgit 1.5.1 From 97fd29c019ae92cd3dc0635de249acfc9c892340 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Jan 2019 09:34:44 +0000 Subject: Don't send IP addresses as SNI (#4452) The problem here is that we have cut-and-pasted an impl from Twisted, and then failed to maintain it. It was fixed in Twisted in https://github.com/twisted/twisted/pull/1047/files; let's do the same here. --- changelog.d/4452.bugfix | 1 + synapse/crypto/context_factory.py | 15 ++++-- .../federation/test_matrix_federation_agent.py | 63 ++++++++++++++++++++-- 3 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 changelog.d/4452.bugfix (limited to 'synapse') diff --git a/changelog.d/4452.bugfix b/changelog.d/4452.bugfix new file mode 100644 index 0000000000..a715ca3788 --- /dev/null +++ b/changelog.d/4452.bugfix @@ -0,0 +1 @@ +Don't send IP addresses as SNI diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 6ba3eca7b2..286ad80100 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -17,6 +17,7 @@ from zope.interface import implementer from OpenSSL import SSL, crypto from twisted.internet._sslverify import _defaultCurveName +from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.internet.ssl import CertificateOptions, ContextFactory from twisted.python.failure import Failure @@ -98,8 +99,14 @@ class ClientTLSOptions(object): def __init__(self, hostname, ctx): self._ctx = ctx - self._hostname = hostname - self._hostnameBytes = _idnaBytes(hostname) + + if isIPAddress(hostname) or isIPv6Address(hostname): + self._hostnameBytes = hostname.encode('ascii') + self._sendSNI = False + else: + self._hostnameBytes = _idnaBytes(hostname) + self._sendSNI = True + ctx.set_info_callback( _tolerateErrors(self._identityVerifyingInfoCallback) ) @@ -111,7 +118,9 @@ class ClientTLSOptions(object): return connection def _identityVerifyingInfoCallback(self, connection, where, ret): - if where & SSL.SSL_CB_HANDSHAKE_START: + # Literal IPv4 and IPv6 addresses are not permitted + # as host names according to the RFCs + if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI: connection.set_tlsext_host_name(self._hostnameBytes) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index eb963d80fb..7a3881f558 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -46,7 +46,7 @@ class MatrixFederationAgentTests(TestCase): _srv_resolver=self.mock_resolver, ) - def _make_connection(self, client_factory): + def _make_connection(self, client_factory, expected_sni): """Builds a test server, and completes the outgoing client connection Returns: @@ -69,9 +69,17 @@ class MatrixFederationAgentTests(TestCase): # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor)) - # finally, give the reactor a pump to get the TLS juices flowing. + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) + # check the SNI + server_name = server_tls_protocol._tlsConnection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + # fish the test server back out of the server-side TLS protocol. return server_tls_protocol.wrappedProtocol @@ -113,7 +121,10 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory) + http_server = self._make_connection( + client_factory, + expected_sni=b"testserv", + ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -150,6 +161,52 @@ class MatrixFederationAgentTests(TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): + """ + Test the behaviour when the server name contains an explicit IP (with no port) + """ + + # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?) + self.mock_resolver.resolve_service.side_effect = lambda _: [] + + # then there will be a getaddrinfo on the IP + self.reactor.lookups["1.2.3.4"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once() + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=None, + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') + # XXX currently broken + # self.assertEqual( + # request.requestHeaders.getRawHeaders(b'host'), + # [b'1.2.3.4:8448'] + # ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + def _check_logcontext(context): current = LoggingContext.current_context() -- cgit 1.5.1 From 58f6c4818337364dd9c6bf01062e7b0dadcb8a25 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 24 Jan 2019 21:31:54 +1100 Subject: Use native UPSERTs where possible (#4306) --- .coveragerc | 6 +- .gitignore | 6 +- changelog.d/4306.misc | 1 + synapse/storage/_base.py | 148 +++++++++++++++++++++++++++++++++--- synapse/storage/client_ips.py | 5 +- synapse/storage/engines/__init__.py | 2 +- synapse/storage/engines/postgres.py | 14 ++++ synapse/storage/engines/sqlite.py | 96 +++++++++++++++++++++++ synapse/storage/engines/sqlite3.py | 87 --------------------- synapse/storage/pusher.py | 9 ++- synapse/storage/user_directory.py | 55 ++++++++++---- tests/storage/test_base.py | 1 + tests/test_server.py | 12 ++- tests/unittest.py | 12 ++- tox.ini | 1 + 15 files changed, 325 insertions(+), 130 deletions(-) create mode 100644 changelog.d/4306.misc create mode 100644 synapse/storage/engines/sqlite.py delete mode 100644 synapse/storage/engines/sqlite3.py (limited to 'synapse') diff --git a/.coveragerc b/.coveragerc index 9873a30738..e9460a340a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,11 +1,7 @@ [run] branch = True parallel = True -source = synapse - -[paths] -source= - coverage +include = synapse/* [report] precision = 2 diff --git a/.gitignore b/.gitignore index d739595c3a..1033124f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,9 +25,9 @@ homeserver*.pid *.tls.dh *.tls.key -.coverage -.coverage.* -!.coverage.rc +.coverage* +coverage.* +!.coveragerc htmlcov demo/*/*.db diff --git a/changelog.d/4306.misc b/changelog.d/4306.misc new file mode 100644 index 0000000000..58130b6190 --- /dev/null +++ b/changelog.d/4306.misc @@ -0,0 +1 @@ +Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 865b5e915a..254fdc04c6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -192,6 +192,41 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = {"user_ips"} + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later(0.0, self._check_safe_to_upsert) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait a second and check again. + """ + updates = yield self._simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + # The User IPs table in schema #53 was missing a unique index, which we + # run as a background update. + if "user_ips_device_unique_index" not in updates: + self._unsafe_to_upsert_tables.discard("user_id") + + # If there's any tables left to check, reschedule to run. + if self._unsafe_to_upsert_tables: + self._clock.call_later(1.0, self._check_safe_to_upsert) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -494,8 +529,15 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert(self, table, keyvalues, values, - insertion_values={}, desc="_simple_upsert", lock=True): + def _simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="_simple_upsert", + lock=True + ): """ `lock` should generally be set to True (the default), but can be set @@ -516,16 +558,21 @@ class SQLBaseStore(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(bool): True if a new entry was created, False if an - existing one was updated. + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. """ attempts = 0 while True: try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock=lock + self._simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, ) defer.returnValue(result) except self.database_engine.module.IntegrityError as e: @@ -537,12 +584,59 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "IntegrityError when upserting into %s; retrying: %s", - table, e + "%s when upserting into %s; retrying: %s", e.__name__, table, e ) - def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, - lock=True): + def _simple_upsert_txn( + self, + txn, + table, + keyvalues, + values, + insertion_values={}, + lock=True, + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self._simple_upsert_txn_native_upsert( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + ) + else: + return self._simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def _simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): # We need to lock the table :(, unless we're *really* careful if lock: self.database_engine.lock_table(txn, table) @@ -577,12 +671,44 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) + ", ".join("?" for _ in allvalues), ) txn.execute(sql, list(allvalues.values())) # successfully inserted return True + def _simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = ( + "INSERT INTO %s (%s) VALUES (%s) " + "ON CONFLICT (%s) DO UPDATE SET %s" + ) % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + ", ".join(k + "=EXCLUDED." + k for k in values), + ) + txn.execute(sql, list(allvalues.values())) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index b228a20ac2..091d7116c5 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -257,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) def _update_client_ips_batch_txn(self, txn, to_update): - self.database_engine.lock_table(txn, "user_ips") + if "user_ips" in self._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index e2f9de8451..ff5ef97ca8 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,7 +18,7 @@ import platform from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine -from .sqlite3 import Sqlite3Engine +from .sqlite import Sqlite3Engine SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 42225f8a2a..4004427c7b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -38,6 +38,13 @@ class PostgresEngine(object): return sql.replace("?", "%s") def on_new_connection(self, db_conn): + + # Get the version of PostgreSQL that we're using. As per the psycopg2 + # docs: The number is formed by converting the major, minor, and + # revision numbers into two-decimal-digit numbers and appending them + # together. For example, version 8.1.5 will be returned as 80105 + self._version = db_conn.server_version + db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -54,6 +61,13 @@ class PostgresEngine(object): cursor.close() + @property + def can_native_upsert(self): + """ + Can we use native UPSERTs? This requires PostgreSQL 9.5+. + """ + return self._version >= 90500 + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py new file mode 100644 index 0000000000..c64d73ff21 --- /dev/null +++ b/synapse/storage/engines/sqlite.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +import threading +from sqlite3 import sqlite_version_info + +from synapse.storage.prepare_database import prepare_database + + +class Sqlite3Engine(object): + single_threaded = True + + def __init__(self, database_module, database_config): + self.module = database_module + + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + + @property + def can_native_upsert(self): + """ + Do we support native UPSERTs? This requires SQLite3 3.24+, plus some + more work we haven't done yet to tell what was inserted vs updated. + """ + return sqlite_version_info >= (3, 24, 0) + + def check_database(self, txn): + pass + + def convert_param_style(self, sql): + return sql + + def on_new_connection(self, db_conn): + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) + + def is_deadlock(self, error): + return False + + def is_connection_closed(self, conn): + return False + + def lock_table(self, txn, table): + return + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + + +# Following functions taken from: https://github.com/coleifer/peewee + +def _parse_match_info(buf): + bufsize = len(buf) + return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] + + +def _rank(raw_match_info): + """Handle match_info called w/default args 'pcx' - based on the example rank + function http://sqlite.org/fts3.html#appendix_a + """ + match_info = _parse_match_info(raw_match_info) + score = 0.0 + p, c = match_info[:2] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + col_idx = phrase_info_idx + (col_num * 3) + x1, x2 = match_info[col_idx:col_idx + 2] + if x1 > 0: + score += float(x1) / x2 + return score diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py deleted file mode 100644 index 19949fc474..0000000000 --- a/synapse/storage/engines/sqlite3.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import struct -import threading - -from synapse.storage.prepare_database import prepare_database - - -class Sqlite3Engine(object): - single_threaded = True - - def __init__(self, database_module, database_config): - self.module = database_module - - # The current max state_group, or None if we haven't looked - # in the DB yet. - self._current_state_group_id = None - self._current_state_group_id_lock = threading.Lock() - - def check_database(self, txn): - pass - - def convert_param_style(self, sql): - return sql - - def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) - db_conn.create_function("rank", 1, _rank) - - def is_deadlock(self, error): - return False - - def is_connection_closed(self, conn): - return False - - def lock_table(self, txn, table): - return - - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - - -# Following functions taken from: https://github.com/coleifer/peewee - -def _parse_match_info(buf): - bufsize = len(buf) - return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] - - -def _rank(raw_match_info): - """Handle match_info called w/default args 'pcx' - based on the example rank - function http://sqlite.org/fts3.html#appendix_a - """ - match_info = _parse_match_info(raw_match_info) - score = 0.0 - p, c = match_info[:2] - for phrase_num in range(p): - phrase_info_idx = 2 + (phrase_num * c * 3) - for col_num in range(c): - col_idx = phrase_info_idx + (col_num * 3) - x1, x2 = match_info[col_idx:col_idx + 2] - if x1 > 0: - score += float(x1) / x2 - return score diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 2743b52bad..134297e284 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so _simple_upsert will retry - newly_inserted = yield self._simple_upsert( + yield self._simple_upsert( table="pushers", keyvalues={ "app_id": app_id, @@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore): lock=False, ) - if newly_inserted: + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before yield self.runInteraction( "add_pusher", self._invalidate_cache_and_stream, diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index a8781b0e5d..ce48212265 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if new_entry: + if self.database_engine.can_native_upsert: sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('english', ?), 'A') || setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( sql, @@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore): ) ) else: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, user_id, + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, user_id, + ) + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" ) - ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name,) if display_name else user_id self._simple_upsert_txn( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 829f47d2e8..452d76ddd5 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection config = Mock() + config._enable_native_upserts = False config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} hs = TestHomeServer( diff --git a/tests/test_server.py b/tests/test_server.py index 634a8fbca5..08fb3fe02f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -19,7 +19,7 @@ from six import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -30,12 +30,18 @@ from synapse.util import Clock from synapse.util.logcontext import make_deferred_yieldable from tests import unittest -from tests.server import FakeTransport, make_request, render, setup_test_homeserver +from tests.server import ( + FakeTransport, + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) class JsonResourceTests(unittest.TestCase): def setUp(self): - self.reactor = MemoryReactorClock() + self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor diff --git a/tests/unittest.py b/tests/unittest.py index 78d2f740f9..cda549c783 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -96,7 +96,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING)) @around(self) def setUp(orig): @@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) - return setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not self.get_success(stor.has_completed_background_updates()): + self.get_success(stor.do_next_background_update(1)) + + return hs def pump(self, by=0.0): """ diff --git a/tox.ini b/tox.ini index a0f5486829..9b2d78ed6d 100644 --- a/tox.ini +++ b/tox.ini @@ -149,4 +149,5 @@ deps = codecov commands = coverage combine + coverage xml codecov -X gcov -- cgit 1.5.1 From 068aa1d22840a1154bb8fbdd445a8c36b290db91 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Jan 2019 12:44:27 +0000 Subject: Time out filtered room dir queries after 60s --- synapse/handlers/room_list.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index dc88620885..ea63fb604c 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -31,6 +31,7 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler +from datetime import datetime, timedelta logger = logging.getLogger(__name__) @@ -73,8 +74,13 @@ class RoomListHandler(BaseHandler): # We explicitly don't bother caching searches or requests for # appservice specific lists. logger.info("Bypassing cache as search request.") + + # XXX: Quick hack to stop room directory queries taking too long. + # Timeout request after 60s. Probably want a more fundamental + # solution at some point + timeout = datetime.now() + timedelta(seconds=60) return self._get_public_room_list( - limit, since_token, search_filter, network_tuple=network_tuple, + limit, since_token, search_filter, network_tuple=network_tuple, timeout=timeout, ) key = (limit, since_token, network_tuple) @@ -87,7 +93,8 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID,): + network_tuple=EMPTY_THIRD_PARTY_ID, + timeout=None,): if since_token and since_token != "END": since_token = RoomListNextBatch.from_token(since_token) else: @@ -202,6 +209,9 @@ class RoomListHandler(BaseHandler): chunk = [] for i in range(0, len(rooms_to_scan), step): + if timeout and datetime.now() > timeout: + raise Exception("Timed out searching room directory") + batch = rooms_to_scan[i:i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( -- cgit 1.5.1 From 5541645e80d2907721f17f648717f0b5a2b6f4fe Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Jan 2019 12:45:32 +0000 Subject: lint --- synapse/handlers/room_list.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ea63fb604c..a99b6e1460 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -80,7 +80,8 @@ class RoomListHandler(BaseHandler): # solution at some point timeout = datetime.now() + timedelta(seconds=60) return self._get_public_room_list( - limit, since_token, search_filter, network_tuple=network_tuple, timeout=timeout, + limit, since_token, search_filter, + network_tuple=network_tuple, timeout=timeout, ) key = (limit, since_token, network_tuple) -- cgit 1.5.1 From a2d85144e54457df2aae2c1a759f1baae910de91 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Jan 2019 14:22:26 +0000 Subject: isort --- synapse/handlers/room_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index a99b6e1460..5f7b33473e 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from datetime import datetime, timedelta from six import PY3, iteritems from six.moves import range @@ -31,7 +32,6 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler -from datetime import datetime, timedelta logger = logging.getLogger(__name__) -- cgit 1.5.1