From d619b113edf2942185a502a91cbf5b51642f6814 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 22 Jan 2019 16:52:29 +0000 Subject: Fix None guard in config.server.is_threepid_reserved --- tests/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 08d6faa0a6..df73c539c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -154,7 +154,9 @@ def default_config(name): config.update_user_directory = False def is_threepid_reserved(threepid): - return ServerConfig.is_threepid_reserved(config, threepid) + return ServerConfig.is_threepid_reserved( + config.mau_limits_reserved_threepids, threepid + ) config.is_threepid_reserved.side_effect = is_threepid_reserved -- cgit 1.4.1 From 97fd29c019ae92cd3dc0635de249acfc9c892340 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Jan 2019 09:34:44 +0000 Subject: Don't send IP addresses as SNI (#4452) The problem here is that we have cut-and-pasted an impl from Twisted, and then failed to maintain it. It was fixed in Twisted in https://github.com/twisted/twisted/pull/1047/files; let's do the same here. --- changelog.d/4452.bugfix | 1 + synapse/crypto/context_factory.py | 15 ++++-- .../federation/test_matrix_federation_agent.py | 63 ++++++++++++++++++++-- 3 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 changelog.d/4452.bugfix (limited to 'tests') diff --git a/changelog.d/4452.bugfix b/changelog.d/4452.bugfix new file mode 100644 index 0000000000..a715ca3788 --- /dev/null +++ b/changelog.d/4452.bugfix @@ -0,0 +1 @@ +Don't send IP addresses as SNI diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 6ba3eca7b2..286ad80100 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -17,6 +17,7 @@ from zope.interface import implementer from OpenSSL import SSL, crypto from twisted.internet._sslverify import _defaultCurveName +from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.internet.ssl import CertificateOptions, ContextFactory from twisted.python.failure import Failure @@ -98,8 +99,14 @@ class ClientTLSOptions(object): def __init__(self, hostname, ctx): self._ctx = ctx - self._hostname = hostname - self._hostnameBytes = _idnaBytes(hostname) + + if isIPAddress(hostname) or isIPv6Address(hostname): + self._hostnameBytes = hostname.encode('ascii') + self._sendSNI = False + else: + self._hostnameBytes = _idnaBytes(hostname) + self._sendSNI = True + ctx.set_info_callback( _tolerateErrors(self._identityVerifyingInfoCallback) ) @@ -111,7 +118,9 @@ class ClientTLSOptions(object): return connection def _identityVerifyingInfoCallback(self, connection, where, ret): - if where & SSL.SSL_CB_HANDSHAKE_START: + # Literal IPv4 and IPv6 addresses are not permitted + # as host names according to the RFCs + if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI: connection.set_tlsext_host_name(self._hostnameBytes) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index eb963d80fb..7a3881f558 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -46,7 +46,7 @@ class MatrixFederationAgentTests(TestCase): _srv_resolver=self.mock_resolver, ) - def _make_connection(self, client_factory): + def _make_connection(self, client_factory, expected_sni): """Builds a test server, and completes the outgoing client connection Returns: @@ -69,9 +69,17 @@ class MatrixFederationAgentTests(TestCase): # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor)) - # finally, give the reactor a pump to get the TLS juices flowing. + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) + # check the SNI + server_name = server_tls_protocol._tlsConnection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + # fish the test server back out of the server-side TLS protocol. return server_tls_protocol.wrappedProtocol @@ -113,7 +121,10 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory) + http_server = self._make_connection( + client_factory, + expected_sni=b"testserv", + ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -150,6 +161,52 @@ class MatrixFederationAgentTests(TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): + """ + Test the behaviour when the server name contains an explicit IP (with no port) + """ + + # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?) + self.mock_resolver.resolve_service.side_effect = lambda _: [] + + # then there will be a getaddrinfo on the IP + self.reactor.lookups["1.2.3.4"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once() + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=None, + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') + # XXX currently broken + # self.assertEqual( + # request.requestHeaders.getRawHeaders(b'host'), + # [b'1.2.3.4:8448'] + # ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + def _check_logcontext(context): current = LoggingContext.current_context() -- cgit 1.4.1 From 58f6c4818337364dd9c6bf01062e7b0dadcb8a25 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 24 Jan 2019 21:31:54 +1100 Subject: Use native UPSERTs where possible (#4306) --- .coveragerc | 6 +- .gitignore | 6 +- changelog.d/4306.misc | 1 + synapse/storage/_base.py | 148 +++++++++++++++++++++++++++++++++--- synapse/storage/client_ips.py | 5 +- synapse/storage/engines/__init__.py | 2 +- synapse/storage/engines/postgres.py | 14 ++++ synapse/storage/engines/sqlite.py | 96 +++++++++++++++++++++++ synapse/storage/engines/sqlite3.py | 87 --------------------- synapse/storage/pusher.py | 9 ++- synapse/storage/user_directory.py | 55 ++++++++++---- tests/storage/test_base.py | 1 + tests/test_server.py | 12 ++- tests/unittest.py | 12 ++- tox.ini | 1 + 15 files changed, 325 insertions(+), 130 deletions(-) create mode 100644 changelog.d/4306.misc create mode 100644 synapse/storage/engines/sqlite.py delete mode 100644 synapse/storage/engines/sqlite3.py (limited to 'tests') diff --git a/.coveragerc b/.coveragerc index 9873a30738..e9460a340a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,11 +1,7 @@ [run] branch = True parallel = True -source = synapse - -[paths] -source= - coverage +include = synapse/* [report] precision = 2 diff --git a/.gitignore b/.gitignore index d739595c3a..1033124f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,9 +25,9 @@ homeserver*.pid *.tls.dh *.tls.key -.coverage -.coverage.* -!.coverage.rc +.coverage* +coverage.* +!.coveragerc htmlcov demo/*/*.db diff --git a/changelog.d/4306.misc b/changelog.d/4306.misc new file mode 100644 index 0000000000..58130b6190 --- /dev/null +++ b/changelog.d/4306.misc @@ -0,0 +1 @@ +Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 865b5e915a..254fdc04c6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -192,6 +192,41 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = {"user_ips"} + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later(0.0, self._check_safe_to_upsert) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait a second and check again. + """ + updates = yield self._simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + # The User IPs table in schema #53 was missing a unique index, which we + # run as a background update. + if "user_ips_device_unique_index" not in updates: + self._unsafe_to_upsert_tables.discard("user_id") + + # If there's any tables left to check, reschedule to run. + if self._unsafe_to_upsert_tables: + self._clock.call_later(1.0, self._check_safe_to_upsert) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -494,8 +529,15 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert(self, table, keyvalues, values, - insertion_values={}, desc="_simple_upsert", lock=True): + def _simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="_simple_upsert", + lock=True + ): """ `lock` should generally be set to True (the default), but can be set @@ -516,16 +558,21 @@ class SQLBaseStore(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(bool): True if a new entry was created, False if an - existing one was updated. + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. """ attempts = 0 while True: try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock=lock + self._simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, ) defer.returnValue(result) except self.database_engine.module.IntegrityError as e: @@ -537,12 +584,59 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "IntegrityError when upserting into %s; retrying: %s", - table, e + "%s when upserting into %s; retrying: %s", e.__name__, table, e ) - def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, - lock=True): + def _simple_upsert_txn( + self, + txn, + table, + keyvalues, + values, + insertion_values={}, + lock=True, + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self._simple_upsert_txn_native_upsert( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + ) + else: + return self._simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def _simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): # We need to lock the table :(, unless we're *really* careful if lock: self.database_engine.lock_table(txn, table) @@ -577,12 +671,44 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) + ", ".join("?" for _ in allvalues), ) txn.execute(sql, list(allvalues.values())) # successfully inserted return True + def _simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = ( + "INSERT INTO %s (%s) VALUES (%s) " + "ON CONFLICT (%s) DO UPDATE SET %s" + ) % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + ", ".join(k + "=EXCLUDED." + k for k in values), + ) + txn.execute(sql, list(allvalues.values())) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index b228a20ac2..091d7116c5 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -257,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) def _update_client_ips_batch_txn(self, txn, to_update): - self.database_engine.lock_table(txn, "user_ips") + if "user_ips" in self._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index e2f9de8451..ff5ef97ca8 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,7 +18,7 @@ import platform from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine -from .sqlite3 import Sqlite3Engine +from .sqlite import Sqlite3Engine SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 42225f8a2a..4004427c7b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -38,6 +38,13 @@ class PostgresEngine(object): return sql.replace("?", "%s") def on_new_connection(self, db_conn): + + # Get the version of PostgreSQL that we're using. As per the psycopg2 + # docs: The number is formed by converting the major, minor, and + # revision numbers into two-decimal-digit numbers and appending them + # together. For example, version 8.1.5 will be returned as 80105 + self._version = db_conn.server_version + db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -54,6 +61,13 @@ class PostgresEngine(object): cursor.close() + @property + def can_native_upsert(self): + """ + Can we use native UPSERTs? This requires PostgreSQL 9.5+. + """ + return self._version >= 90500 + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py new file mode 100644 index 0000000000..c64d73ff21 --- /dev/null +++ b/synapse/storage/engines/sqlite.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +import threading +from sqlite3 import sqlite_version_info + +from synapse.storage.prepare_database import prepare_database + + +class Sqlite3Engine(object): + single_threaded = True + + def __init__(self, database_module, database_config): + self.module = database_module + + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + + @property + def can_native_upsert(self): + """ + Do we support native UPSERTs? This requires SQLite3 3.24+, plus some + more work we haven't done yet to tell what was inserted vs updated. + """ + return sqlite_version_info >= (3, 24, 0) + + def check_database(self, txn): + pass + + def convert_param_style(self, sql): + return sql + + def on_new_connection(self, db_conn): + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) + + def is_deadlock(self, error): + return False + + def is_connection_closed(self, conn): + return False + + def lock_table(self, txn, table): + return + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + + +# Following functions taken from: https://github.com/coleifer/peewee + +def _parse_match_info(buf): + bufsize = len(buf) + return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] + + +def _rank(raw_match_info): + """Handle match_info called w/default args 'pcx' - based on the example rank + function http://sqlite.org/fts3.html#appendix_a + """ + match_info = _parse_match_info(raw_match_info) + score = 0.0 + p, c = match_info[:2] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + col_idx = phrase_info_idx + (col_num * 3) + x1, x2 = match_info[col_idx:col_idx + 2] + if x1 > 0: + score += float(x1) / x2 + return score diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py deleted file mode 100644 index 19949fc474..0000000000 --- a/synapse/storage/engines/sqlite3.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import struct -import threading - -from synapse.storage.prepare_database import prepare_database - - -class Sqlite3Engine(object): - single_threaded = True - - def __init__(self, database_module, database_config): - self.module = database_module - - # The current max state_group, or None if we haven't looked - # in the DB yet. - self._current_state_group_id = None - self._current_state_group_id_lock = threading.Lock() - - def check_database(self, txn): - pass - - def convert_param_style(self, sql): - return sql - - def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) - db_conn.create_function("rank", 1, _rank) - - def is_deadlock(self, error): - return False - - def is_connection_closed(self, conn): - return False - - def lock_table(self, txn, table): - return - - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - - -# Following functions taken from: https://github.com/coleifer/peewee - -def _parse_match_info(buf): - bufsize = len(buf) - return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] - - -def _rank(raw_match_info): - """Handle match_info called w/default args 'pcx' - based on the example rank - function http://sqlite.org/fts3.html#appendix_a - """ - match_info = _parse_match_info(raw_match_info) - score = 0.0 - p, c = match_info[:2] - for phrase_num in range(p): - phrase_info_idx = 2 + (phrase_num * c * 3) - for col_num in range(c): - col_idx = phrase_info_idx + (col_num * 3) - x1, x2 = match_info[col_idx:col_idx + 2] - if x1 > 0: - score += float(x1) / x2 - return score diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 2743b52bad..134297e284 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so _simple_upsert will retry - newly_inserted = yield self._simple_upsert( + yield self._simple_upsert( table="pushers", keyvalues={ "app_id": app_id, @@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore): lock=False, ) - if newly_inserted: + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before yield self.runInteraction( "add_pusher", self._invalidate_cache_and_stream, diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index a8781b0e5d..ce48212265 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if new_entry: + if self.database_engine.can_native_upsert: sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('english', ?), 'A') || setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( sql, @@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore): ) ) else: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, user_id, + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, user_id, + ) + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" ) - ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name,) if display_name else user_id self._simple_upsert_txn( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 829f47d2e8..452d76ddd5 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection config = Mock() + config._enable_native_upserts = False config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} hs = TestHomeServer( diff --git a/tests/test_server.py b/tests/test_server.py index 634a8fbca5..08fb3fe02f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -19,7 +19,7 @@ from six import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -30,12 +30,18 @@ from synapse.util import Clock from synapse.util.logcontext import make_deferred_yieldable from tests import unittest -from tests.server import FakeTransport, make_request, render, setup_test_homeserver +from tests.server import ( + FakeTransport, + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) class JsonResourceTests(unittest.TestCase): def setUp(self): - self.reactor = MemoryReactorClock() + self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor diff --git a/tests/unittest.py b/tests/unittest.py index 78d2f740f9..cda549c783 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -96,7 +96,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING)) @around(self) def setUp(orig): @@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) - return setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not self.get_success(stor.has_completed_background_updates()): + self.get_success(stor.do_next_background_update(1)) + + return hs def pump(self, by=0.0): """ diff --git a/tox.ini b/tox.ini index a0f5486829..9b2d78ed6d 100644 --- a/tox.ini +++ b/tox.ini @@ -149,4 +149,5 @@ deps = codecov commands = coverage combine + coverage xml codecov -X gcov -- cgit 1.4.1 From e1c8440e0cd6d25d09cee71a56067b658eca97ee Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 24 Jan 2019 13:28:07 +0000 Subject: lots more tests for MatrixFederationAgent --- .../federation/test_matrix_federation_agent.py | 89 +++++++++++++++++++--- 1 file changed, 79 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 7a3881f558..bfae69a978 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.federation.srv_resolver import Server from synapse.util.logcontext import LoggingContext from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -105,7 +106,7 @@ class MatrixFederationAgentTests(TestCase): def test_get(self): """ - happy-path test of a GET request + happy-path test of a GET request with an explicit port """ self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") @@ -130,10 +131,6 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv:8448'] - ) content = request.content.read() self.assertEqual(content, b'') @@ -196,11 +193,83 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - # XXX currently broken - # self.assertEqual( - # request.requestHeaders.getRawHeaders(b'host'), - # [b'1.2.3.4:8448'] - # ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_get_hostname_no_srv(self): + """ + Test the behaviour when the server name has no port, and no SRV record + """ + + self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.reactor.lookups["testserv"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once() + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=b'testserv', + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_get_hostname_srv(self): + """ + Test the behaviour when there is a single SRV record + """ + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host="srvtarget", port=8443) + ] + self.reactor.lookups["srvtarget"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once() + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=b'testserv', + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') # finish the request request.finish() -- cgit 1.4.1 From afd69a0920d16bdd9ca0c5cf9238e48986424ecb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 24 Jan 2019 13:29:33 +0000 Subject: Look up the right SRV record --- changelog.d/4464.misc | 1 + synapse/http/federation/matrix_federation_agent.py | 3 ++- tests/http/federation/test_matrix_federation_agent.py | 12 +++++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 changelog.d/4464.misc (limited to 'tests') diff --git a/changelog.d/4464.misc b/changelog.d/4464.misc new file mode 100644 index 0000000000..9a51434755 --- /dev/null +++ b/changelog.d/4464.misc @@ -0,0 +1 @@ +Move SRV logic into the Agent layer diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 64c780a341..0ec28c6696 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -101,7 +101,8 @@ class MatrixFederationAgent(object): if port is not None: target = (host, port) else: - server_list = yield self._srv_resolver.resolve_service(server_name_bytes) + service_name = b"_matrix._tcp.%s" % (server_name_bytes, ) + server_list = yield self._srv_resolver.resolve_service(service_name) if not server_list: target = (host, 8448) logger.debug("No SRV record for %s, using %s", host, target) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index bfae69a978..b32d7566a5 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -174,7 +174,9 @@ class MatrixFederationAgentTests(TestCase): # Nothing happened yet self.assertNoResult(test_d) - self.mock_resolver.resolve_service.assert_called_once() + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.1.2.3.4", + ) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -212,7 +214,9 @@ class MatrixFederationAgentTests(TestCase): # Nothing happened yet self.assertNoResult(test_d) - self.mock_resolver.resolve_service.assert_called_once() + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv", + ) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -251,7 +255,9 @@ class MatrixFederationAgentTests(TestCase): # Nothing happened yet self.assertNoResult(test_d) - self.mock_resolver.resolve_service.assert_called_once() + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv", + ) # Make sure treq is trying to connect clients = self.reactor.tcpClients -- cgit 1.4.1