From 91ea0202e6f4a519e332a6c456aedfe4b7d627c9 Mon Sep 17 00:00:00 2001 From: Krombel Date: Wed, 14 Mar 2018 16:45:37 +0100 Subject: move handling of auto_join_rooms to RegisterHandler Currently the handling of auto_join_rooms only works when a user registers itself via public register api. Registrations via registration_shared_secret and ModuleApi do not work This auto_joins the users in the registration handler which enables the auto join feature for all 3 registration paths. This is related to issue #2725 Signed-Off-by: Matthias Kesler --- tests/rest/client/v1/test_events.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 2b89c0a3c7..a8d09600bd 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -123,6 +123,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.ratelimiter.send_message.return_value = (True, 0) hs.config.enable_registration_captcha = False hs.config.enable_registration = True + hs.config.auto_join_rooms = [] hs.get_handlers().federation_handler = Mock() -- cgit 1.5.1 From 72251d1b979db0bc96e5d95ac70b8e1cd78cde7c Mon Sep 17 00:00:00 2001 From: Silke Date: Tue, 20 Mar 2018 10:40:16 +0100 Subject: Remove address resolution of hosts in SRV records Signed-off-by: Silke Hofstra --- synapse/http/endpoint.py | 103 ++++------------------------------------------- tests/test_dns.py | 29 +------------ 2 files changed, 10 insertions(+), 122 deletions(-) (limited to 'tests') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 87639b9151..00572c2897 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import socket - from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor from twisted.internet.error import ConnectError @@ -33,7 +31,7 @@ SERVER_CACHE = {} # our record of an individual server which can be tried to reach a destination. # -# "host" is actually a dotted-quad or ipv6 address string. Except when there's +# "host" is the hostname acquired from the SRV record. Except when there's # no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" @@ -297,20 +295,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t payload = answer.payload - hosts = yield _get_hosts_for_srv_record( - dns_client, str(payload.target) - ) - - for (ip, ttl) in hosts: - host_ttl = min(answer.ttl, ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) @@ -328,81 +319,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t raise e defer.returnValue(servers) - - -@defer.inlineCallbacks -def _get_hosts_for_srv_record(dns_client, host): - """Look up each of the hosts in a SRV record - - Args: - dns_client (twisted.names.dns.IResolver): - host (basestring): host to look up - - Returns: - Deferred[list[(str, int)]]: a list of (host, ttl) pairs - - """ - ip4_servers = [] - ip6_servers = [] - - def cb(res): - # lookupAddress and lookupIP6Address return a three-tuple - # giving the answer, authority, and additional sections of the - # response. - # - # we only care about the answers. - - return res[0] - - def eb(res, record_type): - if res.check(DNSNameError): - return [] - logger.warn("Error looking up %s for %s: %s", record_type, host, res) - return res - - # no logcontexts here, so we can safely fire these off and gatherResults - d1 = dns_client.lookupAddress(host).addCallbacks( - cb, eb, errbackArgs=("A", )) - d2 = dns_client.lookupIPV6Address(host).addCallbacks( - cb, eb, errbackArgs=("AAAA", )) - results = yield defer.DeferredList( - [d1, d2], consumeErrors=True) - - # if all of the lookups failed, raise an exception rather than blowing out - # the cache with an empty result. - if results and all(s == defer.FAILURE for (s, _) in results): - defer.returnValue(results[0][1]) - - for (success, result) in results: - if success == defer.FAILURE: - continue - - for answer in result: - if not answer.payload: - continue - - try: - if answer.type == dns.A: - ip = answer.payload.dottedQuad() - ip4_servers.append((ip, answer.ttl)) - elif answer.type == dns.AAAA: - ip = socket.inet_ntop( - socket.AF_INET6, answer.payload.address, - ) - ip6_servers.append((ip, answer.ttl)) - else: - # the most likely candidate here is a CNAME record. - # rfc2782 says srvs may not point to aliases. - logger.warn( - "Ignoring unexpected DNS record type %s for %s", - answer.type, host, - ) - continue - except Exception as e: - logger.warn("Ignoring invalid DNS response for %s: %s", - host, e) - continue - - # keep the ipv4 results before the ipv6 results, mostly to match historical - # behaviour. - defer.returnValue(ip4_servers + ip6_servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index d08b0f4333..af607d626f 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -33,8 +33,6 @@ class DnsTestCase(unittest.TestCase): service_name = "test_service.example.com" host_name = "example.com" - ip_address = "127.0.0.1" - ip6_address = "::1" answer_srv = dns.RRHeader( type=dns.SRV, @@ -43,29 +41,9 @@ class DnsTestCase(unittest.TestCase): ) ) - answer_a = dns.RRHeader( - type=dns.A, - payload=dns.Record_A( - address=ip_address, - ) - ) - - answer_aaaa = dns.RRHeader( - type=dns.AAAA, - payload=dns.Record_AAAA( - address=ip6_address, - ) - ) - dns_client_mock.lookupService.return_value = defer.succeed( ([answer_srv], None, None), ) - dns_client_mock.lookupAddress.return_value = defer.succeed( - ([answer_a], None, None), - ) - dns_client_mock.lookupIPV6Address.return_value = defer.succeed( - ([answer_aaaa], None, None), - ) cache = {} @@ -74,13 +52,10 @@ class DnsTestCase(unittest.TestCase): ) dns_client_mock.lookupService.assert_called_once_with(service_name) - dns_client_mock.lookupAddress.assert_called_once_with(host_name) - dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name) - self.assertEquals(len(servers), 2) + self.assertEquals(len(servers), 1) self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, ip_address) - self.assertEquals(servers[1].host, ip6_address) + self.assertEquals(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): -- cgit 1.5.1 From 616835187702a0c6f16042e3efb452e1ee3e7826 Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Tue, 3 Apr 2018 20:41:21 +0200 Subject: Add b prefixes to some strings that are bytes in py3 This has no effect on python2 Signed-off-by: Adrian Tschira --- synapse/api/auth.py | 10 +++++----- synapse/app/frontend_proxy.py | 2 +- synapse/http/server.py | 4 ++-- synapse/http/site.py | 6 +++--- synapse/rest/client/v1/register.py | 4 ++-- tests/utils.py | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ac0a3655a5..f17fda6315 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -204,8 +204,8 @@ class Auth(object): ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - "User-Agent", - default=[""] + b"User-Agent", + default=[b""] )[0] if user and access_token and ip_addr: self.store.insert_client_ip( @@ -672,7 +672,7 @@ def has_access_token(request): bool: False if no access_token was given, True otherwise. """ query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders("Authorization") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -692,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401): AuthError: If there isn't an access_token in the request. """ - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") if auth_headers: # Try the get the access_token from a "Authorization: Bearer" # header diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index de889357c3..b349e3e3ce 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -90,7 +90,7 @@ class KeyUploadServlet(RestServlet): # They're actually trying to upload something, proxy to main synapse. # Pass through the auth headers, if any, in case the access token # is there. - auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) headers = { "Authorization": auth_headers, } diff --git a/synapse/http/server.py b/synapse/http/server.py index f19c068ef6..d979e76639 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -324,7 +324,7 @@ class JsonResource(HttpServer, resource.Resource): register_paths, so will return (possibly via Deferred) either None, or a tuple of (http code, response body). """ - if request.method == "OPTIONS": + if request.method == b"OPTIONS": return _options_handler, {} # Loop through all the registered callbacks to check if the method @@ -536,7 +536,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( - "User-Agent", default=[] + b"User-Agent", default=[] ) for user_agent in user_agents: if "curl" in user_agent: diff --git a/synapse/http/site.py b/synapse/http/site.py index e422c8dfae..c8b46e1af2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ import logging import re import time -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') class SynapseRequest(Request): @@ -43,12 +43,12 @@ class SynapseRequest(Request): def get_redacted_uri(self): return ACCESS_TOKEN_RE.sub( - r'\1\3', + br'\1\3', self.uri ) def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def started_processing(self): self.site.access_logger.info( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 5c5fa8f7ab..8a82097178 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -348,9 +348,9 @@ class RegisterRestServlet(ClientV1RestServlet): admin = register_json.get("admin", None) # Its important to check as we use null bytes as HMAC field separators - if "\x00" in user: + if b"\x00" in user: raise SynapseError(400, "Invalid user") - if "\x00" in password: + if b"\x00" in password: raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not diff --git a/tests/utils.py b/tests/utils.py index 8efd3a3475..f15317d27b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -212,7 +212,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it -- cgit 1.5.1 From e939f3bca6d73bc4ec6450aeb41f5e121a29a662 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Apr 2018 14:37:11 +0100 Subject: Fix tests --- tests/handlers/test_appservice.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index a667fb6f0e..b753455943 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,6 +31,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore = Mock(return_value=self.mock_store) + self.mock_store.get_received_ts.return_value = 0 hs.get_application_service_api = Mock(return_value=self.mock_as_api) hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) hs.get_clock.return_value = MockClock() -- cgit 1.5.1 From 1515560f5ce8054a539dd0fed9e88232883f079f Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 17:20:37 +0200 Subject: Use str(e) instead of e.message Doing this I learned e.message was pretty shortlived, added in 2.6, they realized it was a bad idea and deprecated it in 2.7 Signed-off-by: Adrian Tschira --- synapse/crypto/keyring.py | 8 ++++---- tests/storage/test_appservice.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 35f810b07b..fce83d445f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -352,7 +352,7 @@ class Keyring(object): logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) defer.returnValue({}) @@ -384,7 +384,7 @@ class Keyring(object): logger.info( "Unable to get key %r for %r directly: %s %s", key_ids, server_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) if not keys: @@ -734,7 +734,7 @@ def _handle_key_deferred(verify_request): except IOError as e: logger.warn( "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 502, @@ -744,7 +744,7 @@ def _handle_key_deferred(verify_request): except Exception as e: logger.exception( "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 401, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c2e39a7288..00825498b1 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -480,9 +480,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("id", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("id", str(e)) @defer.inlineCallbacks def test_duplicate_as_tokens(self): @@ -504,6 +504,6 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("as_token", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("as_token", str(e)) -- cgit 1.5.1 From cb9cdfecd097c5a0c85840be9d37d590f0f4af2d Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 17:18:01 +0200 Subject: Add some more variables to the unittest config These worked accidentally before (python2 doesn't complain if you compare incompatible types) but under py3 this blows up spectacularly Signed-off-by: Adrian Tschira --- tests/utils.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index f15317d27b..0cd9f7eeee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -59,6 +59,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.email_enable_notifs = False config.block_non_admin_invites = False config.federation_domain_whitelist = None + config.federation_rc_reject_limit = 10 + config.federation_rc_sleep_limit = 10 + config.federation_rc_concurrent = 10 + config.filter_timeline_limit = 5000 config.user_directory_search_all_users = False # disable user directory updates, because they get done in the -- cgit 1.5.1 From a1a3c9660ffc1e99f656c1d39ebe57b8857b207d Mon Sep 17 00:00:00 2001 From: Adrian Tschira Date: Sun, 15 Apr 2018 21:46:23 +0200 Subject: Make tests py3 compatible This is a mixed commit that fixes various small issues * print parentheses * 01 is invalid syntax (it was octal in py2) * [x for i in 1, 2] is invalid syntax * six moves Signed-off-by: Adrian Tschira --- tests/config/test_load.py | 2 +- tests/crypto/test_keyring.py | 2 +- tests/test_distributor.py | 2 +- tests/util/test_file_consumer.py | 2 +- tests/util/test_linearizer.py | 3 ++- 5 files changed, 6 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 161a87d7e3..772afd2cf9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -24,7 +24,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print self.dir + print(self.dir) self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d4ec02ffc2..149e443022 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(01) + yield async.sleep(1) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index acebcf4a86..010aeaee7e 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -62,7 +62,7 @@ class DistributorTestCase(unittest.TestCase): def test_signal_catch(self): self.dist.declare("alarm") - observers = [Mock() for i in 1, 2] + observers = [Mock() for i in (1, 2)] for o in observers: self.dist.observe("alarm", o) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 76e2234255..d6e1082779 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -20,7 +20,7 @@ from mock import NonCallableMock from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest -from StringIO import StringIO +from six import StringIO import threading diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 793a88e462..4865eb4bc6 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -18,6 +18,7 @@ from tests import unittest from twisted.internet import defer from synapse.util.async import Linearizer +from six.moves import range class LinearizerTestCase(unittest.TestCase): @@ -58,7 +59,7 @@ class LinearizerTestCase(unittest.TestCase): logcontext.LoggingContext.current_context(), lc) func(0, sleep=True) - for i in xrange(1, 100): + for i in range(1, 100): func(i) return func(1000) -- cgit 1.5.1 From 639480e14a06723adf6817ddd2b2ff9e4f4cdf2a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 Apr 2018 18:41:37 +0100 Subject: Avoid creating events with huge numbers of prev_events In most cases, we limit the number of prev_events for a given event to 10 events. This fixes a particular code path which created events with huge numbers of prev_events. --- synapse/handlers/message.py | 78 +++++++++++++++++++--------------- synapse/handlers/room_member.py | 13 ++++-- synapse/storage/event_federation.py | 57 ++++++++++++++++++------- tests/storage/test_event_federation.py | 68 +++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 54 deletions(-) create mode 100644 tests/storage/test_event_federation.py (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54cd691f91..21628a8540 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -37,7 +37,6 @@ from ._base import BaseHandler from canonicaljson import encode_canonical_json import logging -import random import simplejson logger = logging.getLogger(__name__) @@ -433,7 +432,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_event_ids=None): + prev_events_and_hashes=None): """ Given a dict from a client, create a new event. @@ -447,7 +446,13 @@ class EventCreationHandler(object): event_dict (dict): An entire event token_id (str) txn_id (str) - prev_event_ids (list): The prev event ids to use when creating the event + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. Returns: Tuple of created event (FrozenEvent), Context @@ -485,7 +490,7 @@ class EventCreationHandler(object): event, context = yield self.create_new_client_event( builder=builder, requester=requester, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) defer.returnValue((event, context)) @@ -588,39 +593,44 @@ class EventCreationHandler(object): @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, prev_event_ids=None): - if prev_event_ids: - prev_events = yield self.store.add_event_hashes(prev_event_ids) - prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) - depth = prev_max_depth + 1 - else: - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, + def create_new_client_event(self, builder, requester=None, + prev_events_and_hashes=None): + """Create a new event for a local client + + Args: + builder (EventBuilder): + + requester (synapse.types.Requester|None): + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. + + Returns: + Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + """ + + if prev_events_and_hashes is not None: + assert len(prev_events_and_hashes) <= 10, \ + "Attempting to create an event with %i prev_events" % ( + len(prev_events_and_hashes), ) + else: + prev_events_and_hashes = \ + yield self.store.get_prev_events_for_room(builder.room_id) - # We want to limit the max number of prev events we point to in our - # new event - if len(latest_ret) > 10: - # Sort by reverse depth, so we point to the most recent. - latest_ret.sort(key=lambda a: -a[2]) - new_latest_ret = latest_ret[:5] - - # We also randomly point to some of the older events, to make - # sure that we don't completely ignore the older events. - if latest_ret[5:]: - sample_size = min(5, len(latest_ret[5:])) - new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) - latest_ret = new_latest_ret - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 + if prev_events_and_hashes: + depth = max([d for _, _, d in prev_events_and_hashes]) + 1 + else: + depth = 1 - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in prev_events_and_hashes + ] builder.prev_events = prev_events builder.depth = depth diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c45142d38d..714583f1d5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -149,7 +149,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, - prev_event_ids, + prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, @@ -175,7 +175,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. @@ -314,7 +314,12 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", ) - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -403,7 +408,7 @@ class RoomMemberHandler(object): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_event_ids=latest_event_ids, + prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 00ee82d300..a183fc6b50 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from twisted.internet import defer @@ -133,7 +134,47 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -182,22 +223,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py new file mode 100644 index 0000000000..30683e7888 --- /dev/null +++ b/tests/storage/test_event_federation.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils + + +class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_prev_events_for_room(self): + room_id = '@ROOM:local' + + # add a bunch of events and hashes to act as forward extremities + def insert_event(txn, i): + event_id = '$event_%i:local' % i + + txn.execute(( + "INSERT INTO events (" + " room_id, event_id, type, depth, topological_ordering," + " content, processed, outlier) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + ), (room_id, event_id, i, i, True, False)) + + txn.execute(( + 'INSERT INTO event_forward_extremities (room_id, event_id) ' + 'VALUES (?, ?)' + ), (room_id, event_id)) + + txn.execute(( + 'INSERT INTO event_reference_hashes ' + '(event_id, algorithm, hash) ' + "VALUES (?, 'sha256', ?)" + ), (event_id, 'ffff')) + + for i in range(0, 11): + yield self.store.runInteraction("insert", insert_event, i) + + # this should get the last five and five others + r = yield self.store.get_prev_events_for_room(room_id) + self.assertEqual(10, len(r)) + for i in range(0, 5): + el = r[i] + depth = el[2] + self.assertEqual(10 - i, depth) + + for i in range(5, 5): + el = r[i] + depth = el[2] + self.assertLessEqual(5, depth) -- cgit 1.5.1