From 91ea0202e6f4a519e332a6c456aedfe4b7d627c9 Mon Sep 17 00:00:00 2001 From: Krombel Date: Wed, 14 Mar 2018 16:45:37 +0100 Subject: move handling of auto_join_rooms to RegisterHandler Currently the handling of auto_join_rooms only works when a user registers itself via public register api. Registrations via registration_shared_secret and ModuleApi do not work This auto_joins the users in the registration handler which enables the auto join feature for all 3 registration paths. This is related to issue #2725 Signed-Off-by: Matthias Kesler --- tests/rest/client/v1/test_events.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 2b89c0a3c7..a8d09600bd 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -123,6 +123,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.ratelimiter.send_message.return_value = (True, 0) hs.config.enable_registration_captcha = False hs.config.enable_registration = True + hs.config.auto_join_rooms = [] hs.get_handlers().federation_handler = Mock() -- cgit 1.4.1 From 72251d1b979db0bc96e5d95ac70b8e1cd78cde7c Mon Sep 17 00:00:00 2001 From: Silke Date: Tue, 20 Mar 2018 10:40:16 +0100 Subject: Remove address resolution of hosts in SRV records Signed-off-by: Silke Hofstra --- synapse/http/endpoint.py | 103 ++++------------------------------------------- tests/test_dns.py | 29 +------------ 2 files changed, 10 insertions(+), 122 deletions(-) (limited to 'tests') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 87639b9151..00572c2897 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import socket - from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor from twisted.internet.error import ConnectError @@ -33,7 +31,7 @@ SERVER_CACHE = {} # our record of an individual server which can be tried to reach a destination. # -# "host" is actually a dotted-quad or ipv6 address string. Except when there's +# "host" is the hostname acquired from the SRV record. Except when there's # no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" @@ -297,20 +295,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t payload = answer.payload - hosts = yield _get_hosts_for_srv_record( - dns_client, str(payload.target) - ) - - for (ip, ttl) in hosts: - host_ttl = min(answer.ttl, ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) @@ -328,81 +319,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t raise e defer.returnValue(servers) - - -@defer.inlineCallbacks -def _get_hosts_for_srv_record(dns_client, host): - """Look up each of the hosts in a SRV record - - Args: - dns_client (twisted.names.dns.IResolver): - host (basestring): host to look up - - Returns: - Deferred[list[(str, int)]]: a list of (host, ttl) pairs - - """ - ip4_servers = [] - ip6_servers = [] - - def cb(res): - # lookupAddress and lookupIP6Address return a three-tuple - # giving the answer, authority, and additional sections of the - # response. - # - # we only care about the answers. - - return res[0] - - def eb(res, record_type): - if res.check(DNSNameError): - return [] - logger.warn("Error looking up %s for %s: %s", record_type, host, res) - return res - - # no logcontexts here, so we can safely fire these off and gatherResults - d1 = dns_client.lookupAddress(host).addCallbacks( - cb, eb, errbackArgs=("A", )) - d2 = dns_client.lookupIPV6Address(host).addCallbacks( - cb, eb, errbackArgs=("AAAA", )) - results = yield defer.DeferredList( - [d1, d2], consumeErrors=True) - - # if all of the lookups failed, raise an exception rather than blowing out - # the cache with an empty result. - if results and all(s == defer.FAILURE for (s, _) in results): - defer.returnValue(results[0][1]) - - for (success, result) in results: - if success == defer.FAILURE: - continue - - for answer in result: - if not answer.payload: - continue - - try: - if answer.type == dns.A: - ip = answer.payload.dottedQuad() - ip4_servers.append((ip, answer.ttl)) - elif answer.type == dns.AAAA: - ip = socket.inet_ntop( - socket.AF_INET6, answer.payload.address, - ) - ip6_servers.append((ip, answer.ttl)) - else: - # the most likely candidate here is a CNAME record. - # rfc2782 says srvs may not point to aliases. - logger.warn( - "Ignoring unexpected DNS record type %s for %s", - answer.type, host, - ) - continue - except Exception as e: - logger.warn("Ignoring invalid DNS response for %s: %s", - host, e) - continue - - # keep the ipv4 results before the ipv6 results, mostly to match historical - # behaviour. - defer.returnValue(ip4_servers + ip6_servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index d08b0f4333..af607d626f 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -33,8 +33,6 @@ class DnsTestCase(unittest.TestCase): service_name = "test_service.example.com" host_name = "example.com" - ip_address = "127.0.0.1" - ip6_address = "::1" answer_srv = dns.RRHeader( type=dns.SRV, @@ -43,29 +41,9 @@ class DnsTestCase(unittest.TestCase): ) ) - answer_a = dns.RRHeader( - type=dns.A, - payload=dns.Record_A( - address=ip_address, - ) - ) - - answer_aaaa = dns.RRHeader( - type=dns.AAAA, - payload=dns.Record_AAAA( - address=ip6_address, - ) - ) - dns_client_mock.lookupService.return_value = defer.succeed( ([answer_srv], None, None), ) - dns_client_mock.lookupAddress.return_value = defer.succeed( - ([answer_a], None, None), - ) - dns_client_mock.lookupIPV6Address.return_value = defer.succeed( - ([answer_aaaa], None, None), - ) cache = {} @@ -74,13 +52,10 @@ class DnsTestCase(unittest.TestCase): ) dns_client_mock.lookupService.assert_called_once_with(service_name) - dns_client_mock.lookupAddress.assert_called_once_with(host_name) - dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name) - self.assertEquals(len(servers), 2) + self.assertEquals(len(servers), 1) self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, ip_address) - self.assertEquals(servers[1].host, ip6_address) + self.assertEquals(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): -- cgit 1.4.1 From 616835187702a0c6f16042e3efb452e1ee3e7826 Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Tue, 3 Apr 2018 20:41:21 +0200 Subject: Add b prefixes to some strings that are bytes in py3 This has no effect on python2 Signed-off-by: Adrian Tschira --- synapse/api/auth.py | 10 +++++----- synapse/app/frontend_proxy.py | 2 +- synapse/http/server.py | 4 ++-- synapse/http/site.py | 6 +++--- synapse/rest/client/v1/register.py | 4 ++-- tests/utils.py | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ac0a3655a5..f17fda6315 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -204,8 +204,8 @@ class Auth(object): ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - "User-Agent", - default=[""] + b"User-Agent", + default=[b""] )[0] if user and access_token and ip_addr: self.store.insert_client_ip( @@ -672,7 +672,7 @@ def has_access_token(request): bool: False if no access_token was given, True otherwise. """ query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders("Authorization") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -692,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401): AuthError: If there isn't an access_token in the request. """ - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") if auth_headers: # Try the get the access_token from a "Authorization: Bearer" # header diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index de889357c3..b349e3e3ce 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -90,7 +90,7 @@ class KeyUploadServlet(RestServlet): # They're actually trying to upload something, proxy to main synapse. # Pass through the auth headers, if any, in case the access token # is there. - auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) headers = { "Authorization": auth_headers, } diff --git a/synapse/http/server.py b/synapse/http/server.py index f19c068ef6..d979e76639 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -324,7 +324,7 @@ class JsonResource(HttpServer, resource.Resource): register_paths, so will return (possibly via Deferred) either None, or a tuple of (http code, response body). """ - if request.method == "OPTIONS": + if request.method == b"OPTIONS": return _options_handler, {} # Loop through all the registered callbacks to check if the method @@ -536,7 +536,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( - "User-Agent", default=[] + b"User-Agent", default=[] ) for user_agent in user_agents: if "curl" in user_agent: diff --git a/synapse/http/site.py b/synapse/http/site.py index e422c8dfae..c8b46e1af2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ import logging import re import time -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') class SynapseRequest(Request): @@ -43,12 +43,12 @@ class SynapseRequest(Request): def get_redacted_uri(self): return ACCESS_TOKEN_RE.sub( - r'\1\3', + br'\1\3', self.uri ) def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def started_processing(self): self.site.access_logger.info( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 5c5fa8f7ab..8a82097178 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -348,9 +348,9 @@ class RegisterRestServlet(ClientV1RestServlet): admin = register_json.get("admin", None) # Its important to check as we use null bytes as HMAC field separators - if "\x00" in user: + if b"\x00" in user: raise SynapseError(400, "Invalid user") - if "\x00" in password: + if b"\x00" in password: raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not diff --git a/tests/utils.py b/tests/utils.py index 8efd3a3475..f15317d27b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -212,7 +212,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it -- cgit 1.4.1 From 01afc563c39006c21bb7752831cd62c146edc135 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 5 Apr 2018 16:24:04 +0100 Subject: Fix overzealous cache invalidation Fixes an issue where a cache invalidation would invalidate *all* pending entries, rather than just the entry that we intended to invalidate. --- synapse/util/caches/descriptors.py | 64 +++++++++++++++++++++-------------- tests/util/caches/test_descriptors.py | 46 +++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 26 deletions(-) (limited to 'tests') diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index bf3a66eae4..68285a7594 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,12 +40,11 @@ _CacheSentinel = object() class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -62,7 +62,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -80,7 +79,6 @@ class Cache(object): self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None self.metrics = register_cache(name, self.cache) @@ -113,11 +111,10 @@ class Cache(object): callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: @@ -137,12 +134,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -150,13 +144,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, result, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -168,25 +174,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -194,8 +204,10 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in self._pending_deferred_cache.itervalues(): + entry.invalidate() + self._pending_deferred_cache.clear() class _CacheDescriptorBase(object): diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3f14ab503f..2516fe40f4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import logging import mock @@ -25,6 +27,50 @@ from tests import unittest logger = logging.getLogger(__name__) +class CacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = descriptors.Cache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return the deferreds + self.assertIs(cache.get("key1"), d1) + self.assertIs(cache.get("key2"), d2) + + # let one of the lookups complete + d2.callback("result2") + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue( + callback_record[0], "Invalidation callback for key1 not called", + ) + self.assertTrue( + callback_record[1], "Invalidation callback for key2 not called", + ) + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) + + class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.4.1 From e939f3bca6d73bc4ec6450aeb41f5e121a29a662 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Apr 2018 14:37:11 +0100 Subject: Fix tests --- tests/handlers/test_appservice.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index a667fb6f0e..b753455943 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,6 +31,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore = Mock(return_value=self.mock_store) + self.mock_store.get_received_ts.return_value = 0 hs.get_application_service_api = Mock(return_value=self.mock_as_api) hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) hs.get_clock.return_value = MockClock() -- cgit 1.4.1 From 1515560f5ce8054a539dd0fed9e88232883f079f Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 17:20:37 +0200 Subject: Use str(e) instead of e.message Doing this I learned e.message was pretty shortlived, added in 2.6, they realized it was a bad idea and deprecated it in 2.7 Signed-off-by: Adrian Tschira --- synapse/crypto/keyring.py | 8 ++++---- tests/storage/test_appservice.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 35f810b07b..fce83d445f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -352,7 +352,7 @@ class Keyring(object): logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) defer.returnValue({}) @@ -384,7 +384,7 @@ class Keyring(object): logger.info( "Unable to get key %r for %r directly: %s %s", key_ids, server_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) if not keys: @@ -734,7 +734,7 @@ def _handle_key_deferred(verify_request): except IOError as e: logger.warn( "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 502, @@ -744,7 +744,7 @@ def _handle_key_deferred(verify_request): except Exception as e: logger.exception( "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 401, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c2e39a7288..00825498b1 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -480,9 +480,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("id", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("id", str(e)) @defer.inlineCallbacks def test_duplicate_as_tokens(self): @@ -504,6 +504,6 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("as_token", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("as_token", str(e)) -- cgit 1.4.1 From cb9cdfecd097c5a0c85840be9d37d590f0f4af2d Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 17:18:01 +0200 Subject: Add some more variables to the unittest config These worked accidentally before (python2 doesn't complain if you compare incompatible types) but under py3 this blows up spectacularly Signed-off-by: Adrian Tschira --- tests/utils.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index f15317d27b..0cd9f7eeee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -59,6 +59,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.email_enable_notifs = False config.block_non_admin_invites = False config.federation_domain_whitelist = None + config.federation_rc_reject_limit = 10 + config.federation_rc_sleep_limit = 10 + config.federation_rc_concurrent = 10 + config.filter_timeline_limit = 5000 config.user_directory_search_all_users = False # disable user directory updates, because they get done in the -- cgit 1.4.1 From 2a3c33ff03aa88317c30da43cd3773c2789f0fcf Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 17:15:16 +0200 Subject: Use six.moves.urlparse The imports were shuffled around a bunch in py3 Signed-off-by: Adrian Tschira --- synapse/config/appservice.py | 4 ++-- synapse/http/matrixfederationclient.py | 3 +-- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v1/room.py | 9 +++++---- synapse/rest/media/v1/_base.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- tests/rest/client/v1/test_rooms.py | 14 +++++++------- tests/utils.py | 5 ++--- 8 files changed, 20 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 9a2359b6fd..277305e184 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -17,11 +17,11 @@ from ._base import Config, ConfigError from synapse.appservice import ApplicationService from synapse.types import UserID -import urllib import yaml import logging from six import string_types +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def _load_appservice(hostname, as_info, config_filename): ) localpart = as_info["sender_localpart"] - if urllib.quote(localpart) != localpart: + if urlparse.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 60a29081e8..c2e5610f50 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -38,8 +38,7 @@ import logging import random import sys import urllib -import urlparse - +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) outbound_logger = logging.getLogger("synapse.http.outbound") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 45844aa2d2..34df5be4e9 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -25,7 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib -import urlparse +from six.moves.urllib import parse as urlparse import logging from saml2 import BINDING_HTTP_POST diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2ad0e5943b..fcf9c9ab44 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -28,8 +28,9 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, parse_integer ) +from six.moves.urllib import parse as urlparse + import logging -import urllib import simplejson as json logger = logging.getLogger(__name__) @@ -433,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args filter_bytes = request.args.get("filter", None) if filter_bytes: - filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -718,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) - room_id = urllib.unquote(room_id) - target_user = UserID.from_string(urllib.unquote(user_id)) + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index e7ac01da01..d9c4af9389 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -28,7 +28,7 @@ import os import logging import urllib -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bb79599379..9800ce7581 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -47,7 +47,7 @@ import shutil import cgi import logging -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7e8966a1a8..d763400eaf 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -24,7 +24,7 @@ from synapse.api.constants import Membership from synapse.types import UserID import json -import urllib +from six.moves.urllib import parse as urlparse from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase @@ -766,7 +766,7 @@ class RoomMemberStateTestCase(RestTestCase): @defer.inlineCallbacks def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.user_id + urlparse.quote(self.room_id), self.user_id ) # valid join message (NOOP since we made the room) @@ -786,7 +786,7 @@ class RoomMemberStateTestCase(RestTestCase): def test_rooms_members_other(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), self.other_id ) # valid invite message @@ -802,7 +802,7 @@ class RoomMemberStateTestCase(RestTestCase): def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), self.other_id ) # valid invite message with custom key @@ -859,7 +859,7 @@ class RoomMessagesTestCase(RestTestCase): @defer.inlineCallbacks def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) # missing keys or invalid json (code, response) = yield self.mock_resource.trigger( "PUT", path, '{}' @@ -894,7 +894,7 @@ class RoomMessagesTestCase(RestTestCase): @defer.inlineCallbacks def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) content = '{"body":"test","msgtype":{"type":"a"}}' (code, response) = yield self.mock_resource.trigger("PUT", path, content) @@ -911,7 +911,7 @@ class RoomMessagesTestCase(RestTestCase): # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) content = '{"body":"test2","msgtype":"m.text"}' (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(200, code, msg=str(response)) diff --git a/tests/utils.py b/tests/utils.py index f15317d27b..9e815d8643 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,8 +15,7 @@ import hashlib from inspect import getcallargs -import urllib -import urlparse +from six.moves.urllib import parse as urlparse from mock import Mock, patch from twisted.internet import defer, reactor @@ -234,7 +233,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urllib.unquote(u).decode("UTF-8") + urlparse.unquote(u).decode("UTF-8") for u in matcher.groups() ] -- cgit 1.4.1 From a1a3c9660ffc1e99f656c1d39ebe57b8857b207d Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 21:46:23 +0200 Subject: Make tests py3 compatible This is a mixed commit that fixes various small issues * print parentheses * 01 is invalid syntax (it was octal in py2) * [x for i in 1, 2] is invalid syntax * six moves Signed-off-by: Adrian Tschira --- tests/config/test_load.py | 2 +- tests/crypto/test_keyring.py | 2 +- tests/test_distributor.py | 2 +- tests/util/test_file_consumer.py | 2 +- tests/util/test_linearizer.py | 3 ++- 5 files changed, 6 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 161a87d7e3..772afd2cf9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -24,7 +24,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print self.dir + print(self.dir) self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d4ec02ffc2..149e443022 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(01) + yield async.sleep(1) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index acebcf4a86..010aeaee7e 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -62,7 +62,7 @@ class DistributorTestCase(unittest.TestCase): def test_signal_catch(self): self.dist.declare("alarm") - observers = [Mock() for i in 1, 2] + observers = [Mock() for i in (1, 2)] for o in observers: self.dist.observe("alarm", o) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 76e2234255..d6e1082779 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -20,7 +20,7 @@ from mock import NonCallableMock from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest -from StringIO import StringIO +from six import StringIO import threading diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 793a88e462..4865eb4bc6 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -18,6 +18,7 @@ from tests import unittest from twisted.internet import defer from synapse.util.async import Linearizer +from six.moves import range class LinearizerTestCase(unittest.TestCase): @@ -58,7 +59,7 @@ class LinearizerTestCase(unittest.TestCase): logcontext.LoggingContext.current_context(), lc) func(0, sleep=True) - for i in xrange(1, 100): + for i in range(1, 100): func(i) return func(1000) -- cgit 1.4.1 From 639480e14a06723adf6817ddd2b2ff9e4f4cdf2a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 Apr 2018 18:41:37 +0100 Subject: Avoid creating events with huge numbers of prev_events In most cases, we limit the number of prev_events for a given event to 10 events. This fixes a particular code path which created events with huge numbers of prev_events. --- synapse/handlers/message.py | 78 +++++++++++++++++++--------------- synapse/handlers/room_member.py | 13 ++++-- synapse/storage/event_federation.py | 57 ++++++++++++++++++------- tests/storage/test_event_federation.py | 68 +++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 54 deletions(-) create mode 100644 tests/storage/test_event_federation.py (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54cd691f91..21628a8540 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -37,7 +37,6 @@ from ._base import BaseHandler from canonicaljson import encode_canonical_json import logging -import random import simplejson logger = logging.getLogger(__name__) @@ -433,7 +432,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_event_ids=None): + prev_events_and_hashes=None): """ Given a dict from a client, create a new event. @@ -447,7 +446,13 @@ class EventCreationHandler(object): event_dict (dict): An entire event token_id (str) txn_id (str) - prev_event_ids (list): The prev event ids to use when creating the event + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. Returns: Tuple of created event (FrozenEvent), Context @@ -485,7 +490,7 @@ class EventCreationHandler(object): event, context = yield self.create_new_client_event( builder=builder, requester=requester, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) defer.returnValue((event, context)) @@ -588,39 +593,44 @@ class EventCreationHandler(object): @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, prev_event_ids=None): - if prev_event_ids: - prev_events = yield self.store.add_event_hashes(prev_event_ids) - prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) - depth = prev_max_depth + 1 - else: - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, + def create_new_client_event(self, builder, requester=None, + prev_events_and_hashes=None): + """Create a new event for a local client + + Args: + builder (EventBuilder): + + requester (synapse.types.Requester|None): + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. + + Returns: + Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + """ + + if prev_events_and_hashes is not None: + assert len(prev_events_and_hashes) <= 10, \ + "Attempting to create an event with %i prev_events" % ( + len(prev_events_and_hashes), ) + else: + prev_events_and_hashes = \ + yield self.store.get_prev_events_for_room(builder.room_id) - # We want to limit the max number of prev events we point to in our - # new event - if len(latest_ret) > 10: - # Sort by reverse depth, so we point to the most recent. - latest_ret.sort(key=lambda a: -a[2]) - new_latest_ret = latest_ret[:5] - - # We also randomly point to some of the older events, to make - # sure that we don't completely ignore the older events. - if latest_ret[5:]: - sample_size = min(5, len(latest_ret[5:])) - new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) - latest_ret = new_latest_ret - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 + if prev_events_and_hashes: + depth = max([d for _, _, d in prev_events_and_hashes]) + 1 + else: + depth = 1 - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in prev_events_and_hashes + ] builder.prev_events = prev_events builder.depth = depth diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c45142d38d..714583f1d5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -149,7 +149,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, - prev_event_ids, + prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, @@ -175,7 +175,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. @@ -314,7 +314,12 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", ) - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -403,7 +408,7 @@ class RoomMemberHandler(object): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_event_ids=latest_event_ids, + prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 00ee82d300..a183fc6b50 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from twisted.internet import defer @@ -133,7 +134,47 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -182,22 +223,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py new file mode 100644 index 0000000000..30683e7888 --- /dev/null +++ b/tests/storage/test_event_federation.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils + + +class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_prev_events_for_room(self): + room_id = '@ROOM:local' + + # add a bunch of events and hashes to act as forward extremities + def insert_event(txn, i): + event_id = '$event_%i:local' % i + + txn.execute(( + "INSERT INTO events (" + " room_id, event_id, type, depth, topological_ordering," + " content, processed, outlier) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + ), (room_id, event_id, i, i, True, False)) + + txn.execute(( + 'INSERT INTO event_forward_extremities (room_id, event_id) ' + 'VALUES (?, ?)' + ), (room_id, event_id)) + + txn.execute(( + 'INSERT INTO event_reference_hashes ' + '(event_id, algorithm, hash) ' + "VALUES (?, 'sha256', ?)" + ), (event_id, 'ffff')) + + for i in range(0, 11): + yield self.store.runInteraction("insert", insert_event, i) + + # this should get the last five and five others + r = yield self.store.get_prev_events_for_room(room_id) + self.assertEqual(10, len(r)) + for i in range(0, 5): + el = r[i] + depth = el[2] + self.assertEqual(10 - i, depth) + + for i in range(5, 5): + el = r[i] + depth = el[2] + self.assertLessEqual(5, depth) -- cgit 1.4.1 From 1ea904b9f09ea6a9df7792f61029d0250602d46b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 23 Apr 2018 00:53:18 +0100 Subject: Use deferred.addTimeout instead of time_bound_deferred This doesn't feel like a wheel we need to reinvent. --- synapse/http/__init__.py | 22 +++++++++++++ synapse/http/client.py | 20 +++++------- synapse/http/matrixfederationclient.py | 35 ++++++++++----------- synapse/notifier.py | 23 +++++++------- synapse/util/__init__.py | 56 ---------------------------------- tests/util/test_clock.py | 33 -------------------- 6 files changed, 59 insertions(+), 130 deletions(-) delete mode 100644 tests/util/test_clock.py (limited to 'tests') diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index bfebb0f644..20e568bc43 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.api.errors import SynapseError + + +class RequestTimedOutError(SynapseError): + """Exception representing timeout of an outbound request""" + def __init__(self): + super(RequestTimedOutError, self).__init__(504, "Timed out") + + +def cancelled_to_request_timed_out_error(value): + """Turns CancelledErrors into RequestTimedOutErrors. + + For use with deferred.addTimeout() + """ + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise RequestTimedOutError() + return value diff --git a/synapse/http/client.py b/synapse/http/client.py index f3e4973c2e..35c8d51e71 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +19,9 @@ from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) +from synapse.http import cancelled_to_request_timed_out_error from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable -from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -95,21 +96,16 @@ class SimpleHttpClient(object): # counters to it outgoing_requests_counter.inc(method) - def send_request(): + logger.info("Sending request %s %s", method, uri) + + try: request_deferred = self.agent.request( method, uri, *args, **kwargs ) - - return self.clock.time_bound_deferred( - request_deferred, - time_out=60, + request_deferred.addTimeout( + 60, reactor, cancelled_to_request_timed_out_error, ) - - logger.info("Sending request %s %s", method, uri) - - try: - with logcontext.PreserveLoggingContext(): - response = yield send_request() + response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.inc(method, response.code) logger.info( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 60a29081e8..fe4b1636a6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.retryutils from twisted.internet import defer, reactor, protocol from twisted.internet.error import DNSLookupError from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone +from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint +import synapse.metrics from synapse.util.async import sleep from synapse.util import logcontext -import synapse.metrics +from synapse.util.logcontext import make_deferred_yieldable +import synapse.util.retryutils from canonicaljson import encode_canonical_json @@ -184,21 +187,19 @@ class MatrixFederationHttpClient(object): producer = body_callback(method, http_url_bytes, headers_dict) try: - def send_request(): - request_deferred = self.agent.request( - method, - url_bytes, - Headers(headers_dict), - producer - ) - - return self.clock.time_bound_deferred( - request_deferred, - time_out=timeout / 1000. if timeout else 60, - ) - - with logcontext.PreserveLoggingContext(): - response = yield send_request() + request_deferred = self.agent.request( + method, + url_bytes, + Headers(headers_dict), + producer + ) + request_deferred.addTimeout( + timeout / 1000. if timeout else 60, + reactor, cancelled_to_request_timed_out_error, + ) + response = yield make_deferred_yieldable( + request_deferred, + ) log_result = "%d %s" % (response.code, response.phrase,) break diff --git a/synapse/notifier.py b/synapse/notifier.py index 0e40a4aad6..1e4f78b993 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor +from twisted.internet.defer import TimeoutError + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state -from synapse.util import DeferredTimedOutError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred from synapse.util.logcontext import PreserveLoggingContext, preserve_fn @@ -331,11 +332,11 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) + listener.deferred.addTimeout( + (end_time - now) / 1000., reactor, + ) with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) + yield listener.deferred current_token = user_stream.current_token @@ -346,7 +347,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimedOutError: + except TimeoutError: break except defer.CancelledError: break @@ -551,13 +552,11 @@ class Notifier(object): if end_time <= now: break + listener.deferred.addTimeout((end_time - now) / 1000., reactor) try: with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) - except DeferredTimedOutError: + yield listener.deferred + except TimeoutError: break except defer.CancelledError: break diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 756d8ffa32..814a7bf71b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -24,11 +23,6 @@ import logging logger = logging.getLogger(__name__) -class DeferredTimedOutError(SynapseError): - def __init__(self): - super(DeferredTimedOutError, self).__init__(504, "Timed out") - - def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -85,53 +79,3 @@ class Clock(object): except Exception: if not ignore_errs: raise - - def time_bound_deferred(self, given_deferred, time_out): - if given_deferred.called: - return given_deferred - - ret_deferred = defer.Deferred() - - def timed_out_fn(): - e = DeferredTimedOutError() - - try: - ret_deferred.errback(e) - except Exception: - pass - - try: - given_deferred.cancel() - except Exception: - pass - - timer = None - - def cancel(res): - try: - self.cancel_call_later(timer) - except Exception: - pass - return res - - ret_deferred.addBoth(cancel) - - def success(res): - try: - ret_deferred.callback(res) - except Exception: - pass - - return res - - def err(res): - try: - ret_deferred.errback(res) - except Exception: - pass - - given_deferred.addCallbacks(callback=success, errback=err) - - timer = self.call_later(time_out, timed_out_fn) - - return ret_deferred diff --git a/tests/util/test_clock.py b/tests/util/test_clock.py deleted file mode 100644 index 9672603579..0000000000 --- a/tests/util/test_clock.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse import util -from twisted.internet import defer -from tests import unittest - - -class ClockTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_time_bound_deferred(self): - # just a deferred which never resolves - slow_deferred = defer.Deferred() - - clock = util.Clock() - time_bound = clock.time_bound_deferred(slow_deferred, 0.001) - - try: - yield time_bound - self.fail("Expected timedout error, but got nothing") - except util.DeferredTimedOutError: - pass -- cgit 1.4.1 From 6495dbb326dd2b5d58e5de25107f7fe6d13b6ca4 Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Mon, 30 Apr 2018 21:58:30 +0200 Subject: Burminate v1auth This closes #2602 v1auth was created to account for the differences in status code between the v1 and v2_alpha revisions of the protocol (401 vs 403 for invalid tokens). However since those protocols were merged, this makes the r0 version/endpoint internally inconsistent, and violates the specification for the r0 endpoint. This might break clients that rely on this inconsistency with the specification. This is said to affect the legacy angular reference client. However, I feel that restoring parity with the spec is more important. Either way, it is critical to inform developers about this change, in case they rely on the illegal behaviour. Signed-off-by: Adrian Tschira --- synapse/rest/client/v1/base.py | 6 +++++- synapse/rest/client/v1/pusher.py | 2 +- synapse/server.py | 10 ---------- tests/rest/client/v1/test_events.py | 9 +++++++-- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_rooms.py | 18 +++++++++--------- tests/rest/client/v1/test_typing.py | 2 +- 7 files changed, 24 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..197335d7aa 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 1819a560cb..0206e664c1 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet): super(RestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index cd0c1a51be..ebdea6b0c4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,7 +105,6 @@ class HomeServer(object): 'federation_client', 'federation_server', 'handlers', - 'v1auth', 'auth', 'state_handler', 'state_resolution_handler', @@ -225,15 +224,6 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf - def build_state_handler(self): return StateHandler(self) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index a8d09600bd..f5a7258e68 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -148,11 +148,16 @@ class EventStreamPermissionsTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_basic_permissions(self): - # invalid token, expect 403 + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 (code, response) = yield self.mock_resource.trigger_get( "/events?access_token=%s" % ("invalid" + self.token, ) ) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(401, code, msg=str(response)) # valid token, expect content (code, response) = yield self.mock_resource.trigger_get( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index deac7f100c..dc94b8bd19 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,7 +52,7 @@ class ProfileTestCase(unittest.TestCase): def _get_user_by_req(request=None, allow_guest=False): return synapse.types.create_requester(myid) - hs.get_v1auth().get_user_by_req = _get_user_by_req + hs.get_auth().get_user_by_req = _get_user_by_req profile.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d763400eaf..61d737725b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -60,7 +60,7 @@ class RoomPermissionsTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -70,7 +70,7 @@ class RoomPermissionsTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -425,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -507,7 +507,7 @@ class RoomsCreateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -597,7 +597,7 @@ class RoomTopicTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -711,7 +711,7 @@ class RoomMemberStateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -945,7 +945,7 @@ class RoomInitialSyncTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -1017,7 +1017,7 @@ class RoomMessageListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 2ec4ecab5b..fe161ee5cb 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -68,7 +68,7 @@ class RoomTypingTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) -- cgit 1.4.1 From e482f8cd8504b36dc1ce2c1e51e0dee479d33249 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 May 2018 09:12:26 +0100 Subject: Fix incorrect reference to StringIO This was introduced in 4f2f5171 --- synapse/util/logformatter.py | 2 +- tests/util/test_logformatter.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/util/test_logformatter.py (limited to 'tests') diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index 59ab3c6968..3e42868ea9 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter): super(LogFormatter, self).__init__(*args, **kwargs) def formatException(self, ei): - sio = StringIO.StringIO() + sio = StringIO() (typ, val, tb) = ei # log the stack above the exception capture point if possible, but diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py new file mode 100644 index 0000000000..1a1a8412f2 --- /dev/null +++ b/tests/util/test_logformatter.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from synapse.util.logformatter import LogFormatter +from tests import unittest + + +class TestException(Exception): + pass + + +class LogFormatterTestCase(unittest.TestCase): + def test_formatter(self): + formatter = LogFormatter() + + try: + raise TestException("testytest") + except TestException: + ei = sys.exc_info() + + output = formatter.formatException(ei) + + # check the output looks vaguely sane + self.assertIn("testytest", output) + self.assertIn("Capture point", output) -- cgit 1.4.1 From f22e7cda2c72d461acba664cb083e8c4e3c7572a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 May 2018 11:46:23 +0100 Subject: Fix a class of logcontext leaks So, it turns out that if you have a first `Deferred` `D1`, you can add a callback which returns another `Deferred` `D2`, and `D2` must then complete before any further callbacks on `D1` will execute (and later callbacks on `D1` get the *result* of `D2` rather than `D2` itself). So, `D1` might have `called=True` (as in, it has started running its callbacks), but any new callbacks added to `D1` won't get run until `D2` completes - so if you `yield D1` in an `inlineCallbacks` function, your `yield` will 'block'. In conclusion: some of our assumptions in `logcontext` were invalid. We need to make sure that we don't optimise out the logcontext juggling when this situation happens. Fortunately, it is easy to detect by checking `D1.paused`. --- synapse/util/logcontext.py | 60 ++++++++++++++++++++++++-------------- tests/util/test_logcontext.py | 67 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 94 insertions(+), 33 deletions(-) (limited to 'tests') diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 01ac71e53e..e086e12213 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -302,7 +302,7 @@ def preserve_fn(f): def run_in_background(f, *args, **kwargs): """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the - deferred returned by the funtion completes. + deferred returned by the function completes. Useful for wrapping functions that return a deferred which you don't yield on (for instance because you want to pass it to deferred.gatherResults()). @@ -320,24 +320,31 @@ def run_in_background(f, *args, **kwargs): # by synchronous exceptions, so let's turn them into Failures. return defer.fail() - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, LoggingContext.sentinel) + if not isinstance(res, defer.Deferred): + return res + + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about + return res + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) return res @@ -354,9 +361,18 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ - if isinstance(deferred, defer.Deferred) and not deferred.called: - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) - deferred.addBoth(_set_context_cb, prev_context) + if not isinstance(deferred, defer.Deferred): + return deferred + + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred + + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) return deferred diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 4850722bc5..7707fbb1f0 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -36,24 +36,28 @@ class LoggingContextTestCase(unittest.TestCase): yield sleep(0) self._check_test_key("one") - def _test_preserve_fn(self, function): + def _test_run_in_background(self, function): sentinel_context = LoggingContext.current_context() callback_completed = [False] - @defer.inlineCallbacks - def cb(): + def test(): context_one.request = "one" - yield function() - self._check_test_key("one") + d = function() - callback_completed[0] = True + def cb(res): + self._check_test_key("one") + callback_completed[0] = True + return res + d.addCallback(cb) + + return d with LoggingContext() as context_one: context_one.request = "one" # fire off function, but don't wait on it. - logcontext.preserve_fn(cb)() + logcontext.run_in_background(test) self._check_test_key("one") @@ -80,20 +84,31 @@ class LoggingContextTestCase(unittest.TestCase): # test is done once d2 finishes return d2 - def test_preserve_fn_with_blocking_fn(self): + def test_run_in_background_with_blocking_fn(self): @defer.inlineCallbacks def blocking_function(): yield sleep(0) - return self._test_preserve_fn(blocking_function) + return self._test_run_in_background(blocking_function) - def test_preserve_fn_with_non_blocking_fn(self): + def test_run_in_background_with_non_blocking_fn(self): @defer.inlineCallbacks def nonblocking_function(): with logcontext.PreserveLoggingContext(): yield defer.succeed(None) - return self._test_preserve_fn(nonblocking_function) + return self._test_run_in_background(nonblocking_function) + + @unittest.DEBUG + def test_run_in_background_with_chained_deferred(self): + # a function which returns a deferred which looks like it has been + # called, but is actually paused + def testfunc(): + return logcontext.make_deferred_yieldable( + _chained_deferred_function() + ) + + return self._test_run_in_background(testfunc) @defer.inlineCallbacks def test_make_deferred_yieldable(self): @@ -118,6 +133,22 @@ class LoggingContextTestCase(unittest.TestCase): # now it should be restored self._check_test_key("one") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_chained_deferreds(self): + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = logcontext.make_deferred_yieldable(_chained_deferred_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + @defer.inlineCallbacks def test_make_deferred_yieldable_on_non_deferred(self): """Check that make_deferred_yieldable does the right thing when its @@ -132,3 +163,17 @@ class LoggingContextTestCase(unittest.TestCase): r = yield d1 self.assertEqual(r, "bum") self._check_test_key("one") + + +# a function which returns a deferred which has been "called", but +# which had a function which returned another incomplete deferred on +# its callback list, so won't yet call any other new callbacks. +def _chained_deferred_function(): + d = defer.succeed(None) + + def cb(res): + d2 = defer.Deferred() + reactor.callLater(0, d2.callback, res) + return d2 + d.addCallback(cb) + return d -- cgit 1.4.1 From 46beeb9a307febb679fc25565aca439f8af044ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 May 2018 15:46:22 +0100 Subject: Fix a couple of logcontext leaks in unit tests ... which were making other, innocent, tests, fail. Plus remove a spurious unittest.DEBUG which was making the output noisy. --- tests/appservice/test_scheduler.py | 11 +++++++++-- tests/storage/test_event_push_actions.py | 1 - 2 files changed, 9 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e5a902f734..9181692771 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -17,6 +17,8 @@ from synapse.appservice.scheduler import ( _ServiceQueuer, _TransactionController, _Recoverer ) from twisted.internet import defer + +from synapse.util.logcontext import make_deferred_yieldable from ..utils import MockClock from mock import Mock from tests import unittest @@ -204,7 +206,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(return_value=d) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y: make_deferred_yieldable(d), + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") @@ -235,7 +239,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0)) + + def do_send(x, y): + return make_deferred_yieldable(send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent self.queuer.enqueue(srv1, srv_1_event) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 575374c6a6..9962ce8a5d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -128,7 +128,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _rotate(10) yield _assert_counts(1, 1) - @tests.unittest.DEBUG @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): -- cgit 1.4.1 From 11607006d9570f76a46b44c9c8de13f13550dbcf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 May 2018 15:48:47 +0100 Subject: Remove spurious unittest.DEBUG --- tests/util/test_logcontext.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 7707fbb1f0..ad78d884e0 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -99,7 +99,6 @@ class LoggingContextTestCase(unittest.TestCase): return self._test_run_in_background(nonblocking_function) - @unittest.DEBUG def test_run_in_background_with_chained_deferred(self): # a function which returns a deferred which looks like it has been # called, but is actually paused -- cgit 1.4.1 From 32015e1109bc955697353d8f8088e3f6b538d12c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 May 2018 16:52:42 +0100 Subject: Escape label values in prometheus metrics --- synapse/metrics/metric.py | 22 ++++++++++++++++++++-- tests/metrics/test_metric.py | 21 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 89bd47c3f7..1a09e417c9 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -16,6 +16,7 @@ from itertools import chain import logging +import re logger = logging.getLogger(__name__) @@ -56,8 +57,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: escape backslashes, quotes and newlines - return '"%s"' % (value) + return '"%s"' % (_escape_label_value(value),) def _render_key(self, values): if self.is_scalar(): @@ -299,3 +299,21 @@ class MemoryUsageMetric(object): "process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:count %d" % len_rss, ] + + +def _escape_character(c): + """Replaces a single character with its escape sequence. + """ + if c == "\\": + return "\\\\" + elif c == "\"": + return "\\\"" + elif c == "\n": + return "\\n" + return c + + +def _escape_label_value(value): + """Takes a label value and escapes quotes, newlines and backslashes + """ + return re.sub(r"([\n\"\\])", lambda m: _escape_character(m.group(1)), value) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 39bde6e3f8..069c0be762 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -16,7 +16,8 @@ from tests import unittest from synapse.metrics.metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric + CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, + _escape_label_value, ) @@ -171,3 +172,21 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:size{name="cache_name"} 1', 'cache:evicted_size{name="cache_name"} 2', ]) + + +class LabelValueEscapeTestCase(unittest.TestCase): + def test_simple(self): + string = "safjhsdlifhyskljfksdfh" + self.assertEqual(string, _escape_label_value(string)) + + def test_escape(self): + self.assertEqual( + "abc\\\"def\\nghi\\\\", + _escape_label_value("abc\"def\nghi\\"), + ) + + def test_sequence_of_escapes(self): + self.assertEqual( + "abc\\\"def\\nghi\\\\\\n", + _escape_label_value("abc\"def\nghi\\\n"), + ) -- cgit 1.4.1