From d7275eecf3abcd5f7c645e5da47c7e7310350333 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jul 2018 12:37:12 +0100 Subject: Add a sleep to the Limiter to fix stack overflows. Fixes #3570 --- tests/util/test_limiter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py index a5a767b1ff..f7b942f5c1 100644 --- a/tests/util/test_limiter.py +++ b/tests/util/test_limiter.py @@ -48,21 +48,21 @@ class LimiterTestCase(unittest.TestCase): self.assertFalse(d4.called) self.assertFalse(d5.called) - self.assertTrue(d4.called) + cm4 = yield d4 self.assertFalse(d5.called) with cm3: self.assertFalse(d5.called) - self.assertTrue(d5.called) + cm5 = yield d5 with cm2: pass - with (yield d4): + with cm4: pass - with (yield d5): + with cm5: pass d6 = limiter.queue(key) -- cgit 1.5.1 From 7c712f95bbc7f405355d5714c92d65551f64fec2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jul 2018 13:11:43 +0100 Subject: Combine Limiter and Linearizer Linearizer was effectively a Limiter with max_count=1, so rather than maintaining two sets of code, let's combine them. --- synapse/handlers/message.py | 4 +- synapse/util/async.py | 99 +++++-------------------------------------- tests/util/test_limiter.py | 70 ------------------------------ tests/util/test_linearizer.py | 47 ++++++++++++++++++++ 4 files changed, 59 insertions(+), 161 deletions(-) delete mode 100644 tests/util/test_limiter.py (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8c12c6990f..abc07ea87c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -33,7 +33,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import send_event_to_master from synapse.types import RoomAlias, RoomStreamToken, UserID -from synapse.util.async import Limiter, ReadWriteLock +from synapse.util.async import Linearizer, ReadWriteLock from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func @@ -427,7 +427,7 @@ class EventCreationHandler(object): # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. - self.limiter = Limiter(max_count=5, name="room_event_creation_limit") + self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") self.action_generator = hs.get_action_generator() diff --git a/synapse/util/async.py b/synapse/util/async.py index 22071ddef7..5a50d9700f 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -157,91 +157,8 @@ def concurrently_execute(func, args, limit): class Linearizer(object): - """Linearizes access to resources based on a key. Useful to ensure only one - thing is happening at a time on a given resource. - - Example: - - with (yield linearizer.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None, clock=None): - if name is None: - self.name = id(self) - else: - self.name = name - self.key_to_defer = {} - - if not clock: - from twisted.internet import reactor - clock = Clock(reactor) - self._clock = clock - - @defer.inlineCallbacks - def queue(self, key): - # If there is already a deferred in the queue, we pull it out so that - # we can wait on it later. - # Then we replace it with a deferred that we resolve *after* the - # context manager has exited. - # We only return the context manager after the previous deferred has - # resolved. - # This all has the net effect of creating a chain of deferreds that - # wait for the previous deferred before starting their work. - current_defer = self.key_to_defer.get(key) - - new_defer = defer.Deferred() - self.key_to_defer[key] = new_defer - - if current_defer: - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key - ) - try: - with PreserveLoggingContext(): - yield current_defer - except Exception: - logger.exception("Unexpected exception in Linearizer") - - logger.info("Acquired linearizer lock %r for key %r", self.name, - key) - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (There's no particular need for it to happen before we return - # the context manager, but it needs to happen while we hold the - # lock, and the context manager's exit code must be synchronous, - # so actually this is the only sensible place. - yield self._clock.sleep(0) - - else: - logger.info("Acquired uncontended linearizer lock %r for key %r", - self.name, key) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - with PreserveLoggingContext(): - new_defer.callback(None) - current_d = self.key_to_defer.get(key) - if current_d is new_defer: - self.key_to_defer.pop(key, None) - - defer.returnValue(_ctx_manager()) - - -class Limiter(object): """Limits concurrent access to resources based on a key. Useful to ensure - only a few thing happen at a time on a given resource. + only a few things happen at a time on a given resource. Example: @@ -249,7 +166,7 @@ class Limiter(object): # do some work. """ - def __init__(self, max_count=1, name=None, clock=None): + def __init__(self, name=None, max_count=1, clock=None): """ Args: max_count(int): The maximum number of concurrent accesses @@ -283,10 +200,12 @@ class Limiter(object): new_defer = defer.Deferred() entry[1].append(new_defer) - logger.info("Waiting to acquire limiter lock %r for key %r", self.name, key) + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) yield make_deferred_yieldable(new_defer) - logger.info("Acquired limiter lock %r for key %r", self.name, key) + logger.info("Acquired linearizer lock %r for key %r", self.name, key) entry[0] += 1 # if the code holding the lock completes synchronously, then it @@ -302,7 +221,9 @@ class Limiter(object): yield self._clock.sleep(0) else: - logger.info("Acquired uncontended limiter lock %r for key %r", self.name, key) + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) entry[0] += 1 @contextmanager @@ -310,7 +231,7 @@ class Limiter(object): try: yield finally: - logger.info("Releasing limiter lock %r for key %r", self.name, key) + logger.info("Releasing linearizer lock %r for key %r", self.name, key) # We've finished executing so check if there are any things # blocked waiting to execute and start one of them diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py deleted file mode 100644 index f7b942f5c1..0000000000 --- a/tests/util/test_limiter.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.util.async import Limiter - -from tests import unittest - - -class LimiterTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def test_limiter(self): - limiter = Limiter(3) - - key = object() - - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) - - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) - - cm4 = yield d4 - self.assertFalse(d5.called) - - with cm3: - self.assertFalse(d5.called) - - cm5 = yield d5 - - with cm2: - pass - - with cm4: - pass - - with cm5: - pass - - d6 = limiter.queue(key) - with (yield d6): - pass diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index c95907b32c..c9563124f9 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,3 +66,49 @@ class LinearizerTestCase(unittest.TestCase): func(i) return func(1000) + + @defer.inlineCallbacks + def test_multiple_entries(self): + limiter = Linearizer(max_count=3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + cm4 = yield d4 + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + cm5 = yield d5 + + with cm2: + pass + + with cm4: + pass + + with cm5: + pass + + d6 = limiter.queue(key) + with (yield d6): + pass -- cgit 1.5.1 From e1a237eaabf0ba37f242897700f9bf00729976b8 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 20 Jul 2018 22:41:13 +1000 Subject: Admin API for creating new users (#3415) --- changelog.d/3415.misc | 0 docs/admin_api/register_api.rst | 63 ++++++++ scripts/register_new_matrix_user | 32 +++- synapse/rest/client/v1/admin.py | 122 +++++++++++++++ synapse/secrets.py | 42 +++++ synapse/server.py | 5 + tests/rest/client/v1/test_admin.py | 305 +++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 + 8 files changed, 569 insertions(+), 3 deletions(-) create mode 100644 changelog.d/3415.misc create mode 100644 docs/admin_api/register_api.rst create mode 100644 synapse/secrets.py create mode 100644 tests/rest/client/v1/test_admin.py (limited to 'tests') diff --git a/changelog.d/3415.misc b/changelog.d/3415.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/admin_api/register_api.rst b/docs/admin_api/register_api.rst new file mode 100644 index 0000000000..209cd140fd --- /dev/null +++ b/docs/admin_api/register_api.rst @@ -0,0 +1,63 @@ +Shared-Secret Registration +========================== + +This API allows for the creation of users in an administrative and +non-interactive way. This is generally used for bootstrapping a Synapse +instance with administrator accounts. + +To authenticate yourself to the server, you will need both the shared secret +(``registration_shared_secret`` in the homeserver configuration), and a +one-time nonce. If the registration shared secret is not configured, this API +is not enabled. + +To fetch the nonce, you need to request one from the API:: + + > GET /_matrix/client/r0/admin/register + + < {"nonce": "thisisanonce"} + +Once you have the nonce, you can make a ``POST`` to the same URL with a JSON +body containing the nonce, username, password, whether they are an admin +(optional, False by default), and a HMAC digest of the content. + +As an example:: + + > POST /_matrix/client/r0/admin/register + > { + "nonce": "thisisanonce", + "username": "pepper_roni", + "password": "pizza", + "admin": true, + "mac": "mac_digest_here" + } + + < { + "access_token": "token_here", + "user_id": "@pepper_roni@test", + "home_server": "test", + "device_id": "device_id_here" + } + +The MAC is the hex digest output of the HMAC-SHA1 algorithm, with the key being +the shared secret and the content being the nonce, user, password, and either +the string "admin" or "notadmin", each separated by NULs. For an example of +generation in Python:: + + import hmac, hashlib + + def generate_mac(nonce, user, password, admin=False): + + mac = hmac.new( + key=shared_secret, + digestmod=hashlib.sha1, + ) + + mac.update(nonce.encode('utf8')) + mac.update(b"\x00") + mac.update(user.encode('utf8')) + mac.update(b"\x00") + mac.update(password.encode('utf8')) + mac.update(b"\x00") + mac.update(b"admin" if admin else b"notadmin") + + return mac.hexdigest() diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 12ed20d623..8c3d429351 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -26,11 +26,37 @@ import yaml def request_registration(user, password, server_location, shared_secret, admin=False): + req = urllib2.Request( + "%s/_matrix/client/r0/admin/register" % (server_location,), + headers={'Content-Type': 'application/json'} + ) + + try: + if sys.version_info[:3] >= (2, 7, 9): + # As of version 2.7.9, urllib2 now checks SSL certs + import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + else: + f = urllib2.urlopen(req) + body = f.read() + f.close() + nonce = json.loads(body)["nonce"] + except urllib2.HTTPError as e: + print "ERROR! Received %d %s" % (e.code, e.reason,) + if 400 <= e.code < 500: + if e.info().type == "application/json": + resp = json.load(e) + if "error" in resp: + print resp["error"] + sys.exit(1) + mac = hmac.new( key=shared_secret, digestmod=hashlib.sha1, ) + mac.update(nonce) + mac.update("\x00") mac.update(user) mac.update("\x00") mac.update(password) @@ -40,10 +66,10 @@ def request_registration(user, password, server_location, shared_secret, admin=F mac = mac.hexdigest() data = { - "user": user, + "nonce": nonce, + "username": user, "password": password, "mac": mac, - "type": "org.matrix.login.shared_secret", "admin": admin, } @@ -52,7 +78,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F print "Sending registration request..." req = urllib2.Request( - "%s/_matrix/client/api/v1/register" % (server_location,), + "%s/_matrix/client/r0/admin/register" % (server_location,), data=json.dumps(data), headers={'Content-Type': 'application/json'} ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2dc50e582b..9e9c175970 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac import logging from six.moves import http_client @@ -63,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class UserRegisterServlet(ClientV1RestServlet): + """ + Attributes: + NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted + nonces (dict[str, int]): The nonces that we will accept. A dict of + nonce to the time it was generated, in int seconds. + """ + PATTERNS = client_path_patterns("/admin/register") + NONCE_TIMEOUT = 60 + + def __init__(self, hs): + super(UserRegisterServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + self.reactor = hs.get_reactor() + self.nonces = {} + self.hs = hs + + def _clear_old_nonces(self): + """ + Clear out old nonces that are older than NONCE_TIMEOUT. + """ + now = int(self.reactor.seconds()) + + for k, v in list(self.nonces.items()): + if now - v > self.NONCE_TIMEOUT: + del self.nonces[k] + + def on_GET(self, request): + """ + Generate a new nonce. + """ + self._clear_old_nonces() + + nonce = self.hs.get_secrets().token_hex(64) + self.nonces[nonce] = int(self.reactor.seconds()) + return (200, {"nonce": nonce.encode('ascii')}) + + @defer.inlineCallbacks + def on_POST(self, request): + self._clear_old_nonces() + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + body = parse_json_object_from_request(request) + + if "nonce" not in body: + raise SynapseError( + 400, "nonce must be specified", errcode=Codes.BAD_JSON, + ) + + nonce = body["nonce"] + + if nonce not in self.nonces: + raise SynapseError( + 400, "unrecognised nonce", + ) + + # Delete the nonce, so it can't be reused, even if it's invalid + del self.nonces[nonce] + + if "username" not in body: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['username'], str) or len(body['username']) > 512): + raise SynapseError(400, "Invalid username") + + username = body["username"].encode("utf-8") + if b"\x00" in username: + raise SynapseError(400, "Invalid username") + + if "password" not in body: + raise SynapseError( + 400, "password must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['password'], str) or len(body['password']) > 512): + raise SynapseError(400, "Invalid password") + + password = body["password"].encode("utf-8") + if b"\x00" in password: + raise SynapseError(400, "Invalid password") + + admin = body.get("admin", None) + got_mac = body["mac"] + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=hashlib.sha1, + ) + want_mac.update(nonce) + want_mac.update(b"\x00") + want_mac.update(username) + want_mac.update(b"\x00") + want_mac.update(password) + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") + want_mac = want_mac.hexdigest() + + if not hmac.compare_digest(want_mac, got_mac): + raise SynapseError( + 403, "HMAC incorrect", + ) + + # Reuse the parts of RegisterRestServlet to reduce code duplication + from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) + + (user_id, _) = yield register.registration_handler.register( + localpart=username.lower(), password=password, admin=bool(admin), + generate_token=False, + ) + + result = yield register._create_registration_details(user_id, body) + defer.returnValue((200, result)) + + class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") @@ -614,3 +735,4 @@ def register_servlets(hs, http_server): ShutdownRoomRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server) + UserRegisterServlet(hs).register(http_server) diff --git a/synapse/secrets.py b/synapse/secrets.py new file mode 100644 index 0000000000..f397daaa5e --- /dev/null +++ b/synapse/secrets.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Injectable secrets module for Synapse. + +See https://docs.python.org/3/library/secrets.html#module-secrets for the API +used in Python 3.6, and the API emulated in Python 2.7. +""" + +import six + +if six.PY3: + import secrets + + def Secrets(): + return secrets + + +else: + + import os + import binascii + + class Secrets(object): + def token_bytes(self, nbytes=32): + return os.urandom(nbytes) + + def token_hex(self, nbytes=32): + return binascii.hexlify(self.token_bytes(nbytes)) diff --git a/synapse/server.py b/synapse/server.py index 92bea96c5c..fd4f992258 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -74,6 +74,7 @@ from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, ) +from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender @@ -158,6 +159,7 @@ class HomeServer(object): 'groups_server_handler', 'groups_attestation_signing', 'groups_attestation_renewer', + 'secrets', 'spam_checker', 'room_member_handler', 'federation_registry', @@ -405,6 +407,9 @@ class HomeServer(object): def build_groups_attestation_renewer(self): return GroupAttestionRenewer(self) + def build_secrets(self): + return Secrets() + def build_spam_checker(self): return SpamChecker(self) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py new file mode 100644 index 0000000000..8c90145601 --- /dev/null +++ b/tests/rest/client/v1/test_admin.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import json + +from mock import Mock + +from synapse.http.server import JsonResource +from synapse.rest.client.v1.admin import register_servlets +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) + + +class UserRegisterTestCase(unittest.TestCase): + def setUp(self): + + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = "/_matrix/client/r0/admin/register" + + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) + + self.secrets = Mock() + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.hs.config.registration_shared_secret = u"shared" + + self.hs.get_media_repository = Mock() + self.hs.get_deactivate_account_handler = Mock() + + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) + + def test_disabled(self): + """ + If there is no shared secret, registration through this method will be + prevented. + """ + self.hs.config.registration_shared_secret = None + + request, channel = make_request("POST", self.url, b'{}') + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + 'Shared secret registration is not enabled', channel.json_body["error"] + ) + + def test_get_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, using the + homeserver's secrets provider. + """ + secrets = Mock() + secrets.token_hex = Mock(return_value="abcd") + + self.hs.get_secrets = Mock(return_value=secrets) + + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + + self.assertEqual(channel.json_body, {"nonce": "abcd"}) + + def test_expired_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, which will + only last for SALT_TIMEOUT (60s). + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + # 59 seconds + self.clock.advance(59) + + body = json.dumps({"nonce": nonce}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # 61 seconds + self.clock.advance(2) + + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_register_incorrect_nonce(self): + """ + Only the provided nonce can be used, as it's checked in the MAC. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("HMAC incorrect", channel.json_body["error"]) + + def test_register_correct_nonce(self): + """ + When the correct nonce is provided, and the right key is provided, the + user is registered. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + def test_nonce_reuse(self): + """ + A valid unrecognised nonce. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + # Now, try and reuse it + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_missing_parts(self): + """ + Synapse will complain if you don't give nonce, username, password, and + mac. Admin is optional. Additional checks are done for length and + type. + """ + def nonce(): + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + return channel.json_body["nonce"] + + # + # Nonce check + # + + # Must be present + body = json.dumps({}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('nonce must be specified', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce()}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce(), "username": "a"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('password must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Super long + body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) diff --git a/tests/utils.py b/tests/utils.py index e488238bb3..c3dbff8507 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None config.user_directory_search_all_users = False config.user_consent_server_notice_content = None config.block_events_without_consent_error = None + config.media_storage_providers = [] + config.auto_join_rooms = [] # disable user directory updates, because they get done in the # background, which upsets the test runner. @@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) -- cgit 1.5.1 From 3d6df846580ce6ef8769945e6990af2f44251e40 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jul 2018 13:59:55 +0100 Subject: Test and fix support for cancellation in Linearizer --- synapse/util/async.py | 28 ++++++++++++++++++++++------ tests/util/test_linearizer.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/util/async.py b/synapse/util/async.py index 5a50d9700f..a7094e2fb4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -184,13 +184,13 @@ class Linearizer(object): # key_to_defer is a map from the key to a 2 element list where # the first element is the number of things executing, and - # the second element is a deque of deferreds for the things blocked from - # executing. + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. self.key_to_defer = {} @defer.inlineCallbacks def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, collections.deque()]) + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items @@ -198,12 +198,28 @@ class Linearizer(object): # this item so that it can continue executing. if entry[0] >= self.max_count: new_defer = defer.Deferred() - entry[1].append(new_defer) + entry[1][new_defer] = 1 logger.info( "Waiting to acquire linearizer lock %r for key %r", self.name, key, ) - yield make_deferred_yieldable(new_defer) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise logger.info("Acquired linearizer lock %r for key %r", self.name, key) entry[0] += 1 @@ -238,7 +254,7 @@ class Linearizer(object): entry[0] -= 1 if entry[1]: - next_def = entry[1].popleft() + (next_def, _) = entry[1].popitem(last=False) # we need to run the next thing in the sentinel context. with PreserveLoggingContext(): diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index c9563124f9..4729bd5a0a 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -17,6 +17,7 @@ from six.moves import range from twisted.internet import defer, reactor +from twisted.internet.defer import CancelledError from synapse.util import Clock, logcontext from synapse.util.async import Linearizer @@ -112,3 +113,33 @@ class LinearizerTestCase(unittest.TestCase): d6 = limiter.queue(key) with (yield d6): pass + + @defer.inlineCallbacks + def test_cancellation(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + d3 = linearizer.queue(key) + self.assertFalse(d3.called) + + d2.cancel() + + with cm1: + pass + + self.assertTrue(d2.called) + try: + yield d2 + self.fail("Expected d2 to raise CancelledError") + except CancelledError: + pass + + with (yield d3): + pass -- cgit 1.5.1 From 3132b89f12f0386558045683ad198f090b0e2c90 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 21 Jul 2018 15:47:18 +1000 Subject: Make the rest of the .iterwhatever go away (#3562) --- changelog.d/3562.misc | 0 synapse/app/homeserver.py | 6 ++++-- synapse/app/synctl.py | 4 +++- synapse/events/snapshot.py | 4 +++- synapse/handlers/federation.py | 18 +++++++++--------- synapse/state.py | 6 +++--- synapse/visibility.py | 19 ++++++++++--------- tests/test_federation.py | 3 +-- 8 files changed, 33 insertions(+), 27 deletions(-) create mode 100644 changelog.d/3562.misc (limited to 'tests') diff --git a/changelog.d/3562.misc b/changelog.d/3562.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 14e6dca522..2ad1beb8d8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,6 +18,8 @@ import logging import os import sys +from six import iteritems + from twisted.application import service from twisted.internet import defer, reactor from twisted.web.resource import EncodingResourceWrapper, NoResource @@ -442,7 +444,7 @@ def run(hs): stats["total_nonbridged_users"] = total_nonbridged_users daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in daily_user_type_results.iteritems(): + for name, count in iteritems(daily_user_type_results): stats["daily_user_type_" + name] = count room_count = yield hs.get_datastore().get_room_count() @@ -453,7 +455,7 @@ def run(hs): stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() - for name, count in r30_results.iteritems(): + for name, count in iteritems(r30_results): stats["r30_users_" + name] = count daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 68acc15a9a..d658f967ba 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -25,6 +25,8 @@ import subprocess import sys import time +from six import iteritems + import yaml SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] @@ -173,7 +175,7 @@ def main(): os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in cache_factors.iteritems(): + for cache_name, factor in iteritems(cache_factors): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) worker_configfiles = [] diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index bcd9bb5946..f83a1581a6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + from frozendict import frozendict from twisted.internet import defer @@ -159,7 +161,7 @@ def _encode_state_dict(state_dict): return [ (etype, state_key, v) - for (etype, state_key), v in state_dict.iteritems() + for (etype, state_key), v in iteritems(state_dict) ] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 65f6041b10..a6d391c4e8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,8 +21,8 @@ import logging import sys import six -from six import iteritems -from six.moves import http_client +from six import iteritems, itervalues +from six.moves import http_client, zip from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json @@ -731,7 +731,7 @@ class FederationHandler(BaseHandler): """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in state.iteritems() + for (e_type, state_key), event in iteritems(state) if e_type == EventTypes.Member and event.membership == Membership.JOIN ] @@ -748,7 +748,7 @@ class FederationHandler(BaseHandler): except Exception: pass - return sorted(joined_domains.iteritems(), key=lambda d: d[1]) + return sorted(joined_domains.items(), key=lambda d: d[1]) curr_domains = get_domains_from_state(curr_state) @@ -811,7 +811,7 @@ class FederationHandler(BaseHandler): tried_domains = set(likely_domains) tried_domains.add(self.server_name) - event_ids = list(extremities.iterkeys()) + event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = logcontext.preserve_fn( @@ -827,15 +827,15 @@ class FederationHandler(BaseHandler): states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( - [e_id for ids in states.itervalues() for e_id in ids.itervalues()], + [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False ) states = { key: { k: state_map[e_id] - for k, e_id in state_dict.iteritems() + for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in states.iteritems() + } for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: @@ -1515,7 +1515,7 @@ class FederationHandler(BaseHandler): yield self.store.persist_events( [ (ev_info["event"], context) - for ev_info, context in itertools.izip(event_infos, contexts) + for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, ) diff --git a/synapse/state.py b/synapse/state.py index 15a593d41c..504caae2f7 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ import hashlib import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems, iterkeys, itervalues from frozendict import frozendict @@ -647,7 +647,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): for event_id in event_ids ) if event_map is not None: - needed_events -= set(event_map.iterkeys()) + needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d conflicted events", len(needed_events)) @@ -668,7 +668,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): new_needed_events = set(itervalues(auth_events)) new_needed_events -= needed_events if event_map is not None: - new_needed_events -= set(event_map.iterkeys()) + new_needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d auth events", len(new_needed_events)) diff --git a/synapse/visibility.py b/synapse/visibility.py index 9b97ea2b83..ba0499a022 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import operator -import six +from six import iteritems, itervalues +from six.moves import map from twisted.internet import defer @@ -221,7 +222,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, return event # check each event: gives an iterable[None|EventBase] - filtered_events = itertools.imap(allowed, events) + filtered_events = map(allowed, events) # remove the None entries filtered_events = filter(operator.truth, filtered_events) @@ -261,7 +262,7 @@ def filter_events_for_server(store, server_name, events): # membership states for the requesting server to determine # if the server is either in the room or has been invited # into the room. - for ev in state.itervalues(): + for ev in itervalues(state): if ev.type != EventTypes.Member: continue try: @@ -295,7 +296,7 @@ def filter_events_for_server(store, server_name, events): ) visibility_ids = set() - for sids in event_to_state_ids.itervalues(): + for sids in itervalues(event_to_state_ids): hist = sids.get((EventTypes.RoomHistoryVisibility, "")) if hist: visibility_ids.add(hist) @@ -308,7 +309,7 @@ def filter_events_for_server(store, server_name, events): event_map = yield store.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in event_map.itervalues() + for e in itervalues(event_map) ) if all_open: @@ -346,7 +347,7 @@ def filter_events_for_server(store, server_name, events): # state_key_to_event_id_set = { e - for key_to_eid in six.itervalues(event_to_state_ids) + for key_to_eid in itervalues(event_to_state_ids) for e in key_to_eid.items() } @@ -369,10 +370,10 @@ def filter_events_for_server(store, server_name, events): event_to_state = { e_id: { key: event_map[inner_e_id] - for key, inner_e_id in key_to_eid.iteritems() + for key, inner_e_id in iteritems(key_to_eid) if inner_e_id in event_map } - for e_id, key_to_eid in event_to_state_ids.iteritems() + for e_id, key_to_eid in iteritems(event_to_state_ids) } defer.returnValue([ diff --git a/tests/test_federation.py b/tests/test_federation.py index 159a136971..f40ff29b52 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase): ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") - @unittest.DEBUG def test_cant_hide_past_history(self): """ If you send a message, you must be able to provide the direct @@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase): for x, y in d.items() if x == ("m.room.member", "@us:test") ], - "auth_chain_ids": d.values(), + "auth_chain_ids": list(d.values()), } ) -- cgit 1.5.1 From 8fbe418777a62ea6dd5fd811d63e30c42589650b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:33:49 +0100 Subject: Fix unit tests --- tests/replication/slave/storage/test_events.py | 8 +++-- tests/test_state.py | 47 +++++++++++++++++++------- 2 files changed, 40 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cea01d93eb..f5b47f5ec0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_ids = { key: e.event_id for key, e in state.items() } - context = EventContext() - context.current_state_ids = state_ids - context.prev_state_ids = state_ids + context = EventContext.with_state( + state_group=None, + current_state_ids=state_ids, + prev_state_ids=state_ids + ) else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) diff --git a/tests/test_state.py b/tests/test_state.py index c0f2d1152d..429a18cbf7 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].prev_state_ids)) + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertEqual(2, len(prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].prev_state_ids.values()} + {e_id for e_id in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(current_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.prev_state_ids.values()) + set(e.event_id for e in old_state), set(prev_state_ids.values()) ) @defer.inlineCallbacks @@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.prev_state_ids.values()) + set(prev_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) - self.assertEqual(len(context.current_state_ids), 6) + current_state_ids = yield context.get_current_state_ids(self.store) + + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) - self.assertEqual(len(context.current_state_ids), 6) + current_state_ids = yield context.get_current_state_ids(self.store) + + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - old_state_2[3].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - old_state_1[3].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, current_state_ids[("test1", "1")] ) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, -- cgit 1.5.1