From bd348f0af66a82e91b71d4313a20798b8d33b832 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 7 Jun 2018 11:37:10 +0100 Subject: remove dead filter_events_for_clients This is only used by filter_events_for_client, so we can simplify the whole thing by just doing one user at a time, and removing a dead storage function to boot. --- tests/replication/slave/storage/test_account_data.py | 8 -------- 1 file changed, 8 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index da54d478ce..f47a42e45d 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -37,10 +37,6 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 1}} - ) yield self.master_store.add_account_data_for_user( USER_ID, TYPE, {"a": 2} @@ -50,7 +46,3 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 2}} - ) -- cgit 1.5.1 From a61738b316db70a4184d5c355696e0a039e7867f Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 14 Jun 2018 18:27:37 +1000 Subject: Remove run_on_reactor (#3395) --- synapse/federation/transaction_queue.py | 4 ---- synapse/handlers/auth.py | 8 ++------ synapse/handlers/federation.py | 6 +----- synapse/handlers/identity.py | 8 -------- synapse/handlers/message.py | 4 +--- synapse/handlers/register.py | 5 +---- synapse/push/pusherpool.py | 3 --- synapse/rest/client/v1/register.py | 7 ------- synapse/rest/client/v2_alpha/account.py | 7 ------- synapse/rest/client/v2_alpha/register.py | 3 --- synapse/util/async.py | 10 +--------- tests/test_distributor.py | 2 -- tests/util/caches/test_descriptors.py | 2 -- 13 files changed, 6 insertions(+), 63 deletions(-) (limited to 'tests') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f0aeb5a0d3..bcbce7f6eb 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -21,7 +21,6 @@ from .units import Transaction, Edu from synapse.api.errors import HttpResponseException, FederationDeniedError from synapse.util import logcontext, PreserveLoggingContext -from synapse.util.async import run_on_reactor from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func from synapse.handlers.presence import format_user_presence_state, get_interested_remotes @@ -451,9 +450,6 @@ class TransactionQueue(object): # hence why we throw the result away. yield get_retry_limiter(destination, self.clock, self.store) - # XXX: what's this for? - yield run_on_reactor() - pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 912136534d..dabc744890 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from twisted.internet import defer, threads from ._base import BaseHandler @@ -23,7 +24,6 @@ from synapse.api.errors import ( ) from synapse.module_api import ModuleApi from synapse.types import UserID -from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable @@ -423,15 +423,11 @@ class AuthHandler(BaseHandler): def _check_msisdn(self, authdict, _): return self._check_threepid('msisdn', authdict) - @defer.inlineCallbacks def _check_dummy_auth(self, authdict, _): - yield run_on_reactor() - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _check_threepid(self, medium, authdict): - yield run_on_reactor() - if 'threepid_creds' not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 495ac4c648..af94bf33bc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -39,7 +39,7 @@ from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError, logcontext from synapse.util.metrics import measure_func from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor, Linearizer +from synapse.util.async import Linearizer from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, @@ -1381,8 +1381,6 @@ class FederationHandler(BaseHandler): def get_state_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ - yield run_on_reactor() - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1425,8 +1423,6 @@ class FederationHandler(BaseHandler): def get_state_ids_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ - yield run_on_reactor() - state_groups = yield self.store.get_state_groups_ids( room_id, [event_id] ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 529400955d..f00dfe1d3e 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,7 +27,6 @@ from synapse.api.errors import ( MatrixCodeMessageException, CodeMessageException ) from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.api.errors import SynapseError, Codes logger = logging.getLogger(__name__) @@ -62,8 +61,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def threepid_from_creds(self, creds): - yield run_on_reactor() - if 'id_server' in creds: id_server = creds['id_server'] elif 'idServer' in creds: @@ -106,7 +103,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def bind_threepid(self, creds, mxid): - yield run_on_reactor() logger.debug("binding threepid %r to %s", creds, mxid) data = None @@ -188,8 +184,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): - yield run_on_reactor() - if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, @@ -224,8 +218,6 @@ class IdentityHandler(BaseHandler): self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): - yield run_on_reactor() - if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1cb81b6cf8..18dcc6d196 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -36,7 +36,7 @@ from synapse.events.validator import EventValidator from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) -from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter +from synapse.util.async import ReadWriteLock, Limiter from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func from synapse.util.frozenutils import frozendict_json_encoder @@ -959,9 +959,7 @@ class EventCreationHandler(object): event_stream_id, max_stream_id ) - @defer.inlineCallbacks def _notify(): - yield run_on_reactor() try: self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7e52adda3c..e76ef5426d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -24,7 +24,7 @@ from synapse.api.errors import ( from synapse.http.client import CaptchaServerHttpClient from synapse import types from synapse.types import UserID, create_requester, RoomID, RoomAlias -from synapse.util.async import run_on_reactor, Linearizer +from synapse.util.async import Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler @@ -139,7 +139,6 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield run_on_reactor() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -431,8 +430,6 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield run_on_reactor() - if localpart is None: raise SynapseError(400, "Request must include user id") diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 750d11ca38..36bb5bbc65 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,7 +19,6 @@ import logging from twisted.internet import defer from synapse.push.pusher import PusherFactory -from synapse.util.async import run_on_reactor from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -125,7 +124,6 @@ class PusherPool: @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): - yield run_on_reactor() try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id @@ -151,7 +149,6 @@ class PusherPool: @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): - yield run_on_reactor() try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 9b3022e0b0..c10320dedf 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -24,8 +24,6 @@ import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request from synapse.types import create_requester -from synapse.util.async import run_on_reactor - from hashlib import sha1 import hmac import logging @@ -272,7 +270,6 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_password(self, request, register_json, session): - yield run_on_reactor() if (self.hs.config.enable_registration_captcha and not session[LoginType.RECAPTCHA]): # captcha should've been done by this stage! @@ -333,8 +330,6 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_shared_secret(self, request, register_json, session): - yield run_on_reactor() - if not isinstance(register_json.get("mac", None), string_types): raise SynapseError(400, "Expected mac.") if not isinstance(register_json.get("user", None), string_types): @@ -423,8 +418,6 @@ class CreateUserRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_create(self, requester, user_json): - yield run_on_reactor() - if "localpart" not in user_json: raise SynapseError(400, "Expected 'localpart' key.") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 0291fba9e7..e1281cfbb6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -24,7 +24,6 @@ from synapse.http.servlet import ( RestServlet, assert_params_in_request, parse_json_object_from_request, ) -from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import check_3pid_allowed from ._base import client_v2_patterns, interactive_auth_handler @@ -300,8 +299,6 @@ class ThreepidRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield run_on_reactor() - requester = yield self.auth.get_user_by_req(request) threepids = yield self.datastore.user_get_threepids( @@ -312,8 +309,6 @@ class ThreepidRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) threePidCreds = body.get('threePidCreds') @@ -365,8 +360,6 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) required = ['medium', 'address'] diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5cab00aea9..97e7c0f7c6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -32,7 +32,6 @@ from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac from hashlib import sha1 -from synapse.util.async import run_on_reactor from synapse.util.ratelimitutils import FederationRateLimiter from six import string_types @@ -191,8 +190,6 @@ class RegisterRestServlet(RestServlet): @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) kind = "user" diff --git a/synapse/util/async.py b/synapse/util/async.py index 9dd4e6b5bc..b8e57efc54 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError from twisted.python import failure @@ -41,13 +40,6 @@ def sleep(seconds): defer.returnValue(res) -def run_on_reactor(): - """ This will cause the rest of the function to be invoked upon the next - iteration of the main loop - """ - return sleep(0) - - class ObservableDeferred(object): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original @@ -227,7 +219,7 @@ class Linearizer(object): # the context manager, but it needs to happen while we hold the # lock, and the context manager's exit code must be synchronous, # so actually this is the only sensible place. - yield run_on_reactor() + yield sleep(0) else: logger.info("Acquired uncontended linearizer lock %r for key %r", diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 010aeaee7e..c066381698 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -19,7 +19,6 @@ from twisted.internet import defer from mock import Mock, patch from synapse.util.distributor import Distributor -from synapse.util.async import run_on_reactor class DistributorTestCase(unittest.TestCase): @@ -95,7 +94,6 @@ class DistributorTestCase(unittest.TestCase): @defer.inlineCallbacks def observer(): - yield run_on_reactor() raise MyException("Oopsie") self.dist.observe("whail", observer) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2516fe40f4..24754591df 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -18,7 +18,6 @@ import logging import mock from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext from twisted.internet import defer from synapse.util.caches import descriptors @@ -195,7 +194,6 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() raise SynapseError(400, "blah") return inner_fn() -- cgit 1.5.1 From 5c9afd6f80cf04367fe9b02c396af9f85e02a611 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 14 Jun 2018 12:28:36 +0100 Subject: Make default state_default 50 Make it so that, before there is a power-levels event in the room, you need a power level of at least 50 to send state. Partially addresses https://github.com/matrix-org/matrix-doc/issues/1192 --- synapse/event_auth.py | 34 +++++------ tests/test_event_auth.py | 153 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 20 deletions(-) create mode 100644 tests/test_event_auth.py (limited to 'tests') diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 36a48870a0..f512d88145 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -370,28 +370,22 @@ def get_send_level(etype, state_key, power_levels_event): int: power level required to send this event. """ - send_level_event = power_levels_event # todo: rename refs below - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) + if power_levels_event: + power_levels_content = power_levels_event.content else: - send_level = 0 + power_levels_content = {} + + # see if we have a custom level for this event type + send_level = power_levels_content.get("events", {}).get(etype) + + # otherwise, fall back to the state_default/events_default. + if send_level is None: + if state_key is not None: + send_level = power_levels_content.get("state_default", 50) + else: + send_level = power_levels_content.get("events_default", 0) - return send_level + return int(send_level) def _can_send_event(event, auth_events): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py new file mode 100644 index 0000000000..3d756b0d85 --- /dev/null +++ b/tests/test_event_auth.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse import event_auth +from synapse.api.errors import AuthError +from synapse.events import FrozenEvent +import unittest + + +class EventAuthTestCase(unittest.TestCase): + @unittest.DEBUG + def test_random_users_cannot_send_state_before_first_pl(self): + """ + Check that, before the first PL lands, the creator is the only user + that can send a state event. + """ + creator = "@creator:example.com" + joiner = "@joiner:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.member", joiner): _join_event(joiner), + } + + # creator should be able to send state + event_auth.check( + _random_state_event(creator), auth_events, + do_sig_check=False, + ) + + # joiner should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(joiner), + auth_events, + do_sig_check=False, + ), + + @unittest.DEBUG + def test_state_default_level(self): + """ + Check that users above the state_default level can send state and + those below cannot + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + king = "@joiner2:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event(creator, { + "state_default": "30", + "users": { + pleb: "29", + king: "30", + }, + }), + ("m.room.member", pleb): _join_event(pleb), + ("m.room.member", king): _join_event(king), + } + + # pleb should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(pleb), + auth_events, + do_sig_check=False, + ), + + # king should be able to send state + event_auth.check( + _random_state_event(king), auth_events, + do_sig_check=False, + ) + + +# helpers for making events + +TEST_ROOM_ID = "!test:room" + + +def _create_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.create", + "sender": user_id, + "content": { + "creator": user_id, + }, + }) + + +def _join_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "content": { + "membership": "join", + }, + }) + + +def _power_levels_event(sender, content): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.power_levels", + "sender": sender, + "state_key": "", + "content": content, + }) + + +def _random_state_event(sender): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "test.state", + "sender": sender, + "state_key": "", + "content": { + "membership": "join", + }, + }) + + +event_count = 0 + + +def _get_event_id(): + global event_count + c = event_count + event_count += 1 + return "!%i:example.com" % (c, ) -- cgit 1.5.1 From a502cfec0054c91d314887771ec7d3aaa8491701 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 14 Jun 2018 14:20:53 +0100 Subject: remove spurious debug --- tests/test_event_auth.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 3d756b0d85..d08e19c53a 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -20,7 +20,6 @@ import unittest class EventAuthTestCase(unittest.TestCase): - @unittest.DEBUG def test_random_users_cannot_send_state_before_first_pl(self): """ Check that, before the first PL lands, the creator is the only user @@ -49,7 +48,6 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ), - @unittest.DEBUG def test_state_default_level(self): """ Check that users above the state_default level can send state and -- cgit 1.5.1 From 1e77ac66e38eca4b4da0c3f95efc3300504292e4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 14 Jun 2018 14:21:29 +0100 Subject: Fix broken unit test We need power levels for this test to do what it is supposed to do. --- tests/test_state.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_state.py b/tests/test_state.py index a5c5e55951..71c412faf4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase): } ) + power_levels = create_event( + type=EventTypes.PowerLevels, state_key="", + content={"users": { + "@foo:bar": "100", + "@user_id:example.com": "100", + }} + ) + creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} @@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] @@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, context.current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] @@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, context.current_state_ids[("test1", "1")] ) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, -- cgit 1.5.1 From 77ac14b960cb8daef76062ce85fc0427749b48af Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 22 Jun 2018 09:37:10 +0100 Subject: Pass around the reactor explicitly (#3385) --- synapse/handlers/auth.py | 30 +++++++++++++++++++---------- synapse/handlers/message.py | 1 + synapse/handlers/user_directory.py | 9 ++++----- synapse/http/client.py | 6 +++--- synapse/http/matrixfederationclient.py | 5 +++-- synapse/notifier.py | 3 +++ synapse/replication/http/send_event.py | 6 +++--- synapse/rest/media/v1/media_repository.py | 3 ++- synapse/rest/media/v1/media_storage.py | 7 +++++-- synapse/server.py | 19 ++++++++++++++---- synapse/storage/background_updates.py | 3 +-- synapse/storage/client_ips.py | 6 ++++-- synapse/storage/event_push_actions.py | 3 +-- synapse/storage/events_worker.py | 6 +++--- synapse/util/__init__.py | 32 +++++++++++++++++++------------ synapse/util/async.py | 25 +++++++++++------------- synapse/util/file_consumer.py | 16 +++++++++++----- synapse/util/ratelimitutils.py | 3 +-- tests/crypto/test_keyring.py | 9 +++++---- tests/rest/client/test_transactions.py | 6 +++--- tests/rest/media/v1/test_media_storage.py | 5 +++-- tests/util/test_file_consumer.py | 6 +++--- tests/util/test_linearizer.py | 7 ++++--- tests/util/test_logcontext.py | 11 ++++++----- tests/utils.py | 7 ++++++- 25 files changed, 141 insertions(+), 93 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dabc744890..a131b7f73f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -33,6 +33,7 @@ import logging import bcrypt import pymacaroons import simplejson +import attr import synapse.util.stringutils as stringutils @@ -854,7 +855,11 @@ class AuthHandler(BaseHandler): return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, bcrypt.gensalt(self.bcrypt_rounds)) - return make_deferred_yieldable(threads.deferToThread(_do_hash)) + return make_deferred_yieldable( + threads.deferToThreadPool( + self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash + ), + ) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -874,16 +879,21 @@ class AuthHandler(BaseHandler): ) if stored_hash: - return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) + return make_deferred_yieldable( + threads.deferToThreadPool( + self.hs.get_reactor(), + self.hs.get_reactor().getThreadPool(), + _do_validate_hash, + ), + ) else: return defer.succeed(False) -class MacaroonGeneartor(object): - def __init__(self, hs): - self.clock = hs.get_clock() - self.server_name = hs.config.server_name - self.macaroon_secret_key = hs.config.macaroon_secret_key +@attr.s +class MacaroonGenerator(object): + + hs = attr.ib() def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] @@ -901,7 +911,7 @@ class MacaroonGeneartor(object): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") - now = self.clock.time_msec() + now = self.hs.get_clock().time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() @@ -913,9 +923,9 @@ class MacaroonGeneartor(object): def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( - location=self.server_name, + location=self.hs.config.server_name, identifier="key", - key=self.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 18dcc6d196..7b9946ab91 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -806,6 +806,7 @@ class EventCreationHandler(object): # If we're a worker we need to hit out to the master. if self.config.worker_app: yield send_event_to_master( + self.hs.get_clock(), self.http_client, host=self.config.worker_replication_host, port=self.config.worker_replication_http_port, diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index a39f0f7343..7e4a114d4f 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,7 +19,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.storage.roommember import ProfileInfo from synapse.util.metrics import Measure -from synapse.util.async import sleep from synapse.types import get_localpart_from_id from six import iteritems @@ -174,7 +173,7 @@ class UserDirectoryHandler(object): logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) yield self._handle_initial_room(room_id) num_processed_rooms += 1 - yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) logger.info("Processed all rooms.") @@ -188,7 +187,7 @@ class UserDirectoryHandler(object): logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) yield self._handle_local_user(user_id) num_processed_users += 1 - yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.) logger.info("Processed all users") @@ -236,7 +235,7 @@ class UserDirectoryHandler(object): count = 0 for user_id in user_ids: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) if not self.is_mine_id(user_id): count += 1 @@ -251,7 +250,7 @@ class UserDirectoryHandler(object): continue if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) count += 1 user_set = (user_id, other_user_id) diff --git a/synapse/http/client.py b/synapse/http/client.py index 8064a84c5c..46ffb41de1 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -98,8 +98,8 @@ class SimpleHttpClient(object): method, uri, *args, **kwargs ) add_timeout_to_deferred( - request_deferred, - 60, cancelled_to_request_timed_out_error, + request_deferred, 60, self.hs.get_reactor(), + cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) @@ -115,7 +115,7 @@ class SimpleHttpClient(object): "Error sending request to %s %s: %s %s", method, redact_uri(uri), type(e).__name__, e.message ) - raise e + raise @defer.inlineCallbacks def post_urlencoded_get_json(self, uri, args={}, headers=None): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 993dc06e02..4e0399e762 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint import synapse.metrics -from synapse.util.async import sleep, add_timeout_to_deferred +from synapse.util.async import add_timeout_to_deferred from synapse.util import logcontext from synapse.util.logcontext import make_deferred_yieldable import synapse.util.retryutils @@ -193,6 +193,7 @@ class MatrixFederationHttpClient(object): add_timeout_to_deferred( request_deferred, timeout / 1000. if timeout else 60, + self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable( @@ -234,7 +235,7 @@ class MatrixFederationHttpClient(object): delay = min(delay, 2) delay *= random.uniform(0.8, 1.4) - yield sleep(delay) + yield self.clock.sleep(delay) retries_left -= 1 else: raise diff --git a/synapse/notifier.py b/synapse/notifier.py index 6dce20a284..3c0622a294 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -161,6 +161,7 @@ class Notifier(object): self.user_to_user_stream = {} self.room_to_user_streams = {} + self.hs = hs self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -340,6 +341,7 @@ class Notifier(object): add_timeout_to_deferred( listener.deferred, (end_time - now) / 1000., + self.hs.get_reactor(), ) with PreserveLoggingContext(): yield listener.deferred @@ -561,6 +563,7 @@ class Notifier(object): add_timeout_to_deferred( listener.deferred.addTimeout, (end_time - now) / 1000., + self.hs.get_reactor(), ) try: with PreserveLoggingContext(): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index a9baa2c1c3..f080f96cc1 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -21,7 +21,6 @@ from synapse.api.errors import ( from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.util.async import sleep from synapse.util.caches.response_cache import ResponseCache from synapse.util.metrics import Measure from synapse.types import Requester, UserID @@ -33,11 +32,12 @@ logger = logging.getLogger(__name__) @defer.inlineCallbacks -def send_event_to_master(client, host, port, requester, event, context, +def send_event_to_master(clock, client, host, port, requester, event, context, ratelimit, extra_users): """Send event to be handled on the master Args: + clock (synapse.util.Clock) client (SimpleHttpClient) host (str): host of master port (int): port on master listening for HTTP replication @@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context, # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. - yield sleep(1) + yield clock.sleep(1) except MatrixCodeMessageException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 2ac767d2dc..218ba7a083 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository(object): def __init__(self, hs): + self.hs = hs self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() @@ -94,7 +95,7 @@ class MediaRepository(object): storage_providers.append(provider) self.media_storage = MediaStorage( - self.primary_base_path, self.filepaths, storage_providers, + self.hs, self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index d23fe10b07..d6b8ebbedb 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -37,13 +37,15 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: + hs (synapse.server.Homeserver) local_media_directory (str): Base path where we store media on disk filepaths (MediaFilePaths) storage_providers ([StorageProvider]): List of StorageProvider that are used to fetch and store files. """ - def __init__(self, local_media_directory, filepaths, storage_providers): + def __init__(self, hs, local_media_directory, filepaths, storage_providers): + self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers @@ -175,7 +177,8 @@ class MediaStorage(object): res = yield provider.fetch(path, file_info) if res: with res: - consumer = BackgroundFileConsumer(open(local_path, "w")) + consumer = BackgroundFileConsumer( + open(local_path, "w"), self.hs.get_reactor()) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) diff --git a/synapse/server.py b/synapse/server.py index 58dbf78437..c29c19289a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler, MacaroonGeneartor +from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler @@ -165,15 +165,19 @@ class HomeServer(object): 'server_notices_sender', ] - def __init__(self, hostname, **kwargs): + def __init__(self, hostname, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. """ + if not reactor: + from twisted.internet import reactor + + self._reactor = reactor self.hostname = hostname self._building = {} - self.clock = Clock() + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -186,6 +190,12 @@ class HomeServer(object): self.datastore = DataStore(self.get_db_conn(), self) logger.info("Finished setting up.") + def get_reactor(self): + """ + Fetch the Twisted reactor in use by this HomeServer. + """ + return self._reactor + def get_ip_from_request(self, request): # X-Forwarded-For is handled by our custom request type. return request.getClientIP() @@ -261,7 +271,7 @@ class HomeServer(object): return AuthHandler(self) def build_macaroon_generator(self): - return MacaroonGeneartor(self) + return MacaroonGenerator(self) def build_device_handler(self): return DeviceHandler(self) @@ -328,6 +338,7 @@ class HomeServer(object): return adbapi.ConnectionPool( name, + cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) ) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 8af325a9f5..b7e9c716c8 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.async from ._base import SQLBaseStore from . import engines @@ -92,7 +91,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Starting background schema updates") while True: - yield synapse.util.async.sleep( + yield self.hs.get_clock().sleep( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) try: diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index ce338514e8..968d2fed22 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,7 +15,7 @@ import logging -from twisted.internet import defer, reactor +from twisted.internet import defer from ._base import Cache from . import background_updates @@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 ) - reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, now=None): diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d0350ee5fe..c4a0208ce4 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -16,7 +16,6 @@ from synapse.storage._base import SQLBaseStore, LoggingTransaction from twisted.internet import defer -from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks import logging @@ -800,7 +799,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) if caught_up: break - yield sleep(5) + yield self.hs.get_clock().sleep(5) finally: self._doing_notif_rotation = False diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 32d9d00ffb..38fcf7d444 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.events import FrozenEvent from synapse.events.utils import prune_event @@ -265,7 +265,7 @@ class EventsWorkerStore(SQLBaseStore): except Exception: logger.exception("Failed to callback") with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) + self.hs.get_reactor().callFromThread(fire, event_list, row_dict) except Exception as e: logger.exception("do_fetch") @@ -278,7 +278,7 @@ class EventsWorkerStore(SQLBaseStore): if event_list: with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index fc11e26623..2a3df7c71d 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import PreserveLoggingContext - -from twisted.internet import defer, reactor, task - -import time import logging - from itertools import islice +import attr +from twisted.internet import defer, task + +from synapse.util.logcontext import PreserveLoggingContext + logger = logging.getLogger(__name__) @@ -31,16 +30,24 @@ def unwrapFirstError(failure): return failure.value.subFailure +@attr.s class Clock(object): - """A small utility that obtains current time-of-day so that time may be - mocked during unit-tests. - - TODO(paul): Also move the sleep() functionality into it """ + A Clock wraps a Twisted reactor and provides utilities on top of it. + """ + _reactor = attr.ib() + + @defer.inlineCallbacks + def sleep(self, seconds): + d = defer.Deferred() + with PreserveLoggingContext(): + self._reactor.callLater(seconds, d.callback, seconds) + res = yield d + defer.returnValue(res) def time(self): """Returns the current system time in seconds since epoch.""" - return time.time() + return self._reactor.seconds() def time_msec(self): """Returns the current system time in miliseconds since epoch.""" @@ -56,6 +63,7 @@ class Clock(object): msec(float): How long to wait between calls in milliseconds. """ call = task.LoopingCall(f) + call.clock = self._reactor call.start(msec / 1000.0, now=False) return call @@ -73,7 +81,7 @@ class Clock(object): callback(*args, **kwargs) with PreserveLoggingContext(): - return reactor.callLater(delay, wrapped_callback, *args, **kwargs) + return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer, ignore_errs=False): try: diff --git a/synapse/util/async.py b/synapse/util/async.py index b8e57efc54..1668df4ce6 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor +from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure from .logcontext import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background ) -from synapse.util import logcontext, unwrapFirstError +from synapse.util import logcontext, unwrapFirstError, Clock from contextlib import contextmanager @@ -31,15 +31,6 @@ from six.moves import range logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def sleep(seconds): - d = defer.Deferred() - with PreserveLoggingContext(): - reactor.callLater(seconds, d.callback, seconds) - res = yield d - defer.returnValue(res) - - class ObservableDeferred(object): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original @@ -172,13 +163,18 @@ class Linearizer(object): # do some work. """ - def __init__(self, name=None): + def __init__(self, name=None, clock=None): if name is None: self.name = id(self) else: self.name = name self.key_to_defer = {} + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock + @defer.inlineCallbacks def queue(self, key): # If there is already a deferred in the queue, we pull it out so that @@ -219,7 +215,7 @@ class Linearizer(object): # the context manager, but it needs to happen while we hold the # lock, and the context manager's exit code must be synchronous, # so actually this is the only sensible place. - yield sleep(0) + yield self._clock.sleep(0) else: logger.info("Acquired uncontended linearizer lock %r for key %r", @@ -396,7 +392,7 @@ class DeferredTimeoutError(Exception): """ -def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None): +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): """ Add a timeout to a deferred by scheduling it to be cancelled after timeout seconds. @@ -411,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None): Args: deferred (defer.Deferred): deferred to be timed out timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 3380970e4e..c78801015b 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import threads, reactor +from twisted.internet import threads from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -27,6 +27,7 @@ class BackgroundFileConsumer(object): Args: file_obj (file): The file like object to write to. Closed when finished. + reactor (twisted.internet.reactor): the Twisted reactor to use """ # For PushProducers pause if we have this many unwritten slices @@ -34,9 +35,11 @@ class BackgroundFileConsumer(object): # And resume once the size of the queue is less than this _RESUME_ON_QUEUE_SIZE = 2 - def __init__(self, file_obj): + def __init__(self, file_obj, reactor): self._file_obj = file_obj + self._reactor = reactor + # Producer we're registered with self._producer = None @@ -71,7 +74,10 @@ class BackgroundFileConsumer(object): self._producer = producer self.streaming = streaming self._finished_deferred = run_in_background( - threads.deferToThread, self._writer + threads.deferToThreadPool, + self._reactor, + self._reactor.getThreadPool(), + self._writer, ) if not streaming: self._producer.resumeProducing() @@ -109,7 +115,7 @@ class BackgroundFileConsumer(object): # producer. if self._producer and self._paused_producer: if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: - reactor.callFromThread(self._resume_paused_producer) + self._reactor.callFromThread(self._resume_paused_producer) bytes = self._bytes_queue.get() @@ -121,7 +127,7 @@ class BackgroundFileConsumer(object): # If its a pull producer then we need to explicitly ask for # more stuff. if not self.streaming and self._producer: - reactor.callFromThread(self._producer.resumeProducing) + self._reactor.callFromThread(self._producer.resumeProducing) except Exception as e: self._write_exception = e raise diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 0ab63c3d7d..c5a45cef7c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError -from synapse.util.async import sleep from synapse.util.logcontext import ( run_in_background, make_deferred_yieldable, PreserveLoggingContext, @@ -153,7 +152,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0) + ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 149e443022..cc1c862ba4 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,10 +19,10 @@ import signedjson.sign from mock import Mock from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.util import async, logcontext +from synapse.util import logcontext, Clock from synapse.util.logcontext import LoggingContext from tests import unittest, utils -from twisted.internet import defer +from twisted.internet import defer, reactor class MockPerspectiveServer(object): @@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_verify_json_objects_for_server_awaits_previous_requests(self): + clock = Clock(reactor) key1 = signedjson.key.generate_signing_key(1) kr = keyring.Keyring(self.hs) @@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase): # wait a tick for it to send the request to the perspectives server # (it first tries the datastore) - yield async.sleep(1) # XXX find out why this takes so long! + yield clock.sleep(1) # XXX find out why this takes so long! self.http_client.post_json.assert_called_once() self.assertIs(LoggingContext.current_context(), context_11) @@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(1) + yield clock.sleep(1) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index b5bc2fa255..6a757289db 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,9 +1,9 @@ from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import CLEANUP_PERIOD_MS -from twisted.internet import defer +from twisted.internet import defer, reactor from mock import Mock, call -from synapse.util import async +from synapse.util import Clock from synapse.util.logcontext import LoggingContext from tests import unittest from tests.utils import MockClock @@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test_logcontexts_with_async_result(self): @defer.inlineCallbacks def cb(): - yield async.sleep(0) + yield Clock(reactor).sleep(0) defer.returnValue("yay") @defer.inlineCallbacks diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index eef38b6781..c5e2f5549a 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -14,7 +14,7 @@ # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import MediaStorage @@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase): self.secondary_base_path = os.path.join(self.test_dir, "secondary") hs = Mock() + hs.get_reactor = Mock(return_value=reactor) hs.config.media_store_path = self.primary_base_path storage_providers = [FileStorageProviderBackend( @@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase): self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - self.primary_base_path, self.filepaths, storage_providers, + hs, self.primary_base_path, self.filepaths, storage_providers, ) def tearDown(self): diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index d6e1082779..c2aae8f54c 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_pull_consumer(self): string_file = StringIO() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = DummyPullProducer() @@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_consumer(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=[]) @@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_producer_feedback(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4865eb4bc6..bf7e3aa885 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import async, logcontext + +from synapse.util import logcontext, Clock from tests import unittest -from twisted.internet import defer +from twisted.internet import defer, reactor from synapse.util.async import Linearizer from six.moves import range @@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase): self.assertEqual( logcontext.LoggingContext.current_context(), lc) if sleep: - yield async.sleep(0) + yield Clock(reactor).sleep(0) self.assertEqual( logcontext.LoggingContext.current_context(), lc) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index ad78d884e0..9cf90fcfc4 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -3,8 +3,7 @@ from twisted.internet import defer from twisted.internet import reactor from .. import unittest -from synapse.util.async import sleep -from synapse.util import logcontext +from synapse.util import logcontext, Clock from synapse.util.logcontext import LoggingContext @@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_sleep(self): + clock = Clock(reactor) + @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: competing_context.request = "competing" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: context_one.request = "one" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("one") def _test_run_in_background(self, function): @@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_run_in_background_with_blocking_fn(self): @defer.inlineCallbacks def blocking_function(): - yield sleep(0) + yield Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) diff --git a/tests/utils.py b/tests/utils.py index 262c4a5714..189fd2711c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks -def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): +def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None, + **kargs): """Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. If no datastore is supplied a datastore backed by an in-memory sqlite db will be given to the HS. """ + if reactor is None: + from twisted.internet import reactor + if config is None: config = Mock() config.signing_key = [MockKey()] @@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) db_conn = hs.get_db_conn() -- cgit 1.5.1 From 43e02c409d163700a293ae67015584699d557c3c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 11 Jun 2018 23:13:06 +0100 Subject: Disable partial state group caching for wildcard lookups When _get_state_for_groups is given a wildcard filter, just do a complete lookup. Hopefully this will give us the best of both worlds by not filling up the ram if we only need one or two keys, but also making the cache still work for the federation reader usecase. --- synapse/storage/state.py | 56 +++++++++++++++++++++++++-------- synapse/util/caches/dictionary_cache.py | 25 +++++++-------- tests/util/test_dict_cache.py | 12 +++---- 3 files changed, 61 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 986a20400c..cd9821c270 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -526,10 +526,23 @@ class StateGroupWorkerStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): - """Given list of groups returns dict of group -> list of state events - with matching types. `types` is a list of `(type, state_key)`, where - a `state_key` of None matches all state_keys. If `types` is None then - all events are returned. + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. """ if types: types = frozenset(types) @@ -538,7 +551,7 @@ class StateGroupWorkerStore(SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, _, got_all = self._get_some_state_from_cache( - group, types + group, types, ) results[group] = state_dict_ids @@ -559,22 +572,40 @@ class StateGroupWorkerStore(SQLBaseStore): # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence + # the DictionaryCache knows if it has *all* the state, but + # does not know if it has all of the keys of a particular type, + # which makes wildcard lookups expensive unless we have a complete + # cache. Hence, if we are doing a wildcard lookup, populate the + # cache fully so that we can do an efficient lookup next time. + + if types and any(k is None for (t, k) in types): + types_to_fetch = None + else: + types_to_fetch = types + group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types + missing_groups, types_to_fetch, ) - # Now we want to update the cache with all the things we fetched - # from the database. for group, group_state_dict in iteritems(group_to_state_dict): state_dict = results[group] - state_dict.update(group_state_dict) + # update the result, filtering by `types`. + if types: + for k, v in iteritems(group_state_dict): + (typ, _) = k + if k in types or (typ, None) in types: + state_dict[k] = v + else: + state_dict.update(group_state_dict) + + # update the cache with all the things we fetched from the + # database. self._state_group_cache.update( cache_seq_num, key=group, - value=state_dict, - full=(types is None), - known_absent=types, + value=group_state_dict, + fetched_keys=types_to_fetch, ) defer.returnValue(results) @@ -681,7 +712,6 @@ class StateGroupWorkerStore(SQLBaseStore): self._state_group_cache.sequence, key=state_group, value=dict(current_state_ids), - full=True, ) return state_group diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index bdc21e348f..95793d466d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -107,29 +107,28 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False, known_absent=None): + def update(self, sequence, key, value, fetched_keys=None): """Updates the entry in the cache Args: sequence - key - value (dict): The value to update the cache with. - full (bool): Whether the given value is the full dict, or just a - partial subset there of. If not full then any existing entries - for the key will be updated. - known_absent (set): Set of keys that we know don't exist in the full - dict. + key (K) + value (dict[X,Y]): The value to update the cache with. + fetched_keys (None|set[X]): All of the dictionary keys which were + fetched from the database. + + If None, this is the complete value for key K. Otherwise, it + is used to infer a list of keys which we know don't exist in + the full dict. """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - if known_absent is None: - known_absent = set() - if full: - self._insert(key, value, known_absent) + if fetched_keys is None: + self._insert(key, value, set()) else: - self._update_or_insert(key, value, known_absent) + self._update_or_insert(key, value, fetched_keys) def _update_or_insert(self, key, value, known_absent): # We pop and reinsert as we need to tell the cache the size may have diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index bc92f85fa6..543ac5bed9 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase): seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key) self.assertEqual(test_value, c.value) @@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_hit_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) @@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_miss_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) @@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase): "test2": "test_simple_cache_hit_miss_partial2", "test3": "test_simple_cache_hit_miss_partial3", } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) @@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase): test_value_1 = { "test": "test_simple_cache_hit_miss_partial", } - self.cache.update(seq, key, test_value_1, full=False) + self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) seq = self.cache.sequence test_value_2 = { "test2": "test_simple_cache_hit_miss_partial2", } - self.cache.update(seq, key, test_value_2, full=False) + self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) c = self.cache.get(key) self.assertEqual( -- cgit 1.5.1 From cd6bcdaf87f6c68e0c95b789c8fcb144a0d64b1a Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 10:37:24 +0100 Subject: Better testing framework for homeserver-using things (#3446) --- changelog.d/3446.misc | 0 tests/server.py | 181 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_server.py | 128 +++++++++++++++++++++++++++++++++++ 3 files changed, 309 insertions(+) create mode 100644 changelog.d/3446.misc create mode 100644 tests/server.py create mode 100644 tests/test_server.py (limited to 'tests') diff --git a/changelog.d/3446.misc b/changelog.d/3446.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000000..73069dff52 --- /dev/null +++ b/tests/server.py @@ -0,0 +1,181 @@ +from io import BytesIO + +import attr +import json +from six import text_type + +from twisted.python.failure import Failure +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.site import SynapseRequest +from twisted.internet import threads +from tests.utils import setup_test_homeserver as _sth + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + result = attr.ib(factory=dict) + + @property + def json_body(self): + if not self.result: + raise Exception("No result yet.") + return json.loads(self.result["body"]) + + def writeHeaders(self, version, code, reason, headers): + self.result["version"] = version + self.result["code"] = code + self.result["reason"] = reason + self.result["headers"] = headers + + def write(self, content): + if "body" not in self.result: + self.result["body"] = b"" + + self.result["body"] += content + + def requestDone(self, _self): + self.result["done"] = True + + def getPeer(self): + return None + + def getHost(self): + return None + + @property + def transport(self): + return self + + +class FakeSite: + """ + A fake Twisted Web Site, with mocks of the extra things that + Synapse adds. + """ + + server_version_string = b"1" + site_tag = "test" + + @property + def access_logger(self): + class FakeLogger: + def info(self, *args, **kwargs): + pass + + return FakeLogger() + + +def make_request(method, path, content=b""): + """ + Make a web request using the given method and path, feed it the + content, and return the Request and the Channel underneath. + """ + + if isinstance(content, text_type): + content = content.encode('utf8') + + site = FakeSite() + channel = FakeChannel() + + req = SynapseRequest(site, channel) + req.process = lambda: b"" + req.content = BytesIO(content) + req.requestReceived(method, path, b"1.1") + + return req, channel + + +def wait_until_result(clock, channel, timeout=100): + """ + Wait until the channel has a result. + """ + clock.run() + x = 0 + + while not channel.result: + x += 1 + + if x > timeout: + raise Exception("Timed out waiting for request to finish.") + + clock.advance(0.1) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + return d + + +def setup_test_homeserver(*args, **kwargs): + """ + Set up a synchronous test server, driven by the reactor used by + the homeserver. + """ + d = _sth(*args, **kwargs).result + + # Make the thread pool synchronous. + clock = d.get_clock() + pool = d.get_db_pool() + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) + + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + + class ThreadPool: + """ + Threadless thread pool. + """ + def start(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + clock._reactor.callLater(0, d.callback, True) + return d + + clock.threadpool = ThreadPool() + pool.threadpool = ThreadPool() + return d diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000000..8ad822c43b --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,128 @@ +import json +import re + +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.util import Clock +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import JsonResource +from tests import unittest +from tests.server import make_request, setup_test_homeserver + + +class JsonResourceTests(unittest.TestCase): + def setUp(self): + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.reactor + ) + + def test_handler_for_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + got_kwargs = {} + + def _callback(request, **kwargs): + got_kwargs.update(kwargs) + return (200, kwargs) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo/(?P[^/]*)$")], _callback) + + request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") + request.render(res) + + self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) + self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + + def test_callback_direct_exception(self): + """ + If the web callback raises an uncaught exception, it will be translated + into a 500. + """ + + def _callback(request, **kwargs): + raise Exception("boo") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'500') + + def test_callback_indirect_exception(self): + """ + If the web callback raises an uncaught exception in a Deferred, it will + be translated into a 500. + """ + + def _throw(*args): + raise Exception("boo") + + def _callback(request, **kwargs): + d = Deferred() + d.addCallback(_throw) + self.reactor.callLater(1, d.callback, True) + return d + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + # No error has been raised yet + self.assertTrue("code" not in channel.result) + + # Advance time, now there's an error + self.reactor.advance(1) + self.assertEqual(channel.result["code"], b'500') + + def test_callback_synapseerror(self): + """ + If the web callback raises a SynapseError, it returns the appropriate + status code and message set in it. + """ + + def _callback(request, **kwargs): + raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'403') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Forbidden!!one!") + self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + + def test_no_handler(self): + """ + If there is no handler to process the request, Synapse will return 400. + """ + + def _callback(request, **kwargs): + """ + Not ever actually called! + """ + self.fail("shouldn't ever get here") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foobar") + request.render(res) + + self.assertEqual(channel.result["code"], b'400') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Unrecognized request") + self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") -- cgit 1.5.1 From 77078d6c8ef11e2401406edc1ca340e0d7779267 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 11:27:32 +0100 Subject: handle federation not telling us about prev_events --- synapse/federation/federation_server.py | 4 +- synapse/handlers/federation.py | 87 ++++++++---- tests/test_federation.py | 235 ++++++++++++++++++++++++++++++++ tests/unittest.py | 2 +- 4 files changed, 301 insertions(+), 27 deletions(-) create mode 100644 tests/test_federation.py (limited to 'tests') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d4dd967c60..4096093527 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -549,7 +549,9 @@ class FederationServer(FederationBase): affected=pdu.event_id, ) - yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) + yield self.handler.on_receive_pdu( + origin, pdu, get_missing=True, sent_to_us_directly=True, + ) def __str__(self): return "" % self.server_name diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b6f8d4cf82..250a5509d8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -44,6 +44,7 @@ from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, ) +from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.events.utils import prune_event @@ -89,7 +90,9 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_receive_pdu(self, origin, pdu, get_missing=True): + def on_receive_pdu( + self, origin, pdu, get_missing=True, sent_to_us_directly=False, + ): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -163,7 +166,7 @@ class FederationHandler(BaseHandler): "Ignoring PDU %s for room %s from %s as we've left the room!", pdu.event_id, pdu.room_id, origin, ) - return + defer.returnValue(None) state = None @@ -225,26 +228,54 @@ class FederationHandler(BaseHandler): list(prevs - seen)[:5], ) - if prevs - seen: - logger.info( - "Still missing %d events for room %r: %r...", - len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + if sent_to_us_directly and prevs - seen: + # If they have sent it to us directly, and the server + # isn't telling us about the auth events that it's + # made a message referencing, we explode + raise FederationError( + "ERROR", + 403, + ("Your server isn't divulging details about prev_events " + "referenced in this event."), + affected=pdu.event_id, ) - fetch_state = True + elif prevs - seen: + # If we're walking back up the chain to fetch it, then + # try and find the states. If we can't get the states, + # discard it. + state_groups = [] + auth_chains = set() + try: + # Get the ones we know about + ours = yield self.store.get_state_groups(pdu.room_id, list(seen)) + state_groups.append(ours) - if fetch_state: - # We need to get the state at this event, since we haven't - # processed all the prev events. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - try: - state, auth_chain = yield self.replication_layer.get_state_for_room( - origin, pdu.room_id, pdu.event_id, - ) - except Exception: - logger.exception("Failed to get state for event: %s", pdu.event_id) + for p in prevs - seen: + state, auth_chain = yield self.replication_layer.get_state_for_room( + origin, pdu.room_id, p + ) + auth_chains.update(auth_chain) + state_group = { + (x.type, x.state_key): x.event_id for x in state + } + state_groups.append(state_group) + + def fetch(ev_ids): + return self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + + state = yield resolve_events_with_factory(state_groups, {pdu.event_id: pdu}, fetch) + + state = yield self.store.get_events(state.values()) + state = state.values() + except Exception: + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=pdu.event_id, + ) yield self._process_received_pdu( origin, @@ -322,11 +353,17 @@ class FederationHandler(BaseHandler): for e in missing_events: logger.info("Handling found event %s", e.event_id) - yield self.on_receive_pdu( - origin, - e, - get_missing=False - ) + try: + yield self.on_receive_pdu( + origin, + e, + get_missing=False + ) + except FederationError as e: + if e.code == 403: + logger.warn("Event %s failed history check.") + else: + raise @log_function @defer.inlineCallbacks diff --git a/tests/test_federation.py b/tests/test_federation.py new file mode 100644 index 0000000000..12f4633cd5 --- /dev/null +++ b/tests/test_federation.py @@ -0,0 +1,235 @@ + +from twisted.internet.defer import Deferred, succeed, maybeDeferred + +from synapse.util import Clock +from synapse.events import FrozenEvent +from synapse.types import Requester, UserID +from synapse.replication.slave.storage.events import SlavedEventStore + +from tests import unittest +from tests.server import make_request, setup_test_homeserver, ThreadedMemoryReactorClock + +from mock import Mock + +from synapse.api.errors import CodeMessageException, HttpResponseException + + +class MessageAcceptTests(unittest.TestCase): + def setUp(self): + + self.http_client = Mock() + self.reactor = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + ) + + user_id = UserID("us", "test") + our_user = Requester(user_id, None, False, None, None) + room_creator = self.homeserver.get_room_creation_handler() + room = room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) + self.reactor.advance(0.1) + self.room_id = self.successResultOf(room)["room_id"] + + # Figure out what the most recent event is + most_recent = self.successResultOf( + self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + )[0] + + join_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$join:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "prev_state": [(most_recent, {})], + "prev_events": [(most_recent, {})], + } + ) + + self.handler = self.homeserver.get_handlers().federation_handler + self.handler.do_auth = lambda *a, **b: succeed(True) + self.client = self.homeserver.get_federation_client() + self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( + pdus + ) + + # Send the join, it should return None (which is not an error) + d = self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + # Make sure we actually joined the room + self.assertEqual( + self.successResultOf( + self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + )[0], + "$join:test.serv", + ) + + def test_cant_hide_direct_ancestors(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + # If it asks us for new missing events, give them NOTHING + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return {"events": []} + + self.http_client.post_json = post_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + )[0] + + # Now lie about an event + lying_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("two:test.serv", {}), (most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) + + # Step the reactor, so the database fetches come back + self.reactor.advance(1) + + # on_receive_pdu should throw an error + failure = self.failureResultOf(d) + self.assertEqual( + failure.value.args[0], + ( + "ERROR 403: Your server isn't divulging details about prev_events " + "referenced in this event." + ), + ) + + # Make sure the invalid event isn't there + extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + @unittest.DEBUG + def test_cant_hide_past_history(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return { + "events": [ + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "three:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("four:test.serv", {})], + } + ] + } + + self.http_client.post_json = post_json + + def get_json(destination, path, args, headers=None): + print(destination, path, args) + if path.startswith("/_matrix/federation/v1/state_ids/"): + d = self.successResultOf( + self.homeserver.datastore.get_state_ids_for_event("one:test.serv") + ) + + return succeed( + { + "pdu_ids": [ + y + for x, y in d.items() + if x == ("m.room.member", "@us:test") + ], + "auth_chain_ids": d.values(), + } + ) + + self.http_client.get_json = get_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + )[0] + + # Make a good event + good_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [(most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", good_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + bad_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "two:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", bad_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + + extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") + + state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) + self.reactor.advance(1) + self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/unittest.py b/tests/unittest.py index 184fe880f3..de24b1d2d4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -35,7 +35,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace('warning', 'warn') - self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry) + self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")")) handler = ToTwistedHandler() -- cgit 1.5.1 From 8d62baa48cb222c3010007fdd6e48673f5cd0519 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 11:31:48 +0100 Subject: cleanups --- tests/test_federation.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/test_federation.py b/tests/test_federation.py index 12f4633cd5..bc8b3af9b3 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -35,7 +35,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) )[0] join_event = FrozenEvent( @@ -72,7 +72,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -92,7 +92,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -129,7 +129,7 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") @unittest.DEBUG @@ -161,7 +161,6 @@ class MessageAcceptTests(unittest.TestCase): self.http_client.post_json = post_json def get_json(destination, path, args, headers=None): - print(destination, path, args) if path.startswith("/_matrix/federation/v1/state_ids/"): d = self.successResultOf( self.homeserver.datastore.get_state_ids_for_event("one:test.serv") @@ -182,7 +181,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) )[0] # Make a good event @@ -227,7 +226,7 @@ class MessageAcceptTests(unittest.TestCase): ) self.reactor.advance(1) - extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id) + extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) -- cgit 1.5.1 From 35cc3e8b143f69abadbd41f82e463fbcd3528346 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 11:32:09 +0100 Subject: stylistic cleanup --- tests/test_federation.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/test_federation.py b/tests/test_federation.py index bc8b3af9b3..95fa73723c 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -35,7 +35,9 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) )[0] join_event = FrozenEvent( @@ -72,7 +74,9 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) )[0], "$join:test.serv", ) @@ -92,7 +96,9 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) )[0] # Now lie about an event @@ -129,7 +135,9 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") @unittest.DEBUG @@ -181,7 +189,9 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) )[0] # Make a good event @@ -226,7 +236,9 @@ class MessageAcceptTests(unittest.TestCase): ) self.reactor.advance(1) - extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) -- cgit 1.5.1 From 94f09618e54e8ae0a30611f0da463d275768ab74 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 11:38:03 +0100 Subject: cleanups --- tests/unittest.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index de24b1d2d4..b25f2db5d5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace('warning', 'warn') - self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")")) + self.tx_log.emit( + twisted.logger.LogLevel.levelWithName(log_level), + log_entry.replace("{", r"(").replace("}", r")"), + ) handler = ToTwistedHandler() -- cgit 1.5.1 From f03a5d1a1759221bd94d75604f3e4e787cd4133e Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 27 Jun 2018 11:38:14 +0100 Subject: pep8 --- synapse/handlers/federation.py | 10 ++++------ tests/test_federation.py | 7 ++----- 2 files changed, 6 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e9c5d1026a..c4d96749a3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -169,11 +169,8 @@ class FederationHandler(BaseHandler): defer.returnValue(None) state = None - auth_chain = [] - fetch_state = False - # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. @@ -252,9 +249,10 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about for p in prevs - seen: - state, got_auth_chain = yield self.replication_layer.get_state_for_room( - origin, pdu.room_id, p - ) + state, got_auth_chain = ( + yield self.replication_layer.get_state_for_room( + origin, pdu.room_id, p + )) auth_chains.update(got_auth_chain) state_group = {(x.type, x.state_key): x.event_id for x in state} state_groups.append(state_group) diff --git a/tests/test_federation.py b/tests/test_federation.py index 95fa73723c..fc80a69369 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,18 +1,15 @@ -from twisted.internet.defer import Deferred, succeed, maybeDeferred +from twisted.internet.defer import succeed, maybeDeferred from synapse.util import Clock from synapse.events import FrozenEvent from synapse.types import Requester, UserID -from synapse.replication.slave.storage.events import SlavedEventStore from tests import unittest -from tests.server import make_request, setup_test_homeserver, ThreadedMemoryReactorClock +from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock from mock import Mock -from synapse.api.errors import CodeMessageException, HttpResponseException - class MessageAcceptTests(unittest.TestCase): def setUp(self): -- cgit 1.5.1 From e72234f6bda33d89dcca07751e34c62b88215e9d Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 28 Jun 2018 20:56:07 +0100 Subject: fix tests --- synapse/config/appservice.py | 1 + tests/api/test_auth.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 89c07f202f..0c27bb2fa7 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -157,6 +157,7 @@ def _load_appservice(hostname, as_info, config_filename): config_filename, ) + ip_range_whitelist = None if as_info.get('ip_range_whitelist'): ip_range_whitelist = IPSet( as_info.get('ip_range_whitelist') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4575dd9834..48bd411e49 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -86,11 +86,15 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) @@ -119,12 +123,16 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -133,12 +141,16 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() -- cgit 1.5.1 From f82cf3c7dfdcdbcf076dde1835796f2b274077c5 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 28 Jun 2018 21:14:16 +0100 Subject: add test --- tests/api/test_auth.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) (limited to 'tests') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 48bd411e49..aec3b62897 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -100,6 +100,39 @@ class AuthTestCase(unittest.TestCase): requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) + @defer.inlineCallbacks + def test_get_user_by_req_appservice_valid_token_good_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "192.168.10.10" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) + + def test_get_user_by_req_appservice_valid_token_bad_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "131.111.8.42" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + d = self.auth.get_user_by_req(request) + self.failureResultOf(d, AuthError) + def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) -- cgit 1.5.1 From 508196e08a834496daa1bfc5f561e69a430e270c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 3 Jul 2018 14:36:14 +0100 Subject: Reject invalid server names (#3480) Make sure that server_names used in auth headers are sane, and reject them with a sensible error code, before they disappear off into the depths of the system. --- changelog.d/3480.feature | 1 + synapse/federation/transport/server.py | 66 ++++++++++++++++++++++------------ synapse/http/endpoint.py | 34 ++++++++++++++++-- tests/http/__init__.py | 0 tests/http/test_endpoint.py | 46 ++++++++++++++++++++++++ 5 files changed, 122 insertions(+), 25 deletions(-) create mode 100644 changelog.d/3480.feature create mode 100644 tests/http/__init__.py create mode 100644 tests/http/test_endpoint.py (limited to 'tests') diff --git a/changelog.d/3480.feature b/changelog.d/3480.feature new file mode 100644 index 0000000000..a21580943d --- /dev/null +++ b/changelog.d/3480.feature @@ -0,0 +1 @@ +Reject invalid server names in federation requests diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 19d09f5422..1180d4b69d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError, FederationDeniedError +from synapse.http.endpoint import parse_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, @@ -99,26 +100,6 @@ class Authenticator(object): origin = None - def parse_auth_header(header_str): - try: - params = auth.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith("\""): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return (origin, key, sig) - except Exception: - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: @@ -127,8 +108,8 @@ class Authenticator(object): ) for auth in auth_headers: - if auth.startswith("X-Matrix"): - (origin, key, sig) = parse_auth_header(auth) + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig @@ -165,6 +146,47 @@ class Authenticator(object): logger.exception("Error resetting retry timings on %s", origin) +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode('utf-8') + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith(b"\""): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + # ensure that the origin is a valid server name + parse_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warn( + "Error parsing auth header '%s': %s", + header_bytes.decode('ascii', 'replace'), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + ) + + class BaseFederationServlet(object): REQUIRE_AUTH = True diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 80da870584..5a9cbb3324 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -38,6 +38,36 @@ _Server = collections.namedtuple( ) +def parse_server_name(server_name): + """Split a server name into host/port parts. + + Does some basic sanity checking of the + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == ']': + # ipv6 literal, hopefully + if server_name[0] != '[': + raise Exception() + + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout (int): connection timeout in seconds """ - domain_port = destination.split(":") - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None + domain, port = parse_server_name(destination) endpoint_kw_args = {} diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py new file mode 100644 index 0000000000..cd74825c85 --- /dev/null +++ b/tests/http/test_endpoint.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.http.endpoint import parse_server_name +from tests import unittest + + +class ServerNameTestCase(unittest.TestCase): + def test_parse_server_name(self): + test_data = { + 'localhost': ('localhost', None), + 'my-example.com:1234': ('my-example.com', 1234), + '1.2.3.4': ('1.2.3.4', None), + '[0abc:1def::1234]': ('[0abc:1def::1234]', None), + '1.2.3.4:1': ('1.2.3.4', 1), + '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + } + + for i, o in test_data.items(): + self.assertEqual(parse_server_name(i), o) + + def test_parse_bad_server_names(self): + test_data = [ + "", # empty + "localhost:http", # non-numeric port + "1234]", # smells like ipv6 literal but isn't + ] + for i in test_data: + try: + parse_server_name(i) + self.fail( + "Expected parse_server_name(\"%s\") to throw" % i, + ) + except ValueError: + pass -- cgit 1.5.1 From ea555d56331ad01edc9871ec7bf879df7d24dc7d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 4 Jul 2018 09:35:40 +0100 Subject: Reinstate lost run_on_reactor in unit test a61738b removed a call to run_on_reactor from a unit test, but that call was doing something useful, in making the function in question asynchronous. Reinstate the call and add a check that we are testing what we wanted to be testing. --- changelog.d/3385.misc | 1 + tests/util/caches/test_descriptors.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 changelog.d/3385.misc (limited to 'tests') diff --git a/changelog.d/3385.misc b/changelog.d/3385.misc new file mode 100644 index 0000000000..92a91a1ca5 --- /dev/null +++ b/changelog.d/3385.misc @@ -0,0 +1 @@ +Reinstate lost run_on_reactor in unit tests diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 24754591df..a94d566c96 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -19,13 +19,19 @@ import logging import mock from synapse.api.errors import SynapseError from synapse.util import logcontext -from twisted.internet import defer +from twisted.internet import defer, reactor from synapse.util.caches import descriptors from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + class CacheTestCase(unittest.TestCase): def test_invalidate_all(self): cache = descriptors.Cache("testcache") @@ -194,6 +200,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -203,7 +211,12 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass -- cgit 1.5.1 From 546bc9e28b3d7758c732df8e120639d58d455164 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 4 Jul 2018 18:15:03 +0100 Subject: More server_name validation We need to do a bit more validation when we get a server name, but don't want to be re-doing it all over the shop, so factor out a separate parse_and_validate_server_name, and do the extra validation. Also, use it to verify the server name in the config file. --- changelog.d/3483.feature | 1 + synapse/config/server.py | 11 ++++++-- synapse/federation/transport/server.py | 5 ++-- synapse/http/endpoint.py | 47 ++++++++++++++++++++++++++++++---- tests/http/test_endpoint.py | 17 +++++++++--- 5 files changed, 68 insertions(+), 13 deletions(-) create mode 100644 changelog.d/3483.feature (limited to 'tests') diff --git a/changelog.d/3483.feature b/changelog.d/3483.feature new file mode 100644 index 0000000000..afa2fbbcba --- /dev/null +++ b/changelog.d/3483.feature @@ -0,0 +1 @@ +Reject invalid server names in homeserver.yaml \ No newline at end of file diff --git a/synapse/config/server.py b/synapse/config/server.py index 968ecd9ea0..71fd51e4bc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -16,6 +16,7 @@ import logging +from synapse.http.endpoint import parse_and_validate_server_name from ._base import Config, ConfigError logger = logging.Logger(__name__) @@ -25,6 +26,12 @@ class ServerConfig(Config): def read_config(self, config): self.server_name = config["server_name"] + + try: + parse_and_validate_server_name(self.server_name) + except ValueError as e: + raise ConfigError(str(e)) + self.pid_file = self.abspath(config.get("pid_file")) self.web_client = config["web_client"] self.web_client_location = config.get("web_client_location", None) @@ -162,8 +169,8 @@ class ServerConfig(Config): }) def default_config(self, server_name, **kwargs): - if ":" in server_name: - bind_port = int(server_name.split(":")[1]) + _, bind_port = parse_and_validate_server_name(server_name) + if bind_port is not None: unsecure_port = bind_port - 400 else: bind_port = 8448 diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 1180d4b69d..e1fdcc89dc 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError, FederationDeniedError -from synapse.http.endpoint import parse_server_name +from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, @@ -170,8 +170,9 @@ def _parse_auth_header(header_bytes): return value origin = strip_quotes(param_dict["origin"]) + # ensure that the origin is a valid server name - parse_server_name(origin) + parse_and_validate_server_name(origin) key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 5a9cbb3324..1b1123b292 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer from twisted.internet.error import ConnectError @@ -41,8 +43,6 @@ _Server = collections.namedtuple( def parse_server_name(server_name): """Split a server name into host/port parts. - Does some basic sanity checking of the - Args: server_name (str): server name to parse @@ -55,9 +55,6 @@ def parse_server_name(server_name): try: if server_name[-1] == ']': # ipv6 literal, hopefully - if server_name[0] != '[': - raise Exception() - return server_name, None domain_port = server_name.rsplit(":", 1) @@ -68,6 +65,46 @@ def parse_server_name(server_name): raise ValueError("Invalid server name '%s'" % server_name) +VALID_HOST_REGEX = re.compile( + "\\A[0-9a-zA-Z.-]+\\Z", +) + + +def parse_and_validate_server_name(server_name): + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == '[': + if host[-1] != ']': + raise ValueError("Mismatched [...] in server name '%s'" % ( + server_name, + )) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' contains invalid characters" % ( + server_name, + )) + + return host, port + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index cd74825c85..b8a48d20a4 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_server_name +from synapse.http.endpoint import ( + parse_server_name, + parse_and_validate_server_name, +) from tests import unittest @@ -30,17 +33,23 @@ class ServerNameTestCase(unittest.TestCase): for i, o in test_data.items(): self.assertEqual(parse_server_name(i), o) - def test_parse_bad_server_names(self): + def test_validate_bad_server_names(self): test_data = [ "", # empty "localhost:http", # non-numeric port "1234]", # smells like ipv6 literal but isn't + "[1234", + "underscore_.com", + "percent%65.com", + "1234:5678:80", # too many colons ] for i in test_data: try: - parse_server_name(i) + parse_and_validate_server_name(i) self.fail( - "Expected parse_server_name(\"%s\") to throw" % i, + "Expected parse_and_validate_server_name('%s') to throw" % ( + i, + ), ) except ValueError: pass -- cgit 1.5.1 From 3cf3e08a97f4617763ce10da4f127c0e21d7ff1d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 4 Jul 2018 15:31:00 +0100 Subject: Implementation of server_acls ... as described at https://docs.google.com/document/d/1EttUVzjc2DWe2ciw4XPtNpUpIl9lWXGEsy2ewDS7rtw. --- synapse/api/constants.py | 2 + synapse/federation/federation_server.py | 150 ++++++++++++++++++++++++++++- synapse/federation/transport/server.py | 8 +- tests/federation/__init__.py | 0 tests/federation/test_federation_server.py | 57 +++++++++++ 5 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 tests/federation/__init__.py create mode 100644 tests/federation/test_federation_server.py (limited to 'tests') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 5baba43966..4df930c8d1 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -76,6 +76,8 @@ class EventTypes(object): Topic = "m.room.topic" Name = "m.room.name" + ServerACL = "m.room.server_acl" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index fe51ba6806..591d0026bf 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re from canonicaljson import json +import six from twisted.internet import defer +from twisted.internet.abstract import isIPAddress +from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError from synapse.crypto.event_signing import compute_event_signature from synapse.federation.federation_base import ( @@ -27,6 +31,7 @@ from synapse.federation.federation_base import ( from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction +from synapse.http.endpoint import parse_server_name from synapse.types import get_domain_from_id from synapse.util import async from synapse.util.caches.response_cache import ResponseCache @@ -74,6 +79,9 @@ class FederationServer(FederationBase): @log_function def on_backfill_request(self, origin, room_id, versions, limit): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + pdus = yield self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -134,6 +142,8 @@ class FederationServer(FederationBase): received_pdus_counter.inc(len(transaction.pdus)) + origin_host, _ = parse_server_name(transaction.origin) + pdus_by_room = {} for p in transaction.pdus: @@ -154,9 +164,21 @@ class FederationServer(FederationBase): # we can process different rooms in parallel (which is useful if they # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. + @defer.inlineCallbacks def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) + try: + yield self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warn( + "Ignoring PDUs for room %s from banned server", room_id, + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + for pdu in pdus_by_room[room_id]: event_id = pdu.event_id try: @@ -211,6 +233,9 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -234,6 +259,9 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -298,7 +326,9 @@ class FederationServer(FederationBase): defer.returnValue((200, resp)) @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): + def on_make_join_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -306,6 +336,8 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_invite_request(self, origin, content): pdu = event_from_pdu_json(content) + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @@ -314,6 +346,10 @@ class FederationServer(FederationBase): def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -325,7 +361,9 @@ class FederationServer(FederationBase): })) @defer.inlineCallbacks - def on_make_leave_request(self, room_id, user_id): + def on_make_leave_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_leave_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -334,6 +372,10 @@ class FederationServer(FederationBase): def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @@ -341,6 +383,9 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) res = { @@ -369,6 +414,9 @@ class FederationServer(FederationBase): Deferred: Results in `dict` with the same format as `content` """ with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + auth_chain = [ event_from_pdu_json(e) for e in content["auth_chain"] @@ -442,6 +490,9 @@ class FederationServer(FederationBase): def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d, min_depth: %d", @@ -579,6 +630,101 @@ class FederationServer(FederationBase): ) defer.returnValue(ret) + @defer.inlineCallbacks + def check_server_matches_acl(self, server_name, room_id): + """Check if the given server is allowed by the server ACLs in the room + + Args: + server_name (str): name of server, *without any port part* + room_id (str): ID of the room to check + + Raises: + AuthError if the server does not match the ACL + """ + state_ids = yield self.store.get_current_state_ids(room_id) + acl_event_id = state_ids.get((EventTypes.ServerACL, "")) + + if not acl_event_id: + return + + acl_event = yield self.store.get_event(acl_event_id) + if server_matches_acl_event(server_name, acl_event): + return + + raise AuthError(code=403, msg="Server is banned from room") + + +def server_matches_acl_event(server_name, acl_event): + """Check if the given server is allowed by the ACL event + + Args: + server_name (str): name of server, without any port part + acl_event (EventBase): m.room.server_acl event + + Returns: + bool: True if this server is allowed by the ACLs + """ + logger.debug("Checking %s against acl %s", server_name, acl_event.content) + + # first of all, check if literal IPs are blocked, and if so, whether the + # server name is a literal IP + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warn("Ignorning non-bool allow_ip_literals flag") + allow_ip_literals = True + if not allow_ip_literals: + # check for ipv6 literals. These start with '['. + if server_name[0] == '[': + return False + + # check for ipv4 literals. We can just lift the routine from twisted. + if isIPAddress(server_name): + return False + + # next, check the deny list + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warn("Ignorning non-list deny ACL %s", deny) + deny = [] + for e in deny: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched deny rule %s", server_name, e) + return False + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warn("Ignorning non-list allow ACL %s", allow) + allow = [] + for e in allow: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched allow rule %s", server_name, e) + return True + + # everything else should be rejected. + # logger.info("%s fell through", server_name) + return False + + +def _acl_entry_matches(server_name, acl_entry): + if not isinstance(acl_entry, six.string_types): + logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + return False + regex = _glob_to_regex(acl_entry) + return regex.match(server_name) + + +def _glob_to_regex(glob): + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + return re.compile(res + "\\Z", re.IGNORECASE) + class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index e1fdcc89dc..c6d98d35cb 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -385,7 +385,9 @@ class FederationMakeJoinServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_join_request(context, user_id) + content = yield self.handler.on_make_join_request( + origin, context, user_id, + ) defer.returnValue((200, content)) @@ -394,7 +396,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request(context, user_id) + content = yield self.handler.on_make_leave_request( + origin, context, user_id, + ) defer.returnValue((200, content)) diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py new file mode 100644 index 0000000000..4e8dc8fea0 --- /dev/null +++ b/tests/federation/test_federation_server.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.events import FrozenEvent +from synapse.federation.federation_server import server_matches_acl_event +from tests import unittest + + +@unittest.DEBUG +class ServerACLsTestCase(unittest.TestCase): + def test_blacklisted_server(self): + e = _create_acl_event({ + "allow": ["*"], + "deny": ["evil.com"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("evil.com", e)) + self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + + self.assertTrue(server_matches_acl_event("evil.com.au", e)) + self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + + def test_block_ip_literals(self): + e = _create_acl_event({ + "allow_ip_literals": False, + "allow": ["*"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("1.2.3.4", e)) + self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) + self.assertFalse(server_matches_acl_event("[1:2::]", e)) + self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + + +def _create_acl_event(content): + return FrozenEvent({ + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content + }) -- cgit 1.5.1 From 49af4020190eae6b0c65897d96cd2be286364d2b Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 9 Jul 2018 16:09:20 +1000 Subject: run isort --- synapse/api/auth.py | 7 +-- synapse/api/errors.py | 4 +- synapse/api/filtering.py | 14 ++--- synapse/api/urls.py | 2 +- synapse/app/__init__.py | 4 +- synapse/app/_base.py | 11 ++-- synapse/app/appservice.py | 5 +- synapse/app/client_reader.py | 5 +- synapse/app/event_creator.py | 9 ++-- synapse/app/federation_reader.py | 5 +- synapse/app/federation_sender.py | 5 +- synapse/app/frontend_proxy.py | 9 ++-- synapse/app/homeserver.py | 33 +++++++----- synapse/app/media_repository.py | 9 ++-- synapse/app/pusher.py | 5 +- synapse/app/synchrotron.py | 11 ++-- synapse/app/synctl.py | 5 +- synapse/app/user_dir.py | 5 +- synapse/appservice/__init__.py | 12 ++--- synapse/appservice/api.py | 14 ++--- synapse/appservice/scheduler.py | 4 +- synapse/config/_base.py | 3 +- synapse/config/api.py | 4 +- synapse/config/appservice.py | 14 ++--- synapse/config/homeserver.py | 36 ++++++------- synapse/config/jwt.py | 1 - synapse/config/key.py | 21 ++++---- synapse/config/logger.py | 6 ++- synapse/config/password_auth_providers.py | 4 +- synapse/config/registration.py | 4 +- synapse/config/repository.py | 2 +- synapse/config/server.py | 1 + synapse/config/server_notices_config.py | 3 +- synapse/config/tls.py | 11 ++-- synapse/crypto/context_factory.py | 6 +-- synapse/crypto/event_signing.py | 10 ++-- synapse/crypto/keyclient.py | 14 ++--- synapse/crypto/keyring.py | 44 ++++++++-------- synapse/event_auth.py | 6 +-- synapse/events/__init__.py | 3 +- synapse/events/builder.py | 5 +- synapse/events/snapshot.py | 4 +- synapse/events/utils.py | 9 ++-- synapse/events/validator.py | 8 +-- synapse/federation/federation_base.py | 7 +-- synapse/federation/federation_client.py | 14 ++--- synapse/federation/federation_server.py | 18 +++---- synapse/federation/persistence.py | 5 +- synapse/federation/send_queue.py | 14 ++--- synapse/federation/transaction_queue.py | 30 +++++------ synapse/federation/transport/client.py | 9 ++-- synapse/federation/transport/server.py | 21 ++++---- synapse/federation/units.py | 3 +- synapse/groups/attestations.py | 5 +- synapse/groups/groups_server.py | 7 +-- synapse/handlers/__init__.py | 10 ++-- synapse/handlers/_base.py | 3 +- synapse/handlers/admin.py | 4 +- synapse/handlers/appservice.py | 13 +++-- synapse/handlers/auth.py | 26 +++++----- synapse/handlers/deactivate_account.py | 7 +-- synapse/handlers/device.py | 16 +++--- synapse/handlers/devicemessage.py | 3 +- synapse/handlers/directory.py | 9 ++-- synapse/handlers/e2e_keys.py | 10 ++-- synapse/handlers/events.py | 15 +++--- synapse/handlers/federation.py | 40 ++++++++------- synapse/handlers/groups_local.py | 7 +-- synapse/handlers/identity.py | 7 ++- synapse/handlers/initial_sync.py | 9 ++-- synapse/handlers/message.py | 23 ++++----- synapse/handlers/presence.py | 21 ++++---- synapse/handlers/profile.py | 3 +- synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 10 ++-- synapse/handlers/register.py | 11 ++-- synapse/handlers/room.py | 19 +++---- synapse/handlers/room_list.py | 20 ++++---- synapse/handlers/room_member.py | 14 +++-- synapse/handlers/room_member_worker.py | 5 +- synapse/handlers/search.py | 17 +++---- synapse/handlers/set_password.py | 1 + synapse/handlers/sync.py | 24 ++++----- synapse/handlers/typing.py | 11 ++-- synapse/handlers/user_directory.py | 7 +-- synapse/http/additional_resource.py | 3 +- synapse/http/client.py | 48 +++++++++--------- synapse/http/endpoint.py | 12 ++--- synapse/http/matrixfederationclient.py | 48 +++++++++--------- synapse/http/request_metrics.py | 2 +- synapse/http/server.py | 38 +++++++------- synapse/http/servlet.py | 3 +- synapse/http/site.py | 2 +- synapse/metrics/__init__.py | 11 ++-- synapse/notifier.py | 21 ++++---- synapse/push/action_generator.py | 6 +-- synapse/push/baserules.py | 3 +- synapse/push/bulk_push_rule_evaluator.py | 17 ++++--- synapse/push/clientformat.py | 6 +-- synapse/push/emailpusher.py | 7 ++- synapse/push/httppusher.py | 6 +-- synapse/push/mailer.py | 33 ++++++------ synapse/push/presentable_names.py | 6 +-- synapse/push/push_rule_evaluator.py | 4 +- synapse/push/push_tools.py | 5 +- synapse/push/pusher.py | 3 +- synapse/replication/http/__init__.py | 1 - synapse/replication/http/membership.py | 4 +- synapse/replication/http/send_event.py | 12 +++-- synapse/replication/slave/storage/_base.py | 4 +- synapse/replication/slave/storage/appservice.py | 3 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/replication/slave/storage/deviceinbox.py | 7 +-- synapse/replication/slave/storage/devices.py | 5 +- synapse/replication/slave/storage/directory.py | 3 +- synapse/replication/slave/storage/events.py | 3 +- synapse/replication/slave/storage/filtering.py | 3 +- synapse/replication/slave/storage/groups.py | 5 +- synapse/replication/slave/storage/keys.py | 3 +- synapse/replication/slave/storage/presence.py | 8 +-- synapse/replication/slave/storage/push_rule.py | 5 +- synapse/replication/slave/storage/pushers.py | 4 +- synapse/replication/slave/storage/receipts.py | 4 +- synapse/replication/slave/storage/registration.py | 3 +- synapse/replication/slave/storage/room.py | 3 +- synapse/replication/slave/storage/transactions.py | 3 +- synapse/replication/tcp/client.py | 9 ++-- synapse/replication/tcp/protocol.py | 40 +++++++++------ synapse/replication/tcp/resource.py | 17 ++++--- synapse/replication/tcp/streams.py | 5 +- synapse/rest/__init__.py | 55 ++++++++------------ synapse/rest/client/v1/admin.py | 10 ++-- synapse/rest/client/v1/base.py | 10 ++-- synapse/rest/client/v1/directory.py | 9 ++-- synapse/rest/client/v1/events.py | 8 +-- synapse/rest/client/v1/initial_sync.py | 1 + synapse/rest/client/v1/login.py | 30 +++++------ synapse/rest/client/v1/logout.py | 5 +- synapse/rest/client/v1/presence.py | 13 ++--- synapse/rest/client/v1/profile.py | 5 +- synapse/rest/client/v1/push_rule.py | 16 +++--- synapse/rest/client/v1/pusher.py | 15 +++--- synapse/rest/client/v1/register.py | 19 +++---- synapse/rest/client/v1/room.py | 24 +++++---- synapse/rest/client/v1/voip.py | 9 ++-- synapse/rest/client/v2_alpha/account.py | 5 +- synapse/rest/client/v2_alpha/account_data.py | 10 ++-- synapse/rest/client/v2_alpha/auth.py | 5 +- synapse/rest/client/v2_alpha/devices.py | 1 + synapse/rest/client/v2_alpha/filter.py | 10 ++-- synapse/rest/client/v2_alpha/groups.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 7 ++- synapse/rest/client/v2_alpha/notifications.py | 11 ++-- synapse/rest/client/v2_alpha/openid.py | 10 ++-- synapse/rest/client/v2_alpha/read_marker.py | 6 +-- synapse/rest/client/v2_alpha/receipts.py | 6 +-- synapse/rest/client/v2_alpha/register.py | 22 ++++---- synapse/rest/client/v2_alpha/report_event.py | 6 +-- synapse/rest/client/v2_alpha/sync.py | 27 +++++----- synapse/rest/client/v2_alpha/tags.py | 10 ++-- synapse/rest/client/v2_alpha/thirdparty.py | 1 + synapse/rest/client/v2_alpha/user_directory.py | 1 + synapse/rest/client/versions.py | 4 +- synapse/rest/consent/consent_resource.py | 12 ++--- synapse/rest/key/v1/server_key_resource.py | 10 ++-- synapse/rest/key/v2/__init__.py | 1 + synapse/rest/key/v2/local_key_resource.py | 10 ++-- synapse/rest/key/v2/remote_key_resource.py | 16 +++--- synapse/rest/media/v0/content_repository.py | 19 +++---- synapse/rest/media/v1/_base.py | 19 +++---- synapse/rest/media/v1/download_resource.py | 6 +-- synapse/rest/media/v1/filepath.py | 2 +- synapse/rest/media/v1/identicon_resource.py | 1 + synapse/rest/media/v1/media_repository.py | 51 ++++++++++--------- synapse/rest/media/v1/media_storage.py | 17 +++---- synapse/rest/media/v1/preview_url_resource.py | 25 ++++----- synapse/rest/media/v1/storage_provider.py | 11 ++-- synapse/rest/media/v1/thumbnail_resource.py | 11 ++-- synapse/rest/media/v1/thumbnailer.py | 4 +- synapse/rest/media/v1/upload_resource.py | 5 +- synapse/server.py | 34 ++++++------- synapse/server_notices/consent_server_notices.py | 3 +- synapse/state.py | 22 ++++---- synapse/storage/__init__.py | 59 ++++++++++------------ synapse/storage/_base.py | 19 ++++--- synapse/storage/account_data.py | 13 +++-- synapse/storage/appservice.py | 6 ++- synapse/storage/background_updates.py | 10 ++-- synapse/storage/client_ips.py | 9 ++-- synapse/storage/deviceinbox.py | 3 +- synapse/storage/devices.py | 11 ++-- synapse/storage/directory.py | 9 ++-- synapse/storage/end_to_end_keys.py | 8 +-- synapse/storage/engines/__init__.py | 7 ++- synapse/storage/engines/sqlite3.py | 4 +- synapse/storage/event_federation.py | 16 +++--- synapse/storage/event_push_actions.py | 11 ++-- synapse/storage/events.py | 34 ++++++------- synapse/storage/events_worker.py | 25 +++++---- synapse/storage/filtering.py | 7 +-- synapse/storage/group_server.py | 5 +- synapse/storage/keys.py | 14 ++--- synapse/storage/prepare_database.py | 1 - synapse/storage/presence.py | 10 ++-- synapse/storage/profile.py | 2 +- synapse/storage/push_rule.py | 18 ++++--- synapse/storage/pusher.py | 9 ++-- synapse/storage/receipts.py | 16 +++--- synapse/storage/registration.py | 6 +-- synapse/storage/rejections.py | 4 +- synapse/storage/room.py | 12 ++--- synapse/storage/roommember.py | 19 ++++--- synapse/storage/schema/delta/25/fts.py | 6 +-- synapse/storage/schema/delta/27/ts.py | 4 +- synapse/storage/schema/delta/30/as_users.py | 2 +- synapse/storage/schema/delta/31/search_update.py | 7 +-- synapse/storage/schema/delta/33/event_fields.py | 5 +- synapse/storage/schema/delta/33/remote_media_ts.py | 1 - synapse/storage/schema/delta/34/cache_stream.py | 6 +-- .../storage/schema/delta/34/received_txn_purge.py | 4 +- synapse/storage/schema/delta/34/sent_txn_purge.py | 4 +- synapse/storage/schema/delta/37/remove_auth_idx.py | 6 +-- synapse/storage/schema/delta/42/user_dir.py | 2 +- synapse/storage/search.py | 8 +-- synapse/storage/signatures.py | 8 +-- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 16 +++--- synapse/storage/tags.py | 10 ++-- synapse/storage/transactions.py | 11 ++-- synapse/storage/user_directory.py | 14 ++--- synapse/storage/user_erasure_store.py | 2 +- synapse/storage/util/id_generators.py | 2 +- synapse/streams/config.py | 5 +- synapse/streams/events.py | 7 ++- synapse/types.py | 3 +- synapse/util/__init__.py | 1 + synapse/util/async.py | 18 ++++--- synapse/util/caches/__init__.py | 6 +-- synapse/util/caches/descriptors.py | 22 ++++---- synapse/util/caches/dictionary_cache.py | 9 ++-- synapse/util/caches/expiringcache.py | 5 +- synapse/util/caches/lrucache.py | 2 +- synapse/util/caches/stream_change_cache.py | 5 +- synapse/util/file_consumer.py | 4 +- synapse/util/frozenutils.py | 6 +-- synapse/util/httpresourcetree.py | 4 +- synapse/util/logcontext.py | 6 +-- synapse/util/logformatter.py | 3 +- synapse/util/logutils.py | 8 ++- synapse/util/manhole.py | 6 +-- synapse/util/metrics.py | 8 +-- synapse/util/msisdn.py | 1 + synapse/util/ratelimitutils.py | 13 +++-- synapse/util/retryutils.py | 9 ++-- synapse/util/rlimit.py | 3 +- synapse/util/stringutils.py | 1 + synapse/util/versionstring.py | 4 +- synapse/visibility.py | 4 +- tests/__init__.py | 1 + tests/api/test_auth.py | 7 ++- tests/api/test_filtering.py | 15 +++--- tests/appservice/test_appservice.py | 9 ++-- tests/appservice/test_scheduler.py | 15 ++++-- tests/config/test_generate.py | 1 + tests/config/test_load.py | 3 ++ tests/crypto/test_event_signing.py | 10 ++-- tests/crypto/test_keyring.py | 10 ++-- tests/events/test_utils.py | 4 +- tests/federation/test_federation_server.py | 1 + tests/handlers/test_appservice.py | 8 +-- tests/handlers/test_auth.py | 2 + tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 6 +-- tests/handlers/test_e2e_keys.py | 5 +- tests/handlers/test_presence.py | 12 +++-- tests/handlers/test_profile.py | 6 +-- tests/handlers/test_register.py | 5 +- tests/handlers/test_typing.py | 19 ++++--- tests/http/test_endpoint.py | 6 +-- tests/replication/slave/storage/_base.py | 15 +++--- .../replication/slave/storage/test_account_data.py | 4 +- tests/replication/slave/storage/test_events.py | 5 +- tests/replication/slave/storage/test_receipts.py | 4 +- tests/rest/client/test_transactions.py | 7 +-- tests/rest/client/v1/test_events.py | 6 +-- tests/rest/client/v1/test_profile.py | 5 +- tests/rest/client/v1/test_register.py | 10 ++-- tests/rest/client/v1/test_rooms.py | 11 ++-- tests/rest/client/v1/test_typing.py | 7 ++- tests/rest/client/v1/utils.py | 10 ++-- tests/rest/client/v2_alpha/__init__.py | 7 ++- tests/rest/client/v2_alpha/test_filter.py | 11 ++-- tests/rest/client/v2_alpha/test_register.py | 11 ++-- tests/rest/media/v1/test_media_storage.py | 14 ++--- tests/server.py | 10 ++-- tests/storage/test__base.py | 8 +-- tests/storage/test_appservice.py | 20 +++++--- tests/storage/test_background_update.py | 6 +-- tests/storage/test_base.py | 8 +-- tests/storage/test_devices.py | 1 + tests/storage/test_directory.py | 4 +- tests/storage/test_event_push_actions.py | 3 +- tests/storage/test_keys.py | 1 + tests/storage/test_presence.py | 4 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 8 +-- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_roommember.py | 8 +-- tests/storage/test_user_directory.py | 1 + tests/test_distributor.py | 7 +-- tests/test_dns.py | 7 +-- tests/test_event_auth.py | 3 +- tests/test_federation.py | 10 ++-- tests/test_preview.py | 7 +-- tests/test_server.py | 3 +- tests/test_state.py | 10 ++-- tests/test_test_utils.py | 1 - tests/test_types.py | 6 +-- tests/util/caches/test_descriptors.py | 7 ++- tests/util/test_dict_cache.py | 4 +- tests/util/test_expiring_cache.py | 4 +- tests/util/test_file_consumer.py | 9 ++-- tests/util/test_limiter.py | 4 +- tests/util/test_linearizer.py | 7 +-- tests/util/test_logcontext.py | 8 +-- tests/util/test_logformatter.py | 1 + tests/util/test_lrucache.py | 4 +- tests/util/test_rwlock.py | 4 +- tests/util/test_snapshot_cache.py | 5 +- tests/util/test_stream_change_cache.py | 3 +- tests/util/test_treecache.py | 4 +- tests/util/test_wheel_timer.py | 4 +- tests/utils.py | 3 +- 334 files changed, 1609 insertions(+), 1528 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 088b4e8b6d..6dec862fec 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -18,15 +18,16 @@ import logging from six import itervalues import pymacaroons -from twisted.internet import defer from netaddr import IPAddress +from twisted.internet import defer + import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes from synapse.types import UserID -from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR +from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 227a0713b2..6074df292f 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,11 +17,11 @@ import logging -from canonicaljson import json - from six import iteritems from six.moves import http_client +from canonicaljson import json + logger = logging.getLogger(__name__) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index aae25e7a47..25346baa87 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, RoomID -from twisted.internet import defer - -from canonicaljson import json - import jsonschema +from canonicaljson import json from jsonschema import FormatChecker +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage.presence import UserPresenceState +from synapse.types import RoomID, UserID + FILTER_SCHEMA = { "additionalProperties": False, "type": "object", diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bb46b5da8a..71347912f1 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -15,8 +15,8 @@ # limitations under the License. """Contains the URL paths to prefix various aspects of the server with. """ -from hashlib import sha256 import hmac +from hashlib import sha256 from six.moves.urllib.parse import urlencode diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 9c2b627590..3b6b9368b8 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -14,9 +14,11 @@ # limitations under the License. import sys + +from synapse import python_dependencies # noqa: E402 + sys.dont_write_bytecode = True -from synapse import python_dependencies # noqa: E402 try: python_dependencies.check_requirements() diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a6925ab139..391bd14c5c 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -17,15 +17,18 @@ import gc import logging import sys +from daemonize import Daemonize + +from twisted.internet import error, reactor + +from synapse.util import PreserveLoggingContext +from synapse.util.rlimit import change_resource_limit + try: import affinity except Exception: affinity = None -from daemonize import Daemonize -from synapse.util import PreserveLoggingContext -from synapse.util.rlimit import change_resource_limit -from twisted.internet import error, reactor logger = logging.getLogger(__name__) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 4319ddce03..9a37384fb7 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -36,8 +39,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor, defer -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.appservice") diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 654ddb8414..b0ea26dcb4 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -44,8 +47,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.client_reader") diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 441467093a..374f115644 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -43,8 +46,10 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import ( - RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet, JoinRoomAliasServlet, + RoomMembershipRestServlet, + RoomSendEventRestServlet, + RoomStateEventRestServlet, ) from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -52,8 +57,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.event_creator") diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index b2415cc671..7af00b8bcf 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.api.urls import FEDERATION_PREFIX @@ -41,8 +44,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.federation_reader") diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 13d2b70053..18469013fa 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -42,8 +45,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.federation_sender") diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index d2bae4ad03..b5f78f4640 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.api.errors import SynapseError @@ -25,9 +28,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.crypto import context_factory from synapse.http.server import JsonResource -from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, -) +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource @@ -44,8 +45,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.frontend_proxy") diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index ae5fc751d5..14e6dca522 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,27 +18,39 @@ import logging import os import sys +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.web.resource import EncodingResourceWrapper, NoResource +from twisted.web.server import GzipEncoderFactory +from twisted.web.static import File + import synapse import synapse.config.logger from synapse import events -from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \ - LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ - STATIC_PREFIX, WEB_CLIENT_PREFIX +from synapse.api.urls import ( + CONTENT_REPO_PREFIX, + FEDERATION_PREFIX, + LEGACY_MEDIA_PREFIX, + MEDIA_PREFIX, + SERVER_KEY_PREFIX, + SERVER_KEY_V2_PREFIX, + STATIC_PREFIX, + WEB_CLIENT_PREFIX, +) from synapse.app import _base -from synapse.app._base import quit_with_error, listen_ssl, listen_tcp +from synapse.app._base import listen_ssl, listen_tcp, quit_with_error from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.federation.transport.server import TransportLayerServer -from synapse.module_api import ModuleApi from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource -from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \ - check_requirements -from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX +from synapse.module_api import ModuleApi +from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements +from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.key.v1.server_key_resource import LocalKey @@ -55,11 +67,6 @@ from synapse.util.manhole import manhole from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from twisted.application import service -from twisted.internet import defer, reactor -from twisted.web.resource import EncodingResourceWrapper, NoResource -from twisted.web.server import GzipEncoderFactory -from twisted.web.static import File logger = logging.getLogger("synapse.app.homeserver") diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 19a682cce3..749bbf37d0 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -16,11 +16,12 @@ import logging import sys +from twisted.internet import reactor +from twisted.web.resource import NoResource + import synapse from synapse import events -from synapse.api.urls import ( - CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX -) +from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig @@ -43,8 +44,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.media_repository") diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 13cfbd08b0..9295a51d5b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -16,6 +16,9 @@ import logging import sys +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -37,8 +40,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.pusher") diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 82f06ea185..26b9ec85f2 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -17,6 +17,11 @@ import contextlib import logging import sys +from six import iteritems + +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse.api.constants import EventTypes from synapse.app import _base @@ -36,12 +41,12 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -56,10 +61,6 @@ from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -from six import iteritems logger = logging.getLogger("synapse.app.synchrotron") diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 56ae086128..68acc15a9a 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -16,16 +16,17 @@ import argparse import collections +import errno import glob import os import os.path import signal import subprocess import sys -import yaml -import errno import time +import yaml + SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] GREEN = "\x1b[1;32m" diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index f5726e3df6..637a89530a 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -17,6 +17,9 @@ import logging import sys +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + import synapse from synapse import events from synapse.app import _base @@ -43,8 +46,6 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor, defer -from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.user_dir") diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 328cbfa284..57ed8a3ca2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -12,17 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes -from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.types import GroupID, get_domain_from_id - -from twisted.internet import defer - import logging import re from six import string_types +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.types import GroupID, get_domain_from_id +from synapse.util.caches.descriptors import cachedInlineCallbacks + logger = logging.getLogger(__name__) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 47251fb6ad..6980e5890e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -12,19 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import urllib + +from prometheus_client import Counter + from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException -from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event -from synapse.util.caches.response_cache import ResponseCache +from synapse.http.client import SimpleHttpClient from synapse.types import ThirdPartyInstanceID - -import logging -import urllib - -from prometheus_client import Counter +from synapse.util.caches.response_cache import ResponseCache logger = logging.getLogger(__name__) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6eddbc0828..2430814796 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,14 +48,14 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ +import logging + from twisted.internet import defer from synapse.appservice import ApplicationServiceState from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index b748ed2b0a..3d2e90dd5b 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -16,11 +16,12 @@ import argparse import errno import os -import yaml from textwrap import dedent from six import integer_types +import yaml + class ConfigError(Exception): pass diff --git a/synapse/config/api.py b/synapse/config/api.py index 20ba33226a..403d96ba76 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - from synapse.api.constants import EventTypes +from ._base import Config + class ApiConfig(Config): diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 0c27bb2fa7..3b161d708a 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +import logging -from synapse.appservice import ApplicationService -from synapse.types import UserID +from six import string_types +from six.moves.urllib import parse as urlparse +import yaml from netaddr import IPSet -import yaml -import logging +from synapse.appservice import ApplicationService +from synapse.types import UserID -from six import string_types -from six.moves.urllib import parse as urlparse +from ._base import Config, ConfigError logger = logging.getLogger(__name__) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 1dea2ad024..2fd9c48abf 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -13,32 +13,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .tls import TlsConfig -from .server import ServerConfig -from .logger import LoggingConfig -from .database import DatabaseConfig -from .ratelimiting import RatelimitConfig -from .repository import ContentRepositoryConfig -from .captcha import CaptchaConfig -from .voip import VoipConfig -from .registration import RegistrationConfig -from .metrics import MetricsConfig from .api import ApiConfig from .appservice import AppServiceConfig -from .key import KeyConfig -from .saml2 import SAML2Config +from .captcha import CaptchaConfig from .cas import CasConfig -from .password import PasswordConfig +from .consent_config import ConsentConfig +from .database import DatabaseConfig +from .emailconfig import EmailConfig +from .groups import GroupsConfig from .jwt import JWTConfig +from .key import KeyConfig +from .logger import LoggingConfig +from .metrics import MetricsConfig +from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig -from .emailconfig import EmailConfig -from .workers import WorkerConfig from .push import PushConfig +from .ratelimiting import RatelimitConfig +from .registration import RegistrationConfig +from .repository import ContentRepositoryConfig +from .saml2 import SAML2Config +from .server import ServerConfig +from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig -from .groups import GroupsConfig +from .tls import TlsConfig from .user_directory import UserDirectoryConfig -from .consent_config import ConsentConfig -from .server_notices_config import ServerNoticesConfig +from .voip import VoipConfig +from .workers import WorkerConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 47f145c589..51e7f7e003 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -15,7 +15,6 @@ from ._base import Config, ConfigError - MISSING_JWT = ( """Missing jwt library. This is required for jwt login. diff --git a/synapse/config/key.py b/synapse/config/key.py index d1382ad9ac..279c47bb48 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -13,21 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +import hashlib +import logging +import os -from synapse.util.stringutils import random_string from signedjson.key import ( - generate_signing_key, is_signing_algorithm_supported, - decode_signing_key_base64, decode_verify_key_bytes, - read_signing_keys, write_signing_keys, NACL_ED25519 + NACL_ED25519, + decode_signing_key_base64, + decode_verify_key_bytes, + generate_signing_key, + is_signing_algorithm_supported, + read_signing_keys, + write_signing_keys, ) from unpaddedbase64 import decode_base64 -from synapse.util.stringutils import random_string_with_symbols -import os -import hashlib -import logging +from synapse.util.stringutils import random_string, random_string_with_symbols +from ._base import Config, ConfigError logger = logging.getLogger(__name__) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 557c270fbe..a87b11a1df 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -16,15 +16,17 @@ import logging import logging.config import os import signal -from string import Template import sys +from string import Template -from twisted.logger import STDLibLogObserver, globalLogBeginner import yaml +from twisted.logger import STDLibLogObserver, globalLogBeginner + import synapse from synapse.util.logcontext import LoggingContextFilter from synapse.util.versionstring import get_version_string + from ._base import Config DEFAULT_LOG_CONFIG = Template(""" diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 6602c5b4c7..f4066abc28 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - from synapse.util.module_loader import load_module +from ._base import Config + LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' diff --git a/synapse/config/registration.py b/synapse/config/registration.py index c5384b3ad4..0fb964eb67 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from distutils.util import strtobool from synapse.util.stringutils import random_string_with_symbols -from distutils.util import strtobool +from ._base import Config class RegistrationConfig(Config): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 81ecf9778c..fc909c1fac 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError from collections import namedtuple from synapse.util.module_loader import load_module +from ._base import Config, ConfigError MISSING_NETADDR = ( "Missing netaddr library. This is required for URL preview API." diff --git a/synapse/config/server.py b/synapse/config/server.py index 71fd51e4bc..18102656b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -17,6 +17,7 @@ import logging from synapse.http.endpoint import parse_and_validate_server_name + from ._base import Config, ConfigError logger = logging.Logger(__name__) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index be1d1f762c..3c39850ac6 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config from synapse.types import UserID +from ._base import Config + DEFAULT_CONFIG = """\ # Server Notices room configuration # diff --git a/synapse/config/tls.py b/synapse/config/tls.py index b66154bc7c..fef1ea99cb 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - -from OpenSSL import crypto -import subprocess import os - +import subprocess from hashlib import sha256 + from unpaddedbase64 import encode_base64 +from OpenSSL import crypto + +from ._base import Config + GENERATE_DH_PARAMS = False diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 0397f73ab4..a1e1d0d33a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import ssl +import logging + from OpenSSL import SSL, crypto +from twisted.internet import ssl from twisted.internet._sslverify import _defaultCurveName -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index aaa3efaca3..8774b28967 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -15,15 +15,15 @@ # limitations under the License. -from synapse.api.errors import SynapseError, Codes -from synapse.events.utils import prune_event +import hashlib +import logging from canonicaljson import encode_canonical_json -from unpaddedbase64 import encode_base64, decode_base64 from signedjson.sign import sign_json +from unpaddedbase64 import decode_base64, encode_base64 -import hashlib -import logging +from synapse.api.errors import Codes, SynapseError +from synapse.events.utils import prune_event logger = logging.getLogger(__name__) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 2a0eddbea1..668b4f517d 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -13,14 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import logcontext -from twisted.web.http import HTTPClient -from twisted.internet.protocol import Factory -from twisted.internet import defer, reactor -from synapse.http.endpoint import matrix_federation_endpoint -from canonicaljson import json import logging +from canonicaljson import json + +from twisted.internet import defer, reactor +from twisted.internet.protocol import Factory +from twisted.web.http import HTTPClient + +from synapse.http.endpoint import matrix_federation_endpoint +from synapse.util import logcontext logger = logging.getLogger(__name__) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9b17ef0a08..e95b9fb43e 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,35 +14,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.crypto.keyclient import fetch_server_key -from synapse.api.errors import SynapseError, Codes -from synapse.util import unwrapFirstError, logcontext -from synapse.util.logcontext import ( - PreserveLoggingContext, - preserve_fn, - run_in_background, -) -from synapse.util.metrics import Measure - -from twisted.internet import defer +import hashlib +import logging +import urllib +from collections import namedtuple -from signedjson.sign import ( - verify_signed_json, signature_ids, sign_json, encode_canonical_json, - SignatureVerifyException, -) from signedjson.key import ( - is_signing_algorithm_supported, decode_verify_key_bytes, + decode_verify_key_bytes, encode_verify_key_base64, + is_signing_algorithm_supported, +) +from signedjson.sign import ( + SignatureVerifyException, + encode_canonical_json, + sign_json, + signature_ids, + verify_signed_json, ) from unpaddedbase64 import decode_base64, encode_base64 from OpenSSL import crypto +from twisted.internet import defer -from collections import namedtuple -import urllib -import hashlib -import logging - +from synapse.api.errors import Codes, SynapseError +from synapse.crypto.keyclient import fetch_server_key +from synapse.util import logcontext, unwrapFirstError +from synapse.util.logcontext import ( + PreserveLoggingContext, + preserve_fn, + run_in_background, +) +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index cdf99fd140..b32f64e729 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -17,11 +17,11 @@ import logging from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException +from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index cb08da4984..51f9084b90 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.frozenutils import freeze from synapse.util.caches import intern_dict - +from synapse.util.frozenutils import freeze # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 13fbba68c0..e662eaef10 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent, _event_dict_property +import copy from synapse.types import EventID - from synapse.util.stringutils import random_string -import copy +from . import EventBase, FrozenEvent, _event_dict_property class EventBuilder(EventBase): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 8e684d91b5..bcd9bb5946 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from frozendict import frozendict +from twisted.internet import defer + class EventContext(object): """ diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 29ae086786..652941ca0d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes -from . import EventBase +import re + +from six import string_types from frozendict import frozendict -import re +from synapse.api.constants import EventTypes -from six import string_types +from . import EventBase # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? """ +import logging + +from six import string_types + from twisted.internet import defer -from synapse.api.errors import SynapseError, AuthError -from synapse.types import UserID +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import parse_json_object_from_request -from .base import ClientV1RestServlet, client_path_patterns - -from six import string_types +from synapse.types import UserID -import logging +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e4e3611a14..a23edd8fe5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,9 +16,10 @@ """ This module contains REST servlets to do with profile: /profile/ """ from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns -from synapse.types import UserID from synapse.http.servlet import parse_json_object_from_request +from synapse.types import UserID + +from .base import ClientV1RestServlet, client_path_patterns class ProfileDisplaynameRestServlet(ClientV1RestServlet): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6bb4821ec6..0df7ce570f 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -16,16 +16,18 @@ from twisted.internet import defer from synapse.api.errors import ( - SynapseError, UnrecognizedRequestError, NotFoundError, StoreError + NotFoundError, + StoreError, + SynapseError, + UnrecognizedRequestError, ) -from .base import ClientV1RestServlet, client_path_patterns -from synapse.storage.push_rule import ( - InconsistentRuleException, RuleNotFoundException -) -from synapse.push.clientformat import format_push_rules_for_user +from synapse.http.servlet import parse_json_value_from_request from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.http.servlet import parse_json_value_from_request +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException + +from .base import ClientV1RestServlet, client_path_patterns class PushRuleRestServlet(ClientV1RestServlet): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 40e523cc5f..1581f88db5 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -13,20 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import SynapseError, Codes -from synapse.push import PusherConfigException +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.http.server import finish_request from synapse.http.servlet import ( - parse_json_object_from_request, parse_string, RestServlet + RestServlet, + parse_json_object_from_request, + parse_string, ) -from synapse.http.server import finish_request -from synapse.api.errors import StoreError +from synapse.push import PusherConfigException from .base import ClientV1RestServlet, client_path_patterns -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index c10320dedf..3ce5f8b726 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -14,21 +14,22 @@ # limitations under the License. """This module contains REST servlets to do with registration: /register""" +import hmac +import logging +from hashlib import sha1 + +from six import string_types + from twisted.internet import defer -from synapse.api.errors import SynapseError, Codes -from synapse.api.constants import LoginType -from synapse.api.auth import get_access_token_from_request -from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils +from synapse.api.auth import get_access_token_from_request +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import parse_json_object_from_request from synapse.types import create_requester -from hashlib import sha1 -import hmac -import logging - -from six import string_types +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6ae5db79b..2470db52ba 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -15,23 +15,27 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/ """ +import logging + +from six.moves.urllib import parse as urlparse + +from canonicaljson import json + from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns -from synapse.api.errors import SynapseError, Codes, AuthError -from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter -from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID -from synapse.events.utils import serialize_event, format_event_for_client_v2 +from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.http.servlet import ( - parse_json_object_from_request, parse_string, parse_integer + parse_integer, + parse_json_object_from_request, + parse_string, ) +from synapse.streams.config import PaginationConfig +from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID -from six.moves.urllib import parse as urlparse - -import logging -from canonicaljson import json +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index c43b30b73a..62f4c3d93e 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +import hashlib +import hmac + from twisted.internet import defer from .base import ClientV1RestServlet, client_path_patterns -import hmac -import hashlib -import base64 - - class VoipRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/voip/turnServer$") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 80dbc3c92e..528c1f43f9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,17 +17,20 @@ import logging from six.moves import http_client + from twisted.internet import defer from synapse.api.auth import has_access_token from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( - RestServlet, assert_params_in_request, + RestServlet, + assert_params_in_request, parse_json_object_from_request, ) from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import check_3pid_allowed + from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 0e0a187efd..371e9aa354 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_patterns - -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError, SynapseError +import logging from twisted.internet import defer -import logging +from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index d6f3a19648..bd8b5f4afa 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.api.constants import LoginType @@ -23,9 +25,6 @@ from synapse.http.servlet import RestServlet from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) RECAPTCHA_TEMPLATE = """ diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 35d58b367a..09f6a8efe3 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api import errors from synapse.http import servlet + from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 1b9dc4528d..ae86728879 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -13,17 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError, StoreError, Codes +from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID -from ._base import client_v2_patterns -from ._base import set_timeline_upper_limit - -import logging - +from ._base import client_v2_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 3bb1ec2af6..21e02c07c0 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -21,8 +23,6 @@ from synapse.types import GroupID from ._base import client_v2_patterns -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 3cc87ea63f..8486086b51 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -19,10 +19,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, parse_integer + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, ) -from synapse.http.servlet import parse_string from synapse.types import StreamToken + from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 66583d6778..2a6ea3df5f 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer -) from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_room_id, + format_event_for_client_v2_without_room_id, + serialize_event, ) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from ._base import client_v2_patterns -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index aa1cae8e1e..01c90aa2a3 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -14,15 +14,15 @@ # limitations under the License. -from ._base import client_v2_patterns +import logging + +from twisted.internet import defer -from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string -from twisted.internet import defer - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 2f8784fe06..a6e582a5ae 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 1fbff2edd8..de370cac45 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 97e7c0f7c6..896650d5a5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -14,29 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hmac +import logging +from hashlib import sha1 + +from six import string_types + from twisted.internet import defer import synapse import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError +from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string + RestServlet, + assert_params_in_request, + parse_json_object_from_request, + parse_string, ) from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.threepids import check_3pid_allowed from ._base import client_v2_patterns, interactive_auth_handler -import logging -import hmac -from hashlib import sha1 -from synapse.util.ratelimitutils import FederationRateLimiter - -from six import string_types - - # We ought to be using hmac.compare_digest() but on older pythons it doesn't # exist. It's a _really minor_ security flaw to use plain string comparison # because the timing attack is so obscured by all the other code here it's diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 8903e12405..08bb8e04fd 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index d2aa47b326..8aa06faf23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -13,27 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools +import logging + +from canonicaljson import json + from twisted.internet import defer -from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer, parse_boolean +from synapse.api.constants import PresenceState +from synapse.api.errors import SynapseError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.events.utils import ( + format_event_for_client_v2_without_room_id, + serialize_event, ) from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken -from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_room_id, -) -from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION -from synapse.api.errors import SynapseError -from synapse.api.constants import PresenceState -from ._base import client_v2_patterns -from ._base import set_timeline_upper_limit -import itertools -import logging - -from canonicaljson import json +from ._base import client_v2_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index dac8603b07..4fea614e95 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_patterns - -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError +import logging from twisted.internet import defer -import logging +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 6773b9ba60..d9d379182e 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet + from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2d4a43c353..cac0624ba7 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request + from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2ecb15deee..6ac2987b98 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.servlet import RestServlet - import logging import re +from synapse.http.servlet import RestServlet + logger = logging.getLogger(__name__) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 724911d1e6..147ff7d79b 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -13,28 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hashlib import sha256 import hmac import logging +from hashlib import sha256 from os import path + from six.moves import http_client import jinja2 from jinja2 import TemplateNotFound + from twisted.internet import defer from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.api.errors import NotFoundError, SynapseError, StoreError +from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError -from synapse.http.server import ( - finish_request, - wrap_html_request_handler, -) +from synapse.http.server import finish_request, wrap_html_request_handler from synapse.http.servlet import parse_string from synapse.types import UserID - # language to use for the templates. TODO: figure this out from Accept-Language TEMPLATE_LANGUAGE = "en" diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index 1498d188c1..b9ee6e1c13 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -14,14 +14,16 @@ # limitations under the License. -from twisted.web.resource import Resource -from synapse.http.server import respond_with_json_bytes +import logging + +from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from canonicaljson import encode_canonical_json + from OpenSSL import crypto -import logging +from twisted.web.resource import Resource +from synapse.http.server import respond_with_json_bytes logger = logging.getLogger(__name__) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index a07224148c..3491fd2118 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.web.resource import Resource + from .local_key_resource import LocalKey from .remote_key_resource import RemoteKey diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 04775b3c45..ec0ec7b431 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -14,13 +14,15 @@ # limitations under the License. -from twisted.web.resource import Resource -from synapse.http.server import respond_with_json_bytes +import logging + +from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from canonicaljson import encode_canonical_json -import logging +from twisted.web.resource import Resource + +from synapse.http.server import respond_with_json_bytes logger = logging.getLogger(__name__) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 21b4c1175e..7d67e4b064 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import ( - respond_with_json_bytes, wrap_json_request_handler, -) -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.api.errors import SynapseError, Codes -from synapse.crypto.keyring import KeyLookupError +import logging +from io import BytesIO +from twisted.internet import defer from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer +from synapse.api.errors import Codes, SynapseError +from synapse.crypto.keyring import KeyLookupError +from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler +from synapse.http.servlet import parse_integer, parse_json_object_from_request -from io import BytesIO -import logging logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index e44d4276d2..f255f2883f 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -13,22 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json_bytes, finish_request - -from synapse.api.errors import ( - Codes, cs_error -) - -from twisted.protocols.basic import FileSender -from twisted.web import server, resource - -from canonicaljson import json - import base64 import logging import os import re +from canonicaljson import json + +from twisted.protocols.basic import FileSender +from twisted.web import resource, server + +from synapse.api.errors import Codes, cs_error +from synapse.http.server import finish_request, respond_with_json_bytes + logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index c0d2f06855..65f4bd2910 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -13,23 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json, finish_request -from synapse.api.errors import ( - cs_error, Codes, SynapseError -) -from synapse.util import logcontext +import logging +import os +import urllib + +from six.moves.urllib import parse as urlparse from twisted.internet import defer from twisted.protocols.basic import FileSender +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.util import logcontext from synapse.util.stringutils import is_ascii -import os - -import logging -import urllib -from six.moves.urllib import parse as urlparse - logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 8cf8820c31..fbfa85f74f 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -18,11 +18,9 @@ from twisted.internet import defer from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.http.server import ( - set_cors_headers, - wrap_json_request_handler, -) import synapse.http.servlet +from synapse.http.server import set_cors_headers, wrap_json_request_handler + from ._base import parse_media_id, respond_404 logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d5164e47e0..c8586fa280 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import re -import functools NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py index 66f2b6bd30..a2e391415f 100644 --- a/synapse/rest/media/v1/identicon_resource.py +++ b/synapse/rest/media/v1/identicon_resource.py @@ -13,6 +13,7 @@ # limitations under the License. from pydenticon import Generator + from twisted.web.resource import Resource FOREGROUND = [ diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 218ba7a083..30242c525a 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -14,41 +14,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads +import cgi +import errno +import logging +import os +import shutil + +from six import iteritems +from six.moves.urllib import parse as urlparse + import twisted.internet.error import twisted.web.http +from twisted.internet import defer, threads from twisted.web.resource import Resource -from ._base import respond_404, FileInfo, respond_with_responder -from .upload_resource import UploadResource -from .download_resource import DownloadResource -from .thumbnail_resource import ThumbnailResource -from .identicon_resource import IdenticonResource -from .preview_url_resource import PreviewUrlResource -from .filepath import MediaFilePaths -from .thumbnailer import Thumbnailer -from .storage_provider import StorageProviderWrapper -from .media_storage import MediaStorage - -from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.util.stringutils import random_string from synapse.api.errors import ( - SynapseError, HttpResponseException, NotFoundError, FederationDeniedError, + FederationDeniedError, + HttpResponseException, + NotFoundError, + SynapseError, ) - +from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.util.async import Linearizer -from synapse.util.stringutils import is_ascii from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import is_ascii, random_string -import os -import errno -import shutil - -import cgi -import logging -from six.moves.urllib import parse as urlparse -from six import iteritems +from ._base import FileInfo, respond_404, respond_with_responder +from .download_resource import DownloadResource +from .filepath import MediaFilePaths +from .identicon_resource import IdenticonResource +from .media_storage import MediaStorage +from .preview_url_resource import PreviewUrlResource +from .storage_provider import StorageProviderWrapper +from .thumbnail_resource import ThumbnailResource +from .thumbnailer import Thumbnailer +from .upload_resource import UploadResource logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index d6b8ebbedb..b25993fcb5 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -13,22 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads -from twisted.protocols.basic import FileSender +import contextlib +import logging +import os +import shutil +import sys import six -from ._base import Responder +from twisted.internet import defer, threads +from twisted.protocols.basic import FileSender from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.logcontext import make_deferred_yieldable -import contextlib -import os -import logging -import shutil -import sys - +from ._base import Responder logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index adca490640..4e3a18ce08 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -24,31 +24,28 @@ import shutil import sys import traceback -from canonicaljson import json - -from six.moves import urllib_parse as urlparse from six import string_types +from six.moves import urllib_parse as urlparse + +from canonicaljson import json -from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET -from ._base import FileInfo - -from synapse.api.errors import ( - SynapseError, Codes, -) -from synapse.util.logcontext import make_deferred_yieldable, run_in_background -from synapse.util.stringutils import random_string -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.api.errors import Codes, SynapseError from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - respond_with_json_bytes, respond_with_json, + respond_with_json_bytes, wrap_json_request_handler, ) from synapse.util.async import ObservableDeferred -from synapse.util.stringutils import is_ascii +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.util.stringutils import is_ascii, random_string + +from ._base import FileInfo logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 0252afd9d3..7b9f8b4d79 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -13,17 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads +import logging +import os +import shutil -from .media_storage import FileResponder +from twisted.internet import defer, threads from synapse.config._base import Config from synapse.util.logcontext import run_in_background -import logging -import os -import shutil - +from .media_storage import FileResponder logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index aae6e464e8..5305e9175f 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -20,13 +20,14 @@ from twisted.internet import defer from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.http.server import ( - set_cors_headers, - wrap_json_request_handler, -) +from synapse.http.server import set_cors_headers, wrap_json_request_handler from synapse.http.servlet import parse_integer, parse_string + from ._base import ( - FileInfo, parse_media_id, respond_404, respond_with_file, + FileInfo, + parse_media_id, + respond_404, + respond_with_file, respond_with_responder, ) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index e1ee535b9a..a4b26c2587 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import PIL.Image as Image +import logging from io import BytesIO -import logging +import PIL.Image as Image logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 7567476fce..1a98120e1d 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -20,10 +20,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import SynapseError -from synapse.http.server import ( - respond_with_json, - wrap_json_request_handler, -) +from synapse.http.server import respond_with_json, wrap_json_request_handler logger = logging.getLogger(__name__) diff --git a/synapse/server.py b/synapse/server.py index c29c19289a..92bea96c5c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,19 +33,30 @@ from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker from synapse.federation.federation_client import FederationClient -from synapse.federation.federation_server import FederationServer +from synapse.federation.federation_server import ( + FederationHandlerRegistry, + FederationServer, +) from synapse.federation.send_queue import FederationRemoteSendQueue -from synapse.federation.federation_server import FederationHandlerRegistry -from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue +from synapse.federation.transport.client import TransportLayerClient +from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer +from synapse.groups.groups_server import GroupsServerHandler from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.deactivate_account import DeactivateAccountHandler -from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler +from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.events import EventHandler, EventStreamHandler +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.initial_sync import InitialSyncHandler +from synapse.handlers.message import EventCreationHandler from synapse.handlers.presence import PresenceHandler +from synapse.handlers.profile import ProfileHandler +from synapse.handlers.read_marker import ReadMarkerHandler +from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.room import RoomCreationHandler from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler @@ -53,17 +64,8 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler -from synapse.handlers.events import EventHandler, EventStreamHandler -from synapse.handlers.initial_sync import InitialSyncHandler -from synapse.handlers.receipts import ReceiptsHandler -from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.user_directory import UserDirectoryHandler -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.handlers.profile import ProfileHandler -from synapse.handlers.message import EventCreationHandler -from synapse.groups.groups_server import GroupsServerHandler -from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning -from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory +from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator @@ -74,9 +76,7 @@ from synapse.rest.media.v1.media_repository import ( ) from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender -from synapse.server_notices.worker_server_notices_sender import ( - WorkerServerNoticesSender, -) +from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStore from synapse.streams.events import EventSources diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index bb74af1af5..5e3044d164 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -14,7 +14,8 @@ # limitations under the License. import logging -from six import (iteritems, string_types) +from six import iteritems, string_types + from twisted.internet import defer from synapse.api.errors import SynapseError diff --git a/synapse/state.py b/synapse/state.py index 8098db94b4..15a593d41c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -14,25 +14,25 @@ # limitations under the License. +import hashlib +import logging +from collections import namedtuple + +from six import iteritems, itervalues + +from frozendict import frozendict + from twisted.internet import defer from synapse import event_auth -from synapse.util.logutils import log_function -from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR - -from collections import namedtuple -from frozendict import frozendict - -import logging -import hashlib - -from six import iteritems, itervalues +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logutils import log_function +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e843b702b9..ba88a54979 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -15,51 +15,48 @@ # limitations under the License. import datetime -from dateutil import tz -import time import logging +import time +from dateutil import tz + +from synapse.api.constants import PresenceState from synapse.storage.devices import DeviceStore from synapse.storage.user_erasure_store import UserErasureStore -from .appservice import ( - ApplicationServiceStore, ApplicationServiceTransactionStore -) +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .end_to_end_keys import EndToEndKeyStore +from .engines import PostgresEngine +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore from .events import EventsStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore from .registration import RegistrationStore +from .rejections import RejectionsStore from .room import RoomStore from .roommember import RoomMemberStore -from .stream import StreamStore -from .transactions import TransactionStore -from .keys import KeyStore -from .event_federation import EventFederationStore -from .pusher import PusherStore -from .push_rule import PushRuleStore -from .media_repository import MediaRepositoryStore -from .rejections import RejectionsStore -from .event_push_actions import EventPushActionsStore -from .deviceinbox import DeviceInboxStore -from .group_server import GroupServerStore -from .state import StateStore -from .signatures import SignatureStore -from .filtering import FilteringStore -from .end_to_end_keys import EndToEndKeyStore - -from .receipts import ReceiptsStore from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stream import StreamStore from .tags import TagsStore -from .account_data import AccountDataStore -from .openid import OpenIdStore -from .client_ips import ClientIpStore +from .transactions import TransactionStore from .user_directory import UserDirectoryStore - -from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator -from .engines import PostgresEngine - -from synapse.api.constants import PresenceState -from synapse.util.caches.stream_change_cache import StreamChangeCache +from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator logger = logging.getLogger(__name__) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 22d6257a9f..1fd5d8f162 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -13,22 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys +import threading +import time -from synapse.api.errors import StoreError -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches.descriptors import Cache -from synapse.storage.engines import PostgresEngine +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range from prometheus_client import Histogram from twisted.internet import defer -import sys -import time -import threading - -from six import itervalues, iterkeys, iteritems -from six.moves import intern, range +from synapse.api.errors import StoreError +from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import Cache +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 7034a61399..bbc3355c73 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -14,18 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc +import logging + +from canonicaljson import json + from twisted.internet import defer from synapse.storage._base import SQLBaseStore from synapse.storage.util.id_generators import StreamIdGenerator - -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -from canonicaljson import json - -import abc -import logging +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 4d32d0bdf6..9f12b360bc 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -15,14 +15,16 @@ # limitations under the License. import logging import re -from twisted.internet import defer + from canonicaljson import json +from twisted.internet import defer + from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage.events import EventsWorkerStore -from ._base import SQLBaseStore +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index af18964510..dc9eca7d15 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from . import engines - -from twisted.internet import defer +import logging from canonicaljson import json -import logging +from twisted.internet import defer + +from . import engines +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 968d2fed22..b78eda3413 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,15 +15,14 @@ import logging -from twisted.internet import defer +from six import iteritems -from ._base import Cache -from . import background_updates +from twisted.internet import defer from synapse.util.caches import CACHE_SIZE_FACTOR -from six import iteritems - +from . import background_updates +from ._base import Cache logger = logging.getLogger(__name__) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 38addbf9c0..73646da025 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -19,10 +19,9 @@ from canonicaljson import json from twisted.internet import defer -from .background_updates import BackgroundUpdateStore - from synapse.util.caches.expiringcache import ExpiringCache +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 2ed9ada783..ec68e39f1e 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -14,15 +14,16 @@ # limitations under the License. import logging +from six import iteritems, itervalues + +from canonicaljson import json + from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, Cache -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks - -from canonicaljson import json +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from six import itervalues, iteritems +from ._base import Cache, SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index d0c0059757..808194236a 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached - -from synapse.api.errors import SynapseError +from collections import namedtuple from twisted.internet import defer -from collections import namedtuple +from synapse.api.errors import SynapseError +from synapse.util.caches.descriptors import cached +from ._base import SQLBaseStore RoomAliasMapping = namedtuple( "RoomAliasMapping", diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 181047c8b7..7ae5c65482 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer from synapse.util.caches.descriptors import cached -from canonicaljson import encode_canonical_json, json - from ._base import SQLBaseStore -from six import iteritems - class EndToEndKeyStore(SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 8c868ece75..e2f9de8451 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import IncorrectDatabaseSetup -from .postgres import PostgresEngine -from .sqlite3 import Sqlite3Engine - import importlib import platform +from ._base import IncorrectDatabaseSetup +from .postgres import PostgresEngine +from .sqlite3 import Sqlite3Engine SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 60f0fa7fb3..19949fc474 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - import struct import threading +from synapse.storage.prepare_database import prepare_database + class Sqlite3Engine(object): single_threaded = True diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 8fbf7ffba7..8d366d1b91 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,23 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import random +from six.moves import range +from six.moves.queue import Empty, PriorityQueue + +from unpaddedbase64 import encode_base64 + from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.events import EventsWorkerStore from synapse.storage.signatures import SignatureWorkerStore - -from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached -from unpaddedbase64 import encode_base64 - -import logging -from six.moves.queue import PriorityQueue, Empty - -from six.moves import range - logger = logging.getLogger(__name__) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 05cb3f61ce..29b511ae5e 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -14,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._base import SQLBaseStore, LoggingTransaction -from twisted.internet import defer -from synapse.util.caches.descriptors import cachedInlineCallbacks - import logging +from six import iteritems + from canonicaljson import json -from six import iteritems +from twisted.internet import defer + +from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a54abb9edd..2aaab0d02c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -14,37 +14,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict, deque, namedtuple -from functools import wraps import itertools import logging +from collections import OrderedDict, deque, namedtuple +from functools import wraps + +from six import iteritems, itervalues +from six.moves import range from canonicaljson import json +from prometheus_client import Counter from twisted.internet import defer +import synapse.metrics +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 from synapse.storage.events_worker import EventsWorkerStore +from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util.async import ObservableDeferred +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.logcontext import ( - PreserveLoggingContext, make_deferred_yieldable, -) +from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.types import get_domain_from_id, RoomStreamToken -import synapse.metrics - -# these are only included to make the type annotations work -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 - -from six.moves import range -from six import itervalues, iteritems - -from prometheus_client import Counter logger = logging.getLogger(__name__) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 896225aab9..5fe1fd13e5 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,29 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +import logging +from collections import namedtuple + +from canonicaljson import json from twisted.internet import defer +from synapse.api.errors import SynapseError +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event - from synapse.util.logcontext import ( - PreserveLoggingContext, make_deferred_yieldable, run_in_background, LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, ) from synapse.util.metrics import Measure -from synapse.api.errors import SynapseError -from collections import namedtuple - -import logging - -from canonicaljson import json - -# these are only included to make the type annotations work -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index eae6027cee..2d5896c5b4 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer -from ._base import SQLBaseStore -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from canonicaljson import encode_canonical_json, json +from ._base import SQLBaseStore class FilteringStore(SQLBaseStore): diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index b77402d295..592d1b4c2a 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -14,15 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from canonicaljson import json + from twisted.internet import defer from synapse.api.errors import SynapseError from ._base import SQLBaseStore -from canonicaljson import json - - # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null _DEFAULT_CATEGORY_ID = "" diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 0f13b61da8..f547977600 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +import hashlib +import logging -from twisted.internet import defer import six -import OpenSSL from signedjson.key import decode_verify_key_bytes -import hashlib -import logging +import OpenSSL +from twisted.internet import defer + +from synapse.util.caches.descriptors import cachedInlineCallbacks + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index cf2aae0468..b290f834b3 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -20,7 +20,6 @@ import logging import os import re - logger = logging.getLogger(__name__) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index f05d91cc58..a0c7a0dc87 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from collections import namedtuple + +from twisted.internet import defer + from synapse.api.constants import PresenceState -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from collections import namedtuple -from twisted.internet import defer +from ._base import SQLBaseStore class UserPresenceState(namedtuple("UserPresenceState", diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 8612bd5ecc..60295da254 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.storage.roommember import ProfileInfo from synapse.api.errors import StoreError +from synapse.storage.roommember import ProfileInfo from ._base import SQLBaseStore diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9e52e992b3..be655d287b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,21 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +import abc +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.push.baserules import list_with_base_rules from synapse.storage.appservice import ApplicationServiceWorkerStore from synapse.storage.pusher import PusherWorkerStore from synapse.storage.receipts import ReceiptsWorkerStore from synapse.storage.roommember import RoomMemberWorkerStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.push.baserules import list_with_base_rules -from synapse.api.constants import EventTypes -from twisted.internet import defer -from canonicaljson import json - -import abc -import logging +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index c6def861cf..cc273a57b2 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -14,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer +import logging +import types from canonicaljson import encode_canonical_json, json +from twisted.internet import defer + from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -import logging -import types +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f230a3bab7..3738901ea4 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -14,18 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from .util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from twisted.internet import defer +import abc +import logging from canonicaljson import json -import abc -import logging +from twisted.internet import defer +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator logger = logging.getLogger(__name__) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0d18f6d869..07333f777d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -15,15 +15,15 @@ import re +from six.moves import range + from twisted.internet import defer -from synapse.api.errors import StoreError, Codes +from synapse.api.errors import Codes, StoreError from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from six.moves import range - class RegistrationWorkerStore(SQLBaseStore): @cached() diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py index 40acb5c4ed..880f047adb 100644 --- a/synapse/storage/rejections.py +++ b/synapse/storage/rejections.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore - import logging +from ._base import SQLBaseStore + logger = logging.getLogger(__name__) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index ca0eb187e5..3147fb6827 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -13,6 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections +import logging +import re + +from canonicaljson import json + from twisted.internet import defer from synapse.api.errors import StoreError @@ -20,12 +26,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.search import SearchStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from canonicaljson import json - -import collections -import logging -import re - logger = logging.getLogger(__name__) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8fc9549a75..02a802bed9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -14,24 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - +import logging from collections import namedtuple +from six import iteritems, itervalues + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership from synapse.storage.events import EventsWorkerStore +from synapse.types import get_domain_from_id from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii -from synapse.api.constants import Membership, EventTypes -from synapse.types import get_domain_from_id - -import logging -from canonicaljson import json - -from six import itervalues, iteritems - logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index e7351c3ae6..4b2ffd35fd 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -14,11 +14,11 @@ import logging -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine, Sqlite3Engine - import simplejson +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index 6df57b5206..414f9f5aa0 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -14,10 +14,10 @@ import logging -from synapse.storage.prepare_database import get_statements - import simplejson +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 85bd1a2006..ef7ec34346 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from synapse.config.appservice import load_appservices from six.moves import range +from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py index fe6b7d196d..7d8ca5f93f 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/schema/delta/31/search_update.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - import logging + import simplejson +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py index 1e002f9db2..bff1256a7b 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements - import logging + import simplejson +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py index 55ae43f395..9754d3ccfb 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -14,7 +14,6 @@ import time - ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py index 3b63a1562d..cf09e43e2b 100644 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py index 033144341c..67d505e68b 100644 --- a/synapse/storage/schema/delta/34/received_txn_purge.py +++ b/synapse/storage/schema/delta/34/received_txn_purge.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py index 81948e3431..0ffab10b6f 100644 --- a/synapse/storage/schema/delta/34/sent_txn_purge.py +++ b/synapse/storage/schema/delta/34/sent_txn_purge.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py index 20ad8bd5a6..a377884169 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) DROP_INDICES = """ diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py index ea6a18196d..506f326f4d 100644 --- a/synapse/storage/schema/delta/42/user_dir.py +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -14,8 +14,8 @@ import logging -from synapse.storage.prepare_database import get_statements from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 9b77c45318..d5b5df93e6 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -13,19 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import logging import re -from canonicaljson import json +from collections import namedtuple from six import string_types +from canonicaljson import json + from twisted.internet import defer -from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from .background_updates import BackgroundUpdateStore + logger = logging.getLogger(__name__) SearchEntry = namedtuple('SearchEntry', [ diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 25922e5a9c..470212aa2a 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer import six -from ._base import SQLBaseStore - from unpaddedbase64 import encode_base64 + +from twisted.internet import defer + from synapse.crypto.event_signing import compute_event_reference_hash from synapse.util.caches.descriptors import cached, cachedList +from ._base import SQLBaseStore + # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cd9821c270..89a05c4618 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import logging +from collections import namedtuple from six import iteritems, itervalues from six.moves import range @@ -23,10 +23,11 @@ from twisted.internet import defer from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine -from synapse.util.caches import intern_string, get_cache_factor_for +from synapse.util.caches import get_cache_factor_for, intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii + from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index fb463c525a..66856342f0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -33,22 +33,20 @@ what sort order was used: and stream ordering columns respectively. """ +import abc +import logging +from collections import namedtuple + +from six.moves import range + from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from synapse.storage.events import EventsWorkerStore - from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background -from synapse.storage.engines import PostgresEngine - -import abc -import logging - -from six.moves import range -from collections import namedtuple - logger = logging.getLogger(__name__) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 04d123ed95..0f657b2bd3 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -14,16 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.account_data import AccountDataWorkerStore +import logging -from synapse.util.caches.descriptors import cached -from twisted.internet import defer +from six.moves import range from canonicaljson import json -import logging +from twisted.internet import defer -from six.moves import range +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index acbc03446e..c3bc94f56d 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,17 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached +import logging +from collections import namedtuple -from twisted.internet import defer import six from canonicaljson import encode_canonical_json, json -from collections import namedtuple +from twisted.internet import defer -import logging +from synapse.util.caches.descriptors import cached + +from ._base import SQLBaseStore # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 275c299998..ce59e70d0e 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +import re -from ._base import SQLBaseStore +from six import iteritems + +from twisted.internet import defer -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from six import iteritems - -import re -import logging +from ._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py index 47bfc01e84..be013f4427 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/user_erasure_store.py @@ -17,7 +17,7 @@ import operator from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedList, cached +from synapse.util.caches.descriptors import cached, cachedList class UserErasureWorkerStore(SQLBaseStore): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 95031dc9ec..d6160d5e4d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import deque import contextlib import threading +from collections import deque class IdGenerator(object): diff --git a/synapse/streams/config.py b/synapse/streams/config.py index ca78e551cb..46ccbbda7d 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.types import StreamToken - import logging +from synapse.api.errors import SynapseError +from synapse.types import StreamToken logger = logging.getLogger(__name__) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f03ad99118..e5220132a3 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -15,13 +15,12 @@ from twisted.internet import defer -from synapse.types import StreamToken - +from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.presence import PresenceEventSource +from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource -from synapse.handlers.receipts import ReceiptEventSource -from synapse.handlers.account_data import AccountDataEventSource +from synapse.types import StreamToken class EventSources(object): diff --git a/synapse/types.py b/synapse/types.py index cc7c182a78..08f058f714 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import string +from collections import namedtuple from synapse.api.errors import SynapseError -from collections import namedtuple - class Requester(namedtuple("Requester", [ "user", "access_token_id", "is_guest", "device_id", "app_service", diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index e9886ef299..680ea928c7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -17,6 +17,7 @@ import logging from itertools import islice import attr + from twisted.internet import defer, task from synapse.util.logcontext import PreserveLoggingContext diff --git a/synapse/util/async.py b/synapse/util/async.py index 1668df4ce6..5d0fb39130 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from contextlib import contextmanager + +from six.moves import range + from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure +from synapse.util import Clock, logcontext, unwrapFirstError + from .logcontext import ( - PreserveLoggingContext, make_deferred_yieldable, run_in_background + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, ) -from synapse.util import logcontext, unwrapFirstError, Clock - -from contextlib import contextmanager - -import logging - -from six.moves import range logger = logging.getLogger(__name__) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 900575eb3c..7b065b195e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily - import os -from six.moves import intern import six +from six.moves import intern + +from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 65a1042de1..f8a07df6b8 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,10 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import inspect import logging +import threading +from collections import namedtuple +import six +from six import itervalues, string_types + +from twisted.internet import defer + +from synapse.util import logcontext, unwrapFirstError from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError, logcontext from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -24,17 +33,6 @@ from synapse.util.stringutils import to_ascii from . import register_cache -from twisted.internet import defer -from collections import namedtuple - -import functools -import inspect -import threading - -from six import string_types, itervalues -import six - - logger = logging.getLogger(__name__) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 95793d466d..6c0b5a4094 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches.lrucache import LruCache -from collections import namedtuple -from . import register_cache -import threading import logging +import threading +from collections import namedtuple + +from synapse.util.caches.lrucache import LruCache +from . import register_cache logger = logging.getLogger(__name__) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ff04c91955..4abca91f6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache - -from collections import OrderedDict import logging +from collections import OrderedDict +from synapse.util.caches import register_cache logger = logging.getLogger(__name__) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1c5a982094..b684f24e7b 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -14,8 +14,8 @@ # limitations under the License. -from functools import wraps import threading +from functools import wraps from synapse.util.caches.treecache import TreeCache diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 0fb8620001..8637867c6d 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import caches - +import logging from sortedcontainers import SortedDict -import logging +from synapse.util import caches logger = logging.getLogger(__name__) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index c78801015b..629ed44149 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import queue + from twisted.internet import threads from synapse.util.logcontext import make_deferred_yieldable, run_in_background -from six.moves import queue - class BackgroundFileConsumer(object): """A consumer that writes to a file like object. Supports both push diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 535e7d0e7a..581c6052ac 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from frozendict import frozendict -from canonicaljson import json - from six import string_types +from canonicaljson import json +from frozendict import frozendict + def freeze(o): if isinstance(o, dict): diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index e9f0f292ee..2d7ddc1cbe 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import NoResource - import logging +from twisted.web.resource import NoResource + logger = logging.getLogger(__name__) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index df2b71b791..fe9288b031 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -22,10 +22,10 @@ them. See doc/log_contexts.rst for details on how this works. """ -from twisted.internet import defer - -import threading import logging +import threading + +from twisted.internet import defer logger = logging.getLogger(__name__) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index 3e42868ea9..a46bc47ce3 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -14,10 +14,11 @@ # limitations under the License. -from six import StringIO import logging import traceback +from six import StringIO + class LogFormatter(logging.Formatter): """Log formatter which gives more detail for exceptions diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index 03249c5dc8..62a00189cc 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -14,13 +14,11 @@ # limitations under the License. -from inspect import getcallargs -from functools import wraps - -import logging import inspect +import logging import time - +from functools import wraps +from inspect import getcallargs _TIME_FUNC_ID = 0 diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 97e0f00b67..14be3c7396 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.conch.manhole import ColoredManhole -from twisted.conch.insults import insults from twisted.conch import manhole_ssh -from twisted.cred import checkers, portal +from twisted.conch.insults import insults +from twisted.conch.manhole import ColoredManhole from twisted.conch.ssh.keys import Key +from twisted.cred import checkers, portal PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 1ba7d65c7c..63bc64c642 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +from functools import wraps from prometheus_client import Counter -from synapse.util.logcontext import LoggingContext -from functools import wraps -import logging +from twisted.internet import defer +from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index 607161e7f0..a6c30e5265 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -14,6 +14,7 @@ # limitations under the License. import phonenumbers + from synapse.api.errors import SynapseError diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index c5a45cef7c..5ac33b2132 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections +import contextlib +import logging + from twisted.internet import defer from synapse.api.errors import LimitExceededError - from synapse.util.logcontext import ( - run_in_background, make_deferred_yieldable, PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, ) -import collections -import contextlib -import logging - - logger = logging.getLogger(__name__) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 4e93f69d3a..8a3a06fd74 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.logcontext -from twisted.internet import defer - -from synapse.api.errors import CodeMessageException - import logging import random +from twisted.internet import defer + +import synapse.util.logcontext +from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index f4a9abf83f..6c0f2bb0cf 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import resource import logging - +import resource logger = logging.getLogger("synapse.app.homeserver") diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b98b9dc6e4..43d9db67ec 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -15,6 +15,7 @@ import random import string + from six.moves import range _string_with_symbols = ( diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 52086df465..1fbcd41115 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import os import logging +import os +import subprocess logger = logging.getLogger(__name__) diff --git a/synapse/visibility.py b/synapse/visibility.py index 65d79cf0d0..015c2bab37 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,9 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event -from synapse.util.logcontext import ( - make_deferred_yieldable, preserve_fn, -) +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn logger = logging.getLogger(__name__) diff --git a/tests/__init__.py b/tests/__init__.py index aab20e8e02..24006c949e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,4 +14,5 @@ # limitations under the License. from twisted.trial import util + util.DEFAULT_TIMEOUT_DURATION = 10 diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index aec3b62897..5f158ec4b9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pymacaroons from mock import Mock + +import pymacaroons + from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID + from tests import unittest -from tests.utils import setup_test_homeserver, mock_getRawHeaders +from tests.utils import mock_getRawHeaders, setup_test_homeserver class TestHandlers(object): diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcceca7f3e..836a23fb54 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from tests.utils import ( - MockHttpResource, DeferredMockCallable, setup_test_homeserver -) +import jsonschema + +from twisted.internet import defer + +from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent -from synapse.api.errors import SynapseError -import jsonschema +from tests import unittest +from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver user_localpart = "test_user" diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 5b2b95860a..891e0cc973 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,14 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice import ApplicationService +import re + +from mock import Mock from twisted.internet import defer -from mock import Mock -from tests import unittest +from synapse.appservice import ApplicationService -import re +from tests import unittest def _regex(regex, exclusive=True): diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 9181692771..b9f4863e9a 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + +from twisted.internet import defer + from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( - _ServiceQueuer, _TransactionController, _Recoverer + _Recoverer, + _ServiceQueuer, + _TransactionController, ) -from twisted.internet import defer - from synapse.util.logcontext import make_deferred_yieldable -from ..utils import MockClock -from mock import Mock + from tests import unittest +from ..utils import MockClock + class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 879159ccea..eb7f0ab12a 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -19,6 +19,7 @@ import shutil import tempfile from synapse.config.homeserver import HomeServerConfig + from tests import unittest diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 772afd2cf9..5c422eff38 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -15,8 +15,11 @@ import os.path import shutil import tempfile + import yaml + from synapse.config.homeserver import HomeServerConfig + from tests import unittest diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 47cb328a01..cd11871b80 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -14,15 +14,13 @@ # limitations under the License. -from tests import unittest - -from synapse.events.builder import EventBuilder -from synapse.crypto.event_signing import add_hashes_and_signatures - +import nacl.signing from unpaddedbase64 import decode_base64 -import nacl.signing +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events.builder import EventBuilder +from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index cc1c862ba4..a9d37fe084 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -14,15 +14,19 @@ # limitations under the License. import time +from mock import Mock + import signedjson.key import signedjson.sign -from mock import Mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.util import logcontext, Clock +from synapse.util import Clock, logcontext from synapse.util.logcontext import LoggingContext + from tests import unittest, utils -from twisted.internet import defer, reactor class MockPerspectiveServer(object): diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index dfc870066e..f51d99419e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -14,11 +14,11 @@ # limitations under the License. -from .. import unittest - from synapse.events import FrozenEvent from synapse.events.utils import prune_event, serialize_event +from .. import unittest + def MockEvent(**kwargs): if "event_id" not in kwargs: diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 4e8dc8fea0..c91e25f54f 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -16,6 +16,7 @@ import logging from synapse.events import FrozenEvent from synapse.federation.federation_server import server_matches_acl_event + from tests import unittest diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b753455943..57c0771cf3 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest -from tests.utils import MockClock from synapse.handlers.appservice import ApplicationServicesHandler -from mock import Mock +from tests.utils import MockClock + +from .. import unittest class AppServiceHandlerTestCase(unittest.TestCase): diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1822dcf1e0..2e5e8e4dec 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,11 +14,13 @@ # limitations under the License. import pymacaroons + from twisted.internet import defer import synapse import synapse.api.errors from synapse.handlers.auth import AuthHandler + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 778ff2f6e9..633a0b7f36 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -17,8 +17,8 @@ from twisted.internet import defer import synapse.api.errors import synapse.handlers.device - import synapse.storage + from tests import unittest, utils user1 = "@boris:aaa" diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7e5332e272..a353070316 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -14,14 +14,14 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer + from synapse.handlers.directory import DirectoryHandler from synapse.types import RoomAlias +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index d1bd87b898..ca1542236d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,13 +14,14 @@ # limitations under the License. import mock -from synapse.api import errors + from twisted.internet import defer import synapse.api.errors import synapse.handlers.e2e_keys - import synapse.storage +from synapse.api import errors + from tests import unittest, utils diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index de06a6ad30..121ce78634 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -14,18 +14,22 @@ # limitations under the License. -from tests import unittest - from mock import Mock, call from synapse.api.constants import PresenceState from synapse.handlers.presence import ( - handle_update, handle_timeout, - IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT, FEDERATION_PING_INTERVAL, + FEDERATION_TIMEOUT, + IDLE_TIMER, + LAST_ACTIVE_GRANULARITY, + SYNC_ONLINE_TIMEOUT, + handle_timeout, + handle_update, ) from synapse.storage.presence import UserPresenceState +from tests import unittest + class PresenceUpdateTestCase(unittest.TestCase): def test_offline_to_online(self): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 458296ee4c..dc17918a3d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock, NonCallableMock +from twisted.internet import defer + import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e990e45220..025fa1be81 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver -from mock import Mock +from .. import unittest class RegistrationHandlers(object): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index a433bbfa8a..b08856f763 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,19 +14,24 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - -from mock import Mock, call, ANY import json -from ..utils import ( - MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver -) +from mock import ANY, Mock, call + +from twisted.internet import defer from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest + +from ..utils import ( + DeferredMockCallable, + MockClock, + MockHttpResource, + setup_test_homeserver, +) + def _expect_edu(destination, edu_type, content, origin="test"): return { diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b8a48d20a4..60e6a75953 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import ( - parse_server_name, - parse_and_validate_server_name, -) +from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name + from tests import unittest diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 64e07a8c93..8708c8a196 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from tests import unittest - import tempfile from mock import Mock, NonCallableMock -from tests.utils import setup_test_homeserver -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from twisted.internet import defer, reactor + from synapse.replication.tcp.client import ( - ReplicationClientHandler, ReplicationClientFactory, + ReplicationClientFactory, + ReplicationClientHandler, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.utils import setup_test_homeserver class BaseSlavedStoreTestCase(unittest.TestCase): diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index f47a42e45d..adf226404e 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -13,11 +13,11 @@ # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" TYPE = "my.type" diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cb058d3142..cea01d93eb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser -from twisted.internet import defer - +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 6624fe4eea..e6d670cc1f 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" ROOM_ID = "!room:blue" diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 6a757289db..eee99ca2e0 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,10 +1,11 @@ -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.transactions import CLEANUP_PERIOD_MS -from twisted.internet import defer, reactor from mock import Mock, call +from twisted.internet import defer, reactor + +from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock from synapse.util.logcontext import LoggingContext + from tests import unittest from tests.utils import MockClock diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index f5a7258e68..a5af36a99c 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -14,7 +14,7 @@ # limitations under the License. """ Tests REST events for /events paths.""" -from tests import unittest +from mock import Mock, NonCallableMock # twisted imports from twisted.internet import defer @@ -23,13 +23,11 @@ import synapse.rest.client.v1.events import synapse.rest.client.v1.register import synapse.rest.client.v1.room +from tests import unittest from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index dc94b8bd19..d71cc8e0db 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -15,12 +15,15 @@ """Tests REST events for /profile paths.""" from mock import Mock + from twisted.internet import defer import synapse.types -from synapse.api.errors import SynapseError, AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.rest.client.v1 import profile + from tests import unittest + from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index a6a4e2ffe0..f596acb85f 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.client.v1.register import CreateUserRestServlet -from twisted.internet import defer +import json + from mock import Mock + +from twisted.internet import defer + +from synapse.rest.client.v1.register import CreateUserRestServlet + from tests import unittest from tests.utils import mock_getRawHeaders -import json class CreateUserServletTestCase(unittest.TestCase): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 61d737725b..895dffa095 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -15,22 +15,21 @@ """Tests REST events for /rooms paths.""" +import json + +from mock import Mock, NonCallableMock +from six.moves.urllib import parse as urlparse + # twisted imports from twisted.internet import defer import synapse.rest.client.v1.room from synapse.api.constants import Membership - from synapse.types import UserID -import json -from six.moves.urllib import parse as urlparse - from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - PATH_PREFIX = "/_matrix/client/api/v1" diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index fe161ee5cb..bddb3302e4 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -15,18 +15,17 @@ """Tests REST events for /rooms paths.""" +from mock import Mock, NonCallableMock + # twisted imports from twisted.internet import defer import synapse.rest.client.v1.room from synapse.types import UserID -from ....utils import MockHttpResource, MockClock, setup_test_homeserver +from ....utils import MockClock, MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 3bb1dd003a..54d7ba380d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import time + # twisted imports from twisted.internet import defer -# trial imports -from tests import unittest - from synapse.api.constants import Membership -import json -import time +# trial imports +from tests import unittest class RestTestCase(unittest.TestCase): diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 5170217d9e..f18a8a6027 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest - from mock import Mock -from ....utils import MockHttpResource, setup_test_homeserver +from twisted.internet import defer from synapse.types import UserID -from twisted.internet import defer +from tests import unittest +from ....utils import MockHttpResource, setup_test_homeserver PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 76b833e119..bb0b2f94ea 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,16 +15,13 @@ from twisted.internet import defer -from tests import unittest - -from synapse.rest.client.v2_alpha import filter - -from synapse.api.errors import Codes - import synapse.types - +from synapse.api.errors import Codes +from synapse.rest.client.v2_alpha import filter from synapse.types import UserID +from tests import unittest + from ....utils import MockHttpResource, setup_test_homeserver PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8aba456510..9b57a56070 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,12 +1,15 @@ +import json + +from mock import Mock + +from twisted.internet import defer from twisted.python import failure +from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError -from twisted.internet import defer -from mock import Mock + from tests import unittest from tests.utils import mock_getRawHeaders -import json class RegisterRestServletTestCase(unittest.TestCase): diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index c5e2f5549a..bf254a260d 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -14,21 +14,21 @@ # limitations under the License. +import os +import shutil +import tempfile + +from mock import Mock + from twisted.internet import defer, reactor from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend -from mock import Mock - from tests import unittest -import os -import shutil -import tempfile - class MediaStorageTests(unittest.TestCase): def setUp(self): diff --git a/tests/server.py b/tests/server.py index 73069dff52..46223ccf05 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,15 +1,17 @@ +import json from io import BytesIO -import attr -import json from six import text_type -from twisted.python.failure import Failure +import attr + +from twisted.internet import threads from twisted.internet.defer import Deferred +from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock from synapse.http.site import SynapseRequest -from twisted.internet import threads + from tests.utils import setup_test_homeserver as _sth diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 3cfa21c9f8..6d6f00c5c5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,15 +14,15 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from synapse.util.async import ObservableDeferred +from twisted.internet import defer +from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached +from tests import unittest + class CacheTestCase(unittest.TestCase): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 00825498b1..099861b27c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -12,21 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import tempfile -from synapse.config._base import ConfigError -from tests import unittest + +from mock import Mock + +import yaml + from twisted.internet import defer -from tests.utils import setup_test_homeserver from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.config._base import ConfigError from synapse.storage.appservice import ( - ApplicationServiceStore, ApplicationServiceTransactionStore + ApplicationServiceStore, + ApplicationServiceTransactionStore, ) -import json -import os -import yaml -from mock import Mock +from tests import unittest +from tests.utils import setup_test_homeserver class ApplicationServiceStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 1286b4ce2d..ab1f310572 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,10 +1,10 @@ -from tests import unittest +from mock import Mock + from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class BackgroundUpdateTestCase(unittest.TestCase): diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 0ac910e76f..1d1234ee39 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -14,18 +14,18 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer +from collections import OrderedDict from mock import Mock -from collections import OrderedDict +from twisted.internet import defer from synapse.server import HomeServer - from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine +from tests import unittest + class SQLBaseStoreTestCase(unittest.TestCase): """ Test the "simple" SQL generating methods in SQLBaseStore. """ diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f8725acea0..a54cc6bc32 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -16,6 +16,7 @@ from twisted.internet import defer import synapse.api.errors + import tests.unittest import tests.utils diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 95709cd50a..129ebaf343 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.directory import DirectoryStore -from synapse.types import RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 3cbf9a78b1..8430fc7ba6 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer import tests.unittest import tests.utils -from mock import Mock USER_ID = "@user:example.com" diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 0be790d8f8..3a3d002782 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key + from twisted.internet import defer import tests.unittest diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index f5fcb611d4..3276b39504 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -14,13 +14,13 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.presence import PresenceStore from synapse.types import UserID -from tests.utils import setup_test_homeserver, MockClock +from tests import unittest +from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 423710c9c1..2c95e5e95a 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.profile import ProfileStore from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 888ddfaddd..475ec900c4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RedactionTestCase(unittest.TestCase): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index f863b75846..7821ea3fa3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -14,9 +14,9 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ef8a4d234f..ae8ae94b6d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.types import UserID, RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 657b279e5d..c5fd54f67e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RoomMemberStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 0891308f25..23fad12bca 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.storage import UserDirectoryStore from synapse.storage.roommember import ProfileInfo + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/test_distributor.py b/tests/test_distributor.py index c066381698..04a88056f1 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest -from twisted.internet import defer - from mock import Mock, patch +from twisted.internet import defer + from synapse.util.distributor import Distributor +from . import unittest + class DistributorTestCase(unittest.TestCase): diff --git a/tests/test_dns.py b/tests/test_dns.py index 3b360a0fc7..b647d92697 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest +from mock import Mock + from twisted.internet import defer from twisted.names import dns, error -from mock import Mock - from synapse.http.endpoint import resolve_service from tests.utils import MockClock +from . import unittest + @unittest.DEBUG class DnsTestCase(unittest.TestCase): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index d08e19c53a..06112430e5 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from synapse import event_auth from synapse.api.errors import AuthError from synapse.events import FrozenEvent -import unittest class EventAuthTestCase(unittest.TestCase): diff --git a/tests/test_federation.py b/tests/test_federation.py index fc80a69369..159a136971 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,14 +1,14 @@ -from twisted.internet.defer import succeed, maybeDeferred +from mock import Mock + +from twisted.internet.defer import maybeDeferred, succeed -from synapse.util import Clock from synapse.events import FrozenEvent from synapse.types import Requester, UserID +from synapse.util import Clock from tests import unittest -from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock - -from mock import Mock +from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver class MessageAcceptTests(unittest.TestCase): diff --git a/tests/test_preview.py b/tests/test_preview.py index 5bd36c74aa..446843367e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest - from synapse.rest.media.v1.preview_url_resource import ( - summarize_paragraphs, decode_and_calc_og + decode_and_calc_og, + summarize_paragraphs, ) +from . import unittest + class PreviewTestCase(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py index 8ad822c43b..4192013f6d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,9 +4,10 @@ import re from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactorClock -from synapse.util import Clock from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource +from synapse.util import Clock + from tests import unittest from tests.server import make_request, setup_test_homeserver diff --git a/tests/test_state.py b/tests/test_state.py index 71c412faf4..c0f2d1152d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,18 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer -from synapse.events import FrozenEvent from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership +from synapse.events import FrozenEvent from synapse.state import StateHandler, StateResolutionHandler -from .utils import MockClock - -from mock import Mock +from tests import unittest +from .utils import MockClock _next_event_id = 1000 diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index d28bb726bb..bc97c12245 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -14,7 +14,6 @@ # limitations under the License. from tests import unittest - from tests.utils import MockClock diff --git a/tests/test_types.py b/tests/test_types.py index 115def2287..729bd676c1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest - from synapse.api.errors import SynapseError from synapse.server import HomeServer -from synapse.types import UserID, RoomAlias, GroupID +from synapse.types import GroupID, RoomAlias, UserID + +from tests import unittest mock_homeserver = HomeServer(hostname="my.domain") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index a94d566c96..8176a7dabd 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,14 +13,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging +from functools import partial import mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError from synapse.util import logcontext -from twisted.internet import defer, reactor from synapse.util.caches import descriptors + from tests import unittest logger = logging.getLogger(__name__) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 543ac5bed9..26f2fa5800 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.caches.dictionary_cache import DictionaryCache +from tests import unittest + class DictCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 31d24adb8b..d12b5e838b 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock +from .. import unittest + class ExpiringCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index c2aae8f54c..7ce5f8c258 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -14,15 +14,16 @@ # limitations under the License. -from twisted.internet import defer, reactor +import threading + from mock import NonCallableMock +from six import StringIO + +from twisted.internet import defer, reactor from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest -from six import StringIO - -import threading class FileConsumerTests(unittest.TestCase): diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py index 9c795d9fdb..a5a767b1ff 100644 --- a/tests/util/test_limiter.py +++ b/tests/util/test_limiter.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest - from twisted.internet import defer from synapse.util.async import Limiter +from tests import unittest + class LimiterTestCase(unittest.TestCase): diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index bf7e3aa885..c95907b32c 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import logcontext, Clock -from tests import unittest +from six.moves import range from twisted.internet import defer, reactor +from synapse.util import Clock, logcontext from synapse.util.async import Linearizer -from six.moves import range + +from tests import unittest class LinearizerTestCase(unittest.TestCase): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 9cf90fcfc4..c54001f7a4 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,11 +1,11 @@ import twisted.python.failure -from twisted.internet import defer -from twisted.internet import reactor -from .. import unittest +from twisted.internet import defer, reactor -from synapse.util import logcontext, Clock +from synapse.util import Clock, logcontext from synapse.util.logcontext import LoggingContext +from .. import unittest + class LoggingContextTestCase(unittest.TestCase): diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py index 1a1a8412f2..297aebbfbe 100644 --- a/tests/util/test_logformatter.py +++ b/tests/util/test_logformatter.py @@ -15,6 +15,7 @@ import sys from synapse.util.logformatter import LogFormatter + from tests import unittest diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index dfb78cb8bd..9b36ef4482 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest +from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from mock import Mock +from .. import unittest class LruCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 1d745ae1a7..24194e3b25 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.async import ReadWriteLock +from tests import unittest + class ReadWriteLockTestCase(unittest.TestCase): diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index d3a8630c2f..0f5b32fcc0 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -14,10 +14,11 @@ # limitations under the License. -from .. import unittest +from twisted.internet.defer import Deferred from synapse.util.caches.snapshot_cache import SnapshotCache -from twisted.internet.defer import Deferred + +from .. import unittest class SnapshotCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 67ece166c7..e3897c0d19 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,8 +1,9 @@ -from tests import unittest from mock import patch from synapse.util.caches.stream_change_cache import StreamChangeCache +from tests import unittest + class StreamChangeCacheTests(unittest.TestCase): """ diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 7ab578a185..a5f2261208 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -14,10 +14,10 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.treecache import TreeCache +from .. import unittest + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index fdb24a48b0..03201a4d9b 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .. import unittest - from synapse.util.wheel_timer import WheelTimer +from .. import unittest + class WheelTimerTestCase(unittest.TestCase): def test_single_insert_fetch(self): diff --git a/tests/utils.py b/tests/utils.py index 189fd2711c..6adbdbfca1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,9 +15,10 @@ import hashlib from inspect import getcallargs -from six.moves.urllib import parse as urlparse from mock import Mock, patch +from six.moves.urllib import parse as urlparse + from twisted.internet import defer, reactor from synapse.api.errors import CodeMessageException, cs_error -- cgit 1.5.1 From 6c1ec5a1bdfaed2d13a501fd2441d99c73711adb Mon Sep 17 00:00:00 2001 From: Oleg Girko Date: Tue, 10 Jul 2018 00:11:39 +0100 Subject: Use more portable syntax using attrs package. Newer syntax attr.ib(factory=dict) is just a syntactic sugar for attr.ib(default=attr.Factory(dict)) It was introduced in newest version of attrs package (18.1.0) and doesn't work with older versions. We should either require minimum version of attrs to be 18.1.0, or use older (slightly more verbose) syntax. Requiring newest version is not a good solution because Linux distributions may have older version of attrs (17.4.0 in Fedora 28), and requiring to build (and package) newer version just to use newer syntactic sugar in only one test is just too much. It's much better to fix that test to use older syntax. Signed-off-by: Oleg Girko --- tests/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/server.py b/tests/server.py index 46223ccf05..e93f5a7f94 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,7 +22,7 @@ class FakeChannel(object): wire). """ - result = attr.ib(factory=dict) + result = attr.ib(default=attr.Factory(dict)) @property def json_body(self): -- cgit 1.5.1 From 3b391d9c459789b2d136d2738507b210e8275d1b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Jul 2018 16:28:04 +0100 Subject: Fix unit tests --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 6adbdbfca1..e488238bb3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -65,6 +65,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None config.federation_domain_whitelist = None config.federation_rc_reject_limit = 10 config.federation_rc_sleep_limit = 10 + config.federation_rc_sleep_delay = 100 config.federation_rc_concurrent = 10 config.filter_timeline_limit = 5000 config.user_directory_search_all_users = False -- cgit 1.5.1 From bc832f822fc17fd28ab3609d1231a9e821be78de Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jul 2018 17:03:04 +0100 Subject: Fixup unit test --- tests/util/test_stream_change_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index e3897c0d19..fc45baaaa0 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -141,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase): ) # Query all the entries mid-way through the stream, but include one - # that doesn't exist in it. We should get back the one that doesn't - # exist, too. + # that doesn't exist in it. We shouldn't get back the one that doesn't + # exist. self.assertEqual( cache.get_entities_changed( [ @@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]), + set(["bar@baz.net", "user@elsewhere.org"]), ) # Query all the entries, but before the first known point. We will get -- cgit 1.5.1 From 33b60c01b5e60dc38b02b6f0fd018be2d2d07cf7 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 14 Jul 2018 07:34:49 +1000 Subject: Make auth & transactions more testable (#3499) --- changelog.d/3499.misc | 0 synapse/api/auth.py | 124 +++++++++++++-------------- synapse/rest/client/transactions.py | 45 +++++----- synapse/rest/client/v1/base.py | 2 +- synapse/rest/client/v1/logout.py | 3 +- synapse/rest/client/v1/register.py | 6 +- synapse/rest/client/v2_alpha/account.py | 3 +- synapse/rest/client/v2_alpha/register.py | 5 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- tests/rest/client/test_transactions.py | 5 +- 10 files changed, 97 insertions(+), 98 deletions(-) create mode 100644 changelog.d/3499.misc (limited to 'tests') diff --git a/changelog.d/3499.misc b/changelog.d/3499.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6dec862fec..bc629832d9 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -193,7 +193,7 @@ class Auth(object): synapse.types.create_requester(user_id, app_service=app_service) ) - access_token = get_access_token_from_request( + access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) @@ -239,7 +239,7 @@ class Auth(object): @defer.inlineCallbacks def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( - get_access_token_from_request( + self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) ) @@ -513,7 +513,7 @@ class Auth(object): def get_appservice_by_req(self, request): try: - token = get_access_token_from_request( + token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) service = self.store.get_app_service_by_token(token) @@ -673,67 +673,67 @@ class Auth(object): " edit its room list entry" ) + @staticmethod + def has_access_token(request): + """Checks if the request has an access_token. -def has_access_token(request): - """Checks if the request has an access_token. + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + return bool(query_params) or bool(auth_headers) - Returns: - bool: False if no access_token was given, True otherwise. - """ - query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - return bool(query_params) or bool(auth_headers) - - -def get_access_token_from_request(request, token_not_found_http_status=401): - """Extracts the access_token from the request. - - Args: - request: The http request. - token_not_found_http_status(int): The HTTP status code to set in the - AuthError if the token isn't found. This is used in some of the - legacy APIs to change the status code to 403 from the default of - 401 since some of the old clients depended on auth errors returning - 403. - Returns: - str: The access_token - Raises: - AuthError: If there isn't an access_token in the request. - """ + @staticmethod + def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - query_params = request.args.get(b"access_token") - if auth_headers: - # Try the get the access_token from a "Authorization: Bearer" - # header - if query_params is not None: - raise AuthError( - token_not_found_http_status, - "Mixing Authorization headers and access_token query parameters.", - errcode=Codes.MISSING_TOKEN, - ) - if len(auth_headers) > 1: - raise AuthError( - token_not_found_http_status, - "Too many Authorization headers.", - errcode=Codes.MISSING_TOKEN, - ) - parts = auth_headers[0].split(" ") - if parts[0] == "Bearer" and len(parts) == 2: - return parts[1] + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") + if auth_headers: + # Try the get the access_token from a "Authorization: Bearer" + # header + if query_params is not None: + raise AuthError( + token_not_found_http_status, + "Mixing Authorization headers and access_token query parameters.", + errcode=Codes.MISSING_TOKEN, + ) + if len(auth_headers) > 1: + raise AuthError( + token_not_found_http_status, + "Too many Authorization headers.", + errcode=Codes.MISSING_TOKEN, + ) + parts = auth_headers[0].split(" ") + if parts[0] == "Bearer" and len(parts) == 2: + return parts[1] + else: + raise AuthError( + token_not_found_http_status, + "Invalid Authorization header.", + errcode=Codes.MISSING_TOKEN, + ) else: - raise AuthError( - token_not_found_http_status, - "Invalid Authorization header.", - errcode=Codes.MISSING_TOKEN, - ) - else: - # Try to get the access_token from the query params. - if not query_params: - raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) - return query_params[0] + return query_params[0] diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 7c01b438cb..00b1b3066e 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,38 +17,20 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.api.auth import get_access_token_from_request from synapse.util.async import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) - -def get_transaction_key(request): - """A helper function which returns a transaction key that can be used - with TransactionCache for idempotent requests. - - Idempotency is based on the returned key being the same for separate - requests to the same endpoint. The key is formed from the HTTP request - path and the access_token for the requesting user. - - Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. - Returns: - str: A transaction key - """ - token = get_access_token_from_request(request) - return request.path + "/" + token - - CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, clock): - self.clock = clock + def __init__(self, hs): + self.hs = hs + self.auth = self.hs.get_auth() + self.clock = self.hs.get_clock() self.transactions = { # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) } @@ -56,6 +38,23 @@ class HttpTransactionCache(object): # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + def _get_transaction_key(self, request): + """A helper function which returns a transaction key that can be used + with TransactionCache for idempotent requests. + + Idempotency is based on the returned key being the same for separate + requests to the same endpoint. The key is formed from the HTTP request + path and the access_token for the requesting user. + + Args: + request (twisted.web.http.Request): The incoming request. Must + contain an access_token. + Returns: + str: A transaction key + """ + token = self.auth.get_access_token_from_request(request) + return request.path + "/" + token + def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -64,7 +63,7 @@ class HttpTransactionCache(object): fetch_or_execute """ return self.fetch_or_execute( - get_transaction_key(request), fn, *args, **kwargs + self._get_transaction_key(request), fn, *args, **kwargs ) def fetch_or_execute(self, txn_key, fn, *args, **kwargs): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index dde02328c3..c77d7aba68 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet): self.hs = hs self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 05a8ecfcd8..430c692336 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -17,7 +17,6 @@ import logging from twisted.internet import defer -from synapse.api.auth import get_access_token_from_request from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns @@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet): if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token - access_token = get_access_token_from_request(request) + access_token = self._auth.get_access_token_from_request(request) yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 3ce5f8b726..9203e60615 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -23,7 +23,6 @@ from six import string_types from twisted.internet import defer import synapse.util.stringutils as stringutils -from synapse.api.auth import get_access_token_from_request from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import parse_json_object_from_request @@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() @@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - as_token = get_access_token_from_request(request) + as_token = self.auth.get_access_token_from_request(request) if "user" not in register_json: raise SynapseError(400, "Expected 'user' key.") @@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) app_service = self.store.get_app_service_by_token( access_token ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 528c1f43f9..1263abb02c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,7 +20,6 @@ from six.moves import http_client from twisted.internet import defer -from synapse.api.auth import has_access_token from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. - if has_access_token(request): + if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) params = yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 896650d5a5..d305afcaef 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -24,7 +24,6 @@ from twisted.internet import defer import synapse import synapse.types -from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.http.servlet import ( @@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet): desired_username = body['username'] appservice = None - if has_access_token(request): + if self.auth.has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet): # because the IRC bridges rely on being able to register stupid # IDs. - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 90bdb1db15..a9e9a47a0b 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): super(SendToDeviceRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() def on_PUT(self, request, message_type, txn_id): diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index eee99ca2e0..34e68ae82f 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self): self.clock = MockClock() - self.cache = HttpTransactionCache(self.clock) + self.hs = Mock() + self.hs.get_clock = Mock(return_value=self.clock) + self.hs.get_auth = Mock() + self.cache = HttpTransactionCache(self.hs) self.mock_http_response = (200, "GOOD JOB!") self.mock_key = "foo" -- cgit 1.5.1 From ea69d3565110a20f5abe00f1a5c2357b483140fb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 Jul 2018 11:38:45 +0100 Subject: Move filter_events_for_server out of FederationHandler for easier unit testing. --- synapse/handlers/federation.py | 144 ++--------------------------------------- synapse/visibility.py | 132 +++++++++++++++++++++++++++++++++++++ tests/test_visibility.py | 107 ++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 138 deletions(-) create mode 100644 tests/test_visibility.py (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3ecebd29f..20fb46fc89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -43,7 +43,6 @@ from synapse.crypto.event_signing import ( add_hashes_and_signatures, compute_event_signature, ) -from synapse.events.utils import prune_event from synapse.events.validator import EventValidator from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id @@ -52,8 +51,8 @@ from synapse.util.async import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function -from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination +from synapse.visibility import filter_events_for_server from ._base import BaseHandler @@ -501,137 +500,6 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - @measure_func("_filter_events_for_server") - @defer.inlineCallbacks - def _filter_events_for_server(self, server_name, room_id, events): - """Filter the given events for the given server, redacting those the - server can't see. - - Assumes the server is currently in the room. - - Returns - list[FrozenEvent] - """ - # First lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't - # need to check membership (as we know the server is in the room). - event_to_state_ids = yield self.store.get_state_ids_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - ) - ) - - visibility_ids = set() - for sids in event_to_state_ids.itervalues(): - hist = sids.get((EventTypes.RoomHistoryVisibility, "")) - if hist: - visibility_ids.add(hist) - - # If we failed to find any history visibility events then the default - # is "shared" visiblity. - if not visibility_ids: - defer.returnValue(events) - - event_map = yield self.store.get_events(visibility_ids) - all_open = all( - e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in event_map.itervalues() - ) - - if all_open: - defer.returnValue(events) - - # Ok, so we're dealing with events that have non-trivial visibility - # rules, so we need to also get the memberships of the room. - - event_to_state_ids = yield self.store.get_state_ids_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ) - ) - - # We only want to pull out member events that correspond to the - # server's domain. - - def check_match(id): - try: - return server_name == get_domain_from_id(id) - except Exception: - return False - - # Parses mapping `event_id -> (type, state_key) -> state event_id` - # to get all state ids that we're interested in. - event_map = yield self.store.get_events([ - e_id - for key_to_eid in list(event_to_state_ids.values()) - for key, e_id in key_to_eid.items() - if key[0] != EventTypes.Member or check_match(key[1]) - ]) - - event_to_state = { - e_id: { - key: event_map[inner_e_id] - for key, inner_e_id in key_to_eid.iteritems() - if inner_e_id in event_map - } - for e_id, key_to_eid in event_to_state_ids.iteritems() - } - - erased_senders = yield self.store.are_users_erased( - e.sender for e in events, - ) - - def redact_disallowed(event, state): - # if the sender has been gdpr17ed, always return a redacted - # copy of the event. - if erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) - return prune_event(event) - - if not state: - return event - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - if visibility in ["invited", "joined"]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.itervalues(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except Exception: - continue - - if domain != server_name: - continue - - memtype = ev.membership - if memtype == Membership.JOIN: - return event - elif memtype == Membership.INVITE: - if visibility == "invited": - return event - else: - return prune_event(event) - - return event - - defer.returnValue([ - redact_disallowed(e, event_to_state[e.event_id]) - for e in events - ]) - @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit, extremities): @@ -1558,7 +1426,7 @@ class FederationHandler(BaseHandler): limit ) - events = yield self._filter_events_for_server(origin, room_id, events) + events = yield filter_events_for_server(self.store, origin, events) defer.returnValue(events) @@ -1605,8 +1473,8 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self._filter_events_for_server( - origin, event.room_id, [event] + events = yield filter_events_for_server( + self.store, origin, [event], ) event = events[0] defer.returnValue(event) @@ -1896,8 +1764,8 @@ class FederationHandler(BaseHandler): min_depth=min_depth, ) - missing_events = yield self._filter_events_for_server( - origin, room_id, missing_events, + missing_events = yield filter_events_for_server( + self.store, origin, missing_events, ) defer.returnValue(missing_events) diff --git a/synapse/visibility.py b/synapse/visibility.py index 015c2bab37..ddce41efaf 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.types import get_domain_from_id from synapse.util.logcontext import make_deferred_yieldable, preserve_fn logger = logging.getLogger(__name__) @@ -225,3 +226,134 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, # we turn it into a list before returning it. defer.returnValue(list(filtered_events)) + + +@defer.inlineCallbacks +def filter_events_for_server(store, server_name, events): + """Filter the given events for the given server, redacting those the + server can't see. + + Assumes the server is currently in the room. + + Returns + list[FrozenEvent] + """ + # First lets check to see if all the events have a history visibility + # of "shared" or "world_readable". If thats the case then we don't + # need to check membership (as we know the server is in the room). + event_to_state_ids = yield store.get_state_ids_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + ) + ) + + visibility_ids = set() + for sids in event_to_state_ids.itervalues(): + hist = sids.get((EventTypes.RoomHistoryVisibility, "")) + if hist: + visibility_ids.add(hist) + + # If we failed to find any history visibility events then the default + # is "shared" visiblity. + if not visibility_ids: + defer.returnValue(events) + + event_map = yield store.get_events(visibility_ids) + all_open = all( + e.content.get("history_visibility") in (None, "shared", "world_readable") + for e in event_map.itervalues() + ) + + if all_open: + defer.returnValue(events) + + # Ok, so we're dealing with events that have non-trivial visibility + # rules, so we need to also get the memberships of the room. + + event_to_state_ids = yield store.get_state_ids_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) + ) + + # We only want to pull out member events that correspond to the + # server's domain. + + def check_match(id): + try: + return server_name == get_domain_from_id(id) + except Exception: + return False + + # Parses mapping `event_id -> (type, state_key) -> state event_id` + # to get all state ids that we're interested in. + event_map = yield store.get_events([ + e_id + for key_to_eid in list(event_to_state_ids.values()) + for key, e_id in key_to_eid.items() + if key[0] != EventTypes.Member or check_match(key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.iteritems() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.iteritems() + } + + erased_senders = yield store.are_users_erased( + e.sender for e in events, + ) + + def redact_disallowed(event, state): + # if the sender has been gdpr17ed, always return a redacted + # copy of the event. + if erased_senders[event.sender]: + logger.info( + "Sender of %s has been erased, redacting", + event.event_id, + ) + return prune_event(event) + + if not state: + return event + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + if visibility in ["invited", "joined"]: + # We now loop through all state events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in state.itervalues(): + if ev.type != EventTypes.Member: + continue + try: + domain = get_domain_from_id(ev.state_key) + except Exception: + continue + + if domain != server_name: + continue + + memtype = ev.membership + if memtype == Membership.JOIN: + return event + elif memtype == Membership.INVITE: + if visibility == "invited": + return event + else: + return prune_event(event) + + return event + + defer.returnValue([ + redact_disallowed(e, event_to_state[e.event_id]) + for e in events + ]) diff --git a/tests/test_visibility.py b/tests/test_visibility.py new file mode 100644 index 0000000000..86981958cb --- /dev/null +++ b/tests/test_visibility.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.visibility import filter_events_for_server +from tests import unittest +from tests.utils import setup_test_homeserver + +logger = logging.getLogger(__name__) + +TEST_ROOM_ID = "!TEST:ROOM" + + +class FilterEventsForServerTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.event_creation_handler = self.hs.get_event_creation_handler() + self.event_builder_factory = self.hs.get_event_builder_factory() + self.store = self.hs.get_datastore() + + @defer.inlineCallbacks + def test_filtering(self): + # + # The events to be filtered consist of 10 membership events (it doesn't + # really matter if they are joins or leaves, so let's make them joins). + # One of those membership events is going to be for a user on the + # server we are filtering for (so we can check the filtering is doing + # the right thing). + # + + # before we do that, we persist some other events to act as state. + self.inject_visibility("@admin:hs", "joined") + for i in range(0, 10): + yield self.inject_room_member("@resident%i:hs" % i) + + events_to_filter = [] + + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = yield self.inject_room_member(user, extra_content={"a": "b"}) + events_to_filter.append(evt) + + filtered = yield filter_events_for_server( + self.store, "test_server", events_to_filter, + ) + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("a", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["a"], "b") + + @defer.inlineCallbacks + def inject_visibility(self, user_id, visibility): + content = {"history_visibility": visibility} + builder = self.event_builder_factory.new({ + "type": "m.room.history_visibility", + "sender": user_id, + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_room_member(self, user_id, membership="join", extra_content={}): + content = {"membership": membership} + content.update(extra_content) + builder = self.event_builder_factory.new({ + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) -- cgit 1.5.1 From 15b13b537f40ba0318b104e28231e95dfc0ccf93 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 Jul 2018 14:05:31 +0100 Subject: Add a test which profiles filter_events_for_server in a large room --- tests/test_visibility.py | 157 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 86981958cb..1401d831d2 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -15,9 +15,11 @@ import logging from twisted.internet import defer +from twisted.internet.defer import succeed +from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server -from tests import unittest +import tests.unittest from tests.utils import setup_test_homeserver logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ logger = logging.getLogger(__name__) TEST_ROOM_ID = "!TEST:ROOM" -class FilterEventsForServerTestCase(unittest.TestCase): +class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver() @@ -105,3 +107,154 @@ class FilterEventsForServerTestCase(unittest.TestCase): yield self.hs.get_datastore().persist_event(event, context) defer.returnValue(event) + + @defer.inlineCallbacks + def test_large_room(self): + # see what happens when we have a large room with hundreds of thousands + # of membership events + + # As above, the events to be filtered consist of 10 membership events, + # where one of them is for a user on the server we are filtering for. + + import cProfile + import pstats + import time + + # we stub out the store, because building up all that state the normal + # way is very slow. + test_store = _TestStore() + + # our initial state is 100000 membership events and one + # history_visibility event. + room_state = [] + + history_visibility_evt = FrozenEvent({ + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + }) + room_state.append(history_visibility_evt) + test_store.add_event(history_visibility_evt) + + for i in range(0, 100000): + user = "@resident_user_%i:test.com" % (i, ) + evt = FrozenEvent({ + "event_id": "$res_event_%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz," + }, + }) + room_state.append(evt) + test_store.add_event(evt) + + events_to_filter = [] + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = FrozenEvent({ + "event_id": "$evt%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz", + }, + }) + events_to_filter.append(evt) + room_state.append(evt) + + test_store.add_event(evt) + test_store.set_state_ids_for_event(evt, { + (e.type, e.state_key): e.event_id for e in room_state + }) + + pr = cProfile.Profile() + pr.enable() + + logger.info("Starting filtering") + start = time.time() + filtered = yield filter_events_for_server( + test_store, "test_server", events_to_filter, + ) + logger.info("Filtering took %f seconds", time.time() - start) + + pr.disable() + with open("filter_events_for_server.profile", "w+") as f: + ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps.print_stats() + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("extra", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["extra"], "zzz") + + test_large_room.skip = "Disabled by default because it's slow" + + +class _TestStore(object): + """Implements a few methods of the DataStore, so that we can test + filter_events_for_server + + """ + def __init__(self): + # data for get_events: a map from event_id to event + self.events = {} + + # data for get_state_ids_for_events mock: a map from event_id to + # a map from (type_state_key) -> event_id for the state at that + # event + self.state_ids_for_events = {} + + def add_event(self, event): + self.events[event.event_id] = event + + def set_state_ids_for_event(self, event, state): + self.state_ids_for_events[event.event_id] = state + + def get_state_ids_for_events(self, events, types): + res = {} + include_memberships = False + for (type, state_key) in types: + if type == "m.room.history_visibility": + continue + if type != "m.room.member" or state_key is not None: + raise RuntimeError( + "Unimplemented: get_state_ids with type (%s, %s)" % + (type, state_key), + ) + include_memberships = True + + if include_memberships: + for event_id in events: + res[event_id] = self.state_ids_for_events[event_id] + + else: + k = ("m.room.history_visibility", "") + for event_id in events: + hve = self.state_ids_for_events[event_id][k] + res[event_id] = {k: hve} + + return succeed(res) + + def get_events(self, events): + return succeed({ + event_id: self.events[event_id] for event_id in events + }) + + def are_users_erased(self, users): + return succeed({u: False for u in users}) -- cgit 1.5.1 From 850238b4ef1573a4162048d5ae285bb3fdccf5bb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jul 2018 10:59:02 +0100 Subject: Add unit test --- tests/util/test_stream_change_cache.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'tests') diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index fc45baaaa0..65b0f2e6fb 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -178,6 +178,22 @@ class StreamChangeCacheTests(unittest.TestCase): ), ) + # Query a subset of the entries mid-way through the stream. We should + # only get back the subset. + self.assertEqual( + cache.get_entities_changed( + [ + "bar@baz.net", + ], + stream_pos=2, + ), + set( + [ + "bar@baz.net", + ] + ), + ) + def test_max_pos(self): """ StreamChangeCache.get_max_pos_of_last_change will return the most -- cgit 1.5.1 From bc006b3c9d2a9982bc834ff5d1ec1768c85f907a Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 17 Jul 2018 20:43:18 +1000 Subject: Refactor REST API tests to use explicit reactors (#3351) --- changelog.d/3351.misc | 0 synapse/http/site.py | 5 +- synapse/rest/client/v2_alpha/register.py | 2 +- tests/rest/client/v1/test_register.py | 68 +- tests/rest/client/v1/test_rooms.py | 1156 +++++++++++---------------- tests/rest/client/v1/utils.py | 115 ++- tests/rest/client/v2_alpha/__init__.py | 61 -- tests/rest/client/v2_alpha/test_filter.py | 129 +-- tests/rest/client/v2_alpha/test_register.py | 211 ++--- tests/rest/client/v2_alpha/test_sync.py | 83 ++ tests/server.py | 10 + tests/test_server.py | 22 +- tests/unittest.py | 11 + 13 files changed, 939 insertions(+), 934 deletions(-) create mode 100644 changelog.d/3351.misc create mode 100644 tests/rest/client/v2_alpha/test_sync.py (limited to 'tests') diff --git a/changelog.d/3351.misc b/changelog.d/3351.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/http/site.py b/synapse/http/site.py index 41dd974cea..5fd30a4c2c 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -42,9 +42,10 @@ class SynapseRequest(Request): which is handling the request, and returns a context manager. """ - def __init__(self, site, *args, **kw): - Request.__init__(self, *args, **kw) + def __init__(self, site, channel, *args, **kw): + Request.__init__(self, channel, *args, **kw) self.site = site + self._channel = channel self.authenticated_entity = None self.start_time = 0 diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1918897f68..d6cf915d86 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: - defer.returnValue((403, "Guest access is disabled")) + raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index f596acb85f..f15fb36213 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -17,26 +17,22 @@ import json from mock import Mock -from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v1.register import CreateUserRestServlet +from synapse.http.server import JsonResource +from synapse.rest.client.v1.register import register_servlets +from synapse.util import Clock from tests import unittest -from tests.utils import mock_getRawHeaders +from tests.server import make_request, setup_test_homeserver class CreateUserServletTestCase(unittest.TestCase): + """ + Tests for CreateUserRestServlet. + """ def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/client/api/v1/createUser' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() - self.registration_handler = Mock() self.appservice = Mock(sender="@as:test") @@ -44,39 +40,49 @@ class CreateUserServletTestCase(unittest.TestCase): get_app_service_by_token=Mock(return_value=self.appservice) ) - # do the dance to hook things up to the hs global - handlers = Mock( - registration_handler=self.registration_handler, + handlers = Mock(registration_handler=self.registration_handler) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) - self.servlet = CreateUserRestServlet(self.hs) - @defer.inlineCallbacks def test_POST_createuser_with_valid_user(self): + + res = JsonResource(self.hs) + register_servlets(self.hs, res) + + request_data = json.dumps( + { + "localpart": "someone", + "displayname": "someone interesting", + "duration_seconds": 200, + } + ) + + url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service' + user_id = "@someone:interesting" token = "my token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "localpart": "someone", - "displayname": "someone interesting", - "duration_seconds": 200 - }) self.registration_handler.get_or_create_user = Mock( return_value=(user_id, token) ) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + request, channel = make_request(b"POST", url, request_data) + request.render(res) + + # Advance the clock because it waits + self.clock.advance(1) + + self.assertEquals(channel.result["code"], b"200") det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 895dffa095..6b5764095e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -25,949 +25,772 @@ from twisted.internet import defer import synapse.rest.client.v1.room from synapse.api.constants import Membership +from synapse.http.server import JsonResource from synapse.types import UserID +from synapse.util import Clock -from ....utils import MockHttpResource, setup_test_homeserver -from .utils import RestTestCase +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) -PATH_PREFIX = "/_matrix/client/api/v1" +from .utils import RestHelper +PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomPermissionsTestCase(RestTestCase): - """ Tests room permissions. """ - user_id = "@sid1:red" - rmcreator_id = "@notme:red" - @defer.inlineCallbacks +class RoomBase(unittest.TestCase): + rmcreator_id = None + def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - hs = yield setup_test_homeserver( + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( "red", http_client=None, + clock=self.hs_clock, + reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.ratelimiter = hs.get_ratelimiter() + self.ratelimiter = self.hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() + self.hs.get_federation_handler = Mock(return_value=Mock()) def get_user_by_access_token(token=None, allow_guest=False): return { - "user": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.helper.auth_user_id), "token_id": 1, "is_guest": False, } - hs.get_auth().get_user_by_access_token = get_user_by_access_token + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) + + self.hs.get_auth().get_user_by_req = get_user_by_req + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token + self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") def _insert_client_ip(*args, **kwargs): return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - self.auth_user_id = self.rmcreator_id + self.hs.get_datastore().insert_client_ip = _insert_client_ip - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + self.resource = JsonResource(self.hs) + synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) + self.helper = RestHelper(self.hs, self.resource, self.user_id) - self.auth = hs.get_auth() - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" +class RoomPermissionsTestCase(RoomBase): + """ Tests room permissions. """ - self.created_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=False) + user_id = b"@sid1:red" + rmcreator_id = b"@notme:red" + + def setUp(self): - self.created_public_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=True) + super(RoomPermissionsTestCase, self).setUp() + + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) # send a message in one of the rooms self.created_rmid_msg_path = ( - "/rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode('ascii') + request, channel = make_request( + b"PUT", self.created_rmid_msg_path, - '{"msgtype":"m.text","body":"test msg"}' + b'{"msgtype":"m.text","body":"test msg"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # set topic for public room - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - '{"topic":"Public Room Topic"}' + request, channel = make_request( + b"PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # auth as user_id now - self.auth_user_id = self.user_id - - def tearDown(self): - pass + self.helper.auth_user_id = self.user_id - @defer.inlineCallbacks def test_send_message(self): - msg_content = '{"msgtype":"m.text","body":"hello"}' - send_msg_path = ( - "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,) - ) + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return b"/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)).encode('ascii'), + ) # send message in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and invited, expect 403 - yield self.invite( - room=self.created_rmid, - src=self.rmcreator_id, - targ=self.user_id - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(200, code, msg=str(response)) + self.helper.join(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_topic_perms(self): - topic_content = '{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + topic_content = b'{"topic":"My Topic Name"}' + topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, - topic_content + request, channel = make_request( + b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request( + b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 - yield self.invite( + self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) + self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default - self.auth_user_id = self.rmcreator_id - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = self.user_id + self.helper.auth_user_id = self.rmcreator_id + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(topic_content), response) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = make_request( + b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - (code, response) = yield self.mock_resource.trigger_get(path) - self.assertEquals(expect_code, code) + path = b"/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = make_request(b"GET", path) + render(request, self.resource, self.clock) + self.assertEquals(expect_code, int(channel.result["code"])) - @defer.inlineCallbacks def test_membership_basic_room_perms(self): # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room # expect all 403s - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # trying to invite people to this room should 403 - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=404) + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) - @defer.inlineCallbacks def test_membership_private_room_perms(self): room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, private room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, private room + left # expect all 200s - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_membership_public_room_perms(self): room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, public room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, public room + left # expect 200. - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_invited_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) # set [invite/join/left] of other user, expect 403s - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403) - - @defer.inlineCallbacks + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + def test_joined_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) # set invited of self, expect 403 - yield self.invite(room=room, src=self.user_id, targ=self.user_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) # set joined of self, expect 200 (NOOP) - yield self.join(room=room, user=self.user_id) + self.helper.join(room=room, user=self.user_id) other = "@burgundy:red" # set invited of other, expect 200 - yield self.invite(room=room, src=self.user_id, targ=other, - expect_code=200) + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) # set joined of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) # set left of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) # set left of self, expect 200 - yield self.leave(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) - @defer.inlineCallbacks def test_leave_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) - yield self.leave(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 403s for usr in [self.user_id, self.rmcreator_id]: - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403 + expect_code=403, ) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403 + expect_code=403, ) # It is always valid to LEAVE if you've already left (currently.) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403 + expect_code=403, ) -class RoomsMemberListTestCase(RestTestCase): +class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - self.auth_user_id = self.user_id - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + user_id = b"@sid1:red" - def tearDown(self): - pass - - @defer.inlineCallbacks def test_get_member_list(self): - room_id = yield self.create_room_as(self.user_id) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(200, code, msg=str(response)) + room_id = self.helper.create_room_as(self.user_id) + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_room(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/roomdoesnotexist/members" - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_permission(self): - room_id = yield self.create_room_as("@some_other_guy:red") - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(403, code, msg=str(response)) + room_id = self.helper.create_room_as(b"@some_other_guy:red") + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = yield self.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - yield self.invite(room=room_id, src=room_creator, - targ=self.user_id) + room_creator = b"@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = b"/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - yield self.join(room=room_id, user=self.user_id) + self.helper.join(room=room_id, user=self.user_id) # can see list now joined - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - yield self.leave(room=room_id, user=self.user_id) + self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomsCreateTestCase(RestTestCase): +class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + user_id = b"@sid1:red" - @defer.inlineCallbacks def test_post_room_no_keys(self): # POST with no config keys, expect new room id - (code, response) = yield self.mock_resource.trigger("POST", - "/createRoom", - "{}") - self.assertEquals(200, code, response) - self.assertTrue("room_id" in response) + request, channel = make_request(b"POST", b"/createRoom", b"{}") + + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), channel.result) + self.assertTrue("room_id" in channel.json_body) - @defer.inlineCallbacks def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"custom":"stuff"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private","custom":"things"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibili') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '["hello"]') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) -class RoomTopicTestCase(RestTestCase): +class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = "@sid1:red" - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id + user_id = b"@sid1:red" - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + def setUp(self): - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomTopicTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - - def tearDown(self): - pass + self.room_id = self.helper.create_room_as(self.user_id) + self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) - @defer.inlineCallbacks def test_invalid_puts(self): # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_topic(self): # nothing should be there - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(404, code, msg=str(response)) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) -class RoomMemberStateTestCase(RestTestCase): +class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() + user_id = b"@sid1:red" - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + def setUp(self): - self.room_id = yield self.create_room_as(self.user_id) + super(RoomMemberStateTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) def tearDown(self): pass - @defer.inlineCallbacks def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", path, '{}') - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid keys, wrong types - content = ('{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.user_id + urlparse.quote(self.room_id), + self.user_id, ) # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - expected_response = { - "membership": Membership.JOIN, - } - self.assertEquals(expected_response, response) + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message with custom key - content = ('{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, "Join us!" - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) -class RoomMessagesTestCase(RestTestCase): +class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() + super(RoomMessagesTestCase, self).setUp() - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + self.room_id = self.helper.create_room_as(self.user_id) - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - def tearDown(self): - pass - - @defer.inlineCallbacks def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = '{"body":"test","msgtype":{"type":"a"}}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # custom message types content = '{"body":"test","msgtype":"test.custom.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) - -# (code, response) = yield self.mock_resource.trigger("GET", path, None) -# self.assertEquals(200, code, msg=str(response)) -# self.assert_dict(json.loads(content), response) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % ( - urlparse.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = '{"body":"test2","msgtype":"m.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomInitialSyncTestCase(RestTestCase): +class RoomInitialSyncTestCase(RoomBase): """ Tests /rooms/$room_id/initialSync. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomInitialSyncTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) + self.room_id = self.helper.create_room_as(self.user_id) - @defer.inlineCallbacks def test_initial_sync(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/initialSync" % self.room_id - ) - self.assertEquals(200, code) + request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) - self.assertEquals(self.room_id, response["room_id"]) - self.assertEquals("join", response["membership"]) + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} - for event in response["state"]: + for event in channel.json_body["state"]: if "state_key" not in event: continue t = event["type"] @@ -977,75 +800,48 @@ class RoomInitialSyncTestCase(RestTestCase): self.assertTrue("m.room.create" in state) - self.assertTrue("messages" in response) - self.assertTrue("chunk" in response["messages"]) - self.assertTrue("end" in response["messages"]) + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) - self.assertTrue("presence" in response) + self.assertTrue("presence" in channel.json_body) presence_by_user = { - e["content"]["user_id"]: e for e in response["presence"] + e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) -class RoomMessageListTestCase(RestTestCase): +class RoomMessageListTestCase(RoomBase): """ Tests /rooms/$room_id/messages REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token + super(RoomMessageListTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - @defer.inlineCallbacks def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 54d7ba380d..41de8e0762 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -16,13 +16,14 @@ import json import time -# twisted imports +import attr + from twisted.internet import defer from synapse.api.constants import Membership -# trial imports from tests import unittest +from tests.server import make_request, wait_until_result class RestTestCase(unittest.TestCase): @@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase): for key in required: self.assertEquals(required[key], actual[key], msg="%s mismatch. %s" % (key, actual)) + + +@attr.s +class RestHelper(object): + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + resource = attr.ib() + auth_user_id = attr.ib() + + def create_room_as(self, room_creator, is_public=True, tok=None): + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = b"/_matrix/client/r0/createRoom" + content = {} + if not is_public: + content["visibility"] = "private" + if tok: + path = path + b"?access_token=%s" % tok.encode('ascii') + + request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert channel.result["code"] == b"200", channel.result + self.auth_user_id = temp_id + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + + request, channel = make_request( + b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') + ) + + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + self.auth_user_id = temp_id + + @defer.inlineCallbacks + def register(self, user_id): + (code, response) = yield self.mock_resource.trigger( + "POST", + "/_matrix/client/r0/register", + json.dumps( + {"user": user_id, "password": "test", "type": "m.login.password"} + ), + ) + self.assertEquals(200, code) + defer.returnValue(response) + + @defer.inlineCallbacks + def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + if body is None: + body = "body_text_here" + + path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) + content = '{"msgtype":"m.text","body":"%s"}' % body + if tok: + path = path + "?access_token=%s" % tok + + (code, response) = yield self.mock_resource.trigger("PUT", path, content) + self.assertEquals(expect_code, code, msg=str(response)) diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f18a8a6027..e69de29bb2 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mock import Mock - -from twisted.internet import defer - -from synapse.types import UserID - -from tests import unittest - -from ....utils import MockHttpResource, setup_test_homeserver - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class V2AlphaRestTestCase(unittest.TestCase): - # Consumer must define - # USER_ID = - # TO_REGISTER = [] - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - datastore=self.make_datastore_mock(), - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - ) - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - for r in self.TO_REGISTER: - r.register_servlets(hs, self.mock_resource) - - def make_datastore_mock(self): - store = Mock(spec=[ - "insert_client_ip", - ]) - store.get_app_service_by_token = Mock(return_value=None) - return store diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index bb0b2f94ea..5ea9cc825f 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,35 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.types from synapse.api.errors import Codes +from synapse.http.server import JsonResource from synapse.rest.client.v2_alpha import filter from synapse.types import UserID +from synapse.util import Clock from tests import unittest - -from ....utils import MockHttpResource, setup_test_homeserver +from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock +from tests.server import make_request, setup_test_homeserver, wait_until_result PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = "@apple:test" + USER_ID = b"@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) - self.hs = yield setup_test_homeserver( - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() @@ -55,82 +53,103 @@ class FilterTestCase(unittest.TestCase): def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None) + UserID.from_string(self.USER_ID), 1, False, None + ) self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_req = get_user_by_req self.store = self.hs.get_datastore() self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.mock_resource) + r.register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_add_filter(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON - ) - self.assertEquals(200, code) - self.assertEquals({"filter_id": "0"}, response) - filter = yield self.store.get_user_filter( - user_localpart='apple', - filter_id=0, + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(filter, self.EXAMPLE_FILTER) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.clock.advance(0) + self.assertEquals(filter.result, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_add_filter_for_other_user(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) + request.render(self.resource) + wait_until_result(self.clock, channel) + self.hs.is_mine = _is_mine - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_get_filter(self): - filter_id = yield self.filtering.add_user_filter( - user_localpart='apple', - user_filter=self.EXAMPLE_FILTER + filter_id = self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/%s" % (self.USER_ID, filter_id) + self.clock.advance(1) + filter_id = filter_id.result + request, channel = make_request( + b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) - self.assertEquals(200, code) - self.assertEquals(self.EXAMPLE_FILTER, response) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_get_filter_non_existant(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/12382148321" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) - self.assertEquals(400, code) - self.assertEquals(response['errcode'], Codes.NOT_FOUND) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py - @defer.inlineCallbacks def test_get_filter_invalid_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/foobar" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - @defer.inlineCallbacks def test_get_filter_no_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9b57a56070..e004d8fc73 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -2,165 +2,192 @@ import json from mock import Mock -from twisted.internet import defer from twisted.python import failure +from twisted.test.proto_helpers import MemoryReactorClock -from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError -from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.util import Clock from tests import unittest -from tests.utils import mock_getRawHeaders +from tests.server import make_request, setup_test_homeserver, wait_until_result class RegisterRestServletTestCase(unittest.TestCase): - def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/api/v2_alpha/register' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = b"/_matrix/client/r0/register" self.appservice = None - self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: self.appservice) + self.auth = Mock( + get_appservice_by_req=Mock(side_effect=lambda x: self.appservice) ) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_handler = Mock( check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None) + get_session_data=Mock(return_value=None), ) self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) # do the dance to hook it up to the hs global self.handlers = Mock( registration_handler=self.registration_handler, identity_handler=self.identity_handler, - login_handler=self.login_handler + login_handler=self.login_handler, + ) + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) + self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] - # init the thing we're testing - self.servlet = RegisterRestServlet(self.hs) + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_POST_appservice_registration_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "username": "kermit" - }) - self.appservice = { - "id": "1234" - } - self.registration_handler.appservice_register = Mock( - return_value=user_id - ) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token + self.appservice = {"id": "1234"} + self.registration_handler.appservice_register = Mock(return_value=user_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + request_data = json.dumps({"username": "kermit"}) + + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) - @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): - self.request.args = { - "access_token": "i_am_an_app_service" - } - - self.request_data = json.dumps({ - "username": "kermit" - }) self.appservice = None # no application service exists - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (401, None)) + request_data = json.dumps({"username": "kermit"}) + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): - self.request_data = json.dumps({ - "username": "kermit", - "password": 666 - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + request_data = json.dumps({"username": "kermit", "password": 666}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid password" + ) def test_POST_bad_username(self): - self.request_data = json.dumps({ - "username": 777, - "password": "monkey" - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) - - @defer.inlineCallbacks + request_data = json.dumps({"username": 777, "password": "monkey"}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid username" + ) + def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" device_id = "frogfone" - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey", - "device_id": device_id, - }) + request_data = json.dumps( + {"username": "kermit", "password": "monkey", "device_id": device_id} + ) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token - ) - self.device_handler.check_device_registered = \ - Mock(return_value=device_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + self.device_handler.check_device_registered = Mock(return_value=device_id) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, "device_id": device_id, } - self.assertDictContainsSubset(det_data, result) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.auth_handler.get_login_tuple_for_user_id( - user_id, device_id=device_id, initial_device_display_name=None) + user_id, device_id=device_id, initial_device_display_name=None + ) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey" - }) + request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=("@user:id", "t")) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], + "Registration has been disabled", + ) + + def test_POST_guest_registration(self): + user_id = "a@b" + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + self.registration_handler.register = Mock(return_value=(user_id, None)) + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": "guest_device", + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Guest access is disabled" + ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py new file mode 100644 index 0000000000..704cf97a40 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.types +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import sync +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock +from tests.server import make_request, setup_test_homeserver, wait_until_result + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.TestCase): + + USER_ID = b"@apple:test" + TO_REGISTER = [sync] + + def setUp(self): + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.auth = self.hs.get_auth() + + def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.USER_ID), + "token_id": 1, + "is_guest": False, + } + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.USER_ID), 1, False, None + ) + + self.auth.get_user_by_access_token = get_user_by_access_token + self.auth.get_user_by_req = get_user_by_req + + self.store = self.hs.get_datastore() + self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) + + for r in self.TO_REGISTER: + r.register_servlets(self.hs, self.resource) + + def test_sync_argless(self): + request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertTrue( + set( + [ + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + ] + ).issubset(set(channel.json_body.keys())) + ) diff --git a/tests/server.py b/tests/server.py index e93f5a7f94..c611dd6059 100644 --- a/tests/server.py +++ b/tests/server.py @@ -80,6 +80,11 @@ def make_request(method, path, content=b""): content, and return the Request and the Channel underneath. """ + # Decorate it to be the full path + if not path.startswith(b"/_matrix"): + path = b"/_matrix/client/r0/" + path + path = path.replace("//", "/") + if isinstance(content, text_type): content = content.encode('utf8') @@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100): clock.advance(0.1) +def render(request, resource, clock): + request.render(resource) + wait_until_result(clock, request._channel) + + class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. diff --git a/tests/test_server.py b/tests/test_server.py index 4192013f6d..7e063c0290 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase): return (200, kwargs) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo/(?P[^/]*)$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo/(?P[^/]*)$")], _callback + ) - request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") + request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") request.render(res) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) @@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase): raise Exception("boo") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) self.assertEqual(channel.result["code"], b'500') @@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase): return d res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) # No error has been raised yet @@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase): raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foo") + request, channel = make_request(b"GET", b"/_matrix/foo") request.render(res) self.assertEqual(channel.result["code"], b'403') @@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase): self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/foo$")], _callback) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) - request, channel = make_request(b"GET", b"/foobar") + request, channel = make_request(b"GET", b"/_matrix/foobar") request.render(res) self.assertEqual(channel.result["code"], b'400') diff --git a/tests/unittest.py b/tests/unittest.py index b25f2db5d5..b15b06726b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -109,6 +109,17 @@ class TestCase(unittest.TestCase): except AssertionError as e: raise (type(e))(e.message + " for '.%s'" % key) + def assert_dict(self, required, actual): + """Does a partial assert of a dict. + + Args: + required (dict): The keys and value which MUST be in 'actual'. + actual (dict): The test result. Extra keys will not be checked. + """ + for key in required: + self.assertEquals(required[key], actual[key], + msg="%s mismatch. %s" % (key, actual)) + def DEBUG(target): """A decorator to set the .loglevel attribute to logging.DEBUG. -- cgit 1.5.1 From 94440ae9946936339453925c249f81ea463da4d0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 Jul 2018 11:51:26 +0100 Subject: fix imports --- synapse/visibility.py | 1 + tests/test_visibility.py | 1 + 2 files changed, 2 insertions(+) (limited to 'tests') diff --git a/synapse/visibility.py b/synapse/visibility.py index 6209bc0cde..dc33f61d2b 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,6 +17,7 @@ import logging import operator import six + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 1401d831d2..8436c29fe8 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -19,6 +19,7 @@ from twisted.internet.defer import succeed from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server + import tests.unittest from tests.utils import setup_test_homeserver -- cgit 1.5.1 From d897be6a98ff2073235dc1d356a517f085dbf388 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 16 Jul 2018 15:22:27 +0100 Subject: Fix visibility of events from erased users over federation --- synapse/visibility.py | 123 ++++++++++++++++++++++++++--------------------- tests/test_visibility.py | 63 ++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 56 deletions(-) (limited to 'tests') diff --git a/synapse/visibility.py b/synapse/visibility.py index dc33f61d2b..202daa6800 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -232,7 +232,57 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, @defer.inlineCallbacks def filter_events_for_server(store, server_name, events): - # First lets check to see if all the events have a history visibility + # Whatever else we do, we need to check for senders which have requested + # erasure of their data. + erased_senders = yield store.are_users_erased( + e.sender for e in events, + ) + + def redact_disallowed(event, state): + # if the sender has been gdpr17ed, always return a redacted + # copy of the event. + if erased_senders[event.sender]: + logger.info( + "Sender of %s has been erased, redacting", + event.event_id, + ) + return prune_event(event) + + if not state: + return event + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + if visibility in ["invited", "joined"]: + # We now loop through all state events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in state.itervalues(): + if ev.type != EventTypes.Member: + continue + try: + domain = get_domain_from_id(ev.state_key) + except Exception: + continue + + if domain != server_name: + continue + + memtype = ev.membership + if memtype == Membership.JOIN: + return event + elif memtype == Membership.INVITE: + if visibility == "invited": + return event + else: + # server has no users in the room: redact + return prune_event(event) + + return event + + # Next lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( @@ -251,15 +301,24 @@ def filter_events_for_server(store, server_name, events): # If we failed to find any history visibility events then the default # is "shared" visiblity. if not visibility_ids: - defer.returnValue(events) - - event_map = yield store.get_events(visibility_ids) - all_open = all( - e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in event_map.itervalues() - ) + all_open = True + else: + event_map = yield store.get_events(visibility_ids) + all_open = all( + e.content.get("history_visibility") in (None, "shared", "world_readable") + for e in event_map.itervalues() + ) if all_open: + # all the history_visibility state affecting these events is open, so + # we don't need to filter by membership state. We *do* need to check + # for user erasure, though. + if erased_senders: + events = [ + redact_disallowed(e, None) + for e in events + ] + defer.returnValue(events) # Ok, so we're dealing with events that have non-trivial visibility @@ -314,54 +373,6 @@ def filter_events_for_server(store, server_name, events): for e_id, key_to_eid in event_to_state_ids.iteritems() } - erased_senders = yield store.are_users_erased( - e.sender for e in events, - ) - - def redact_disallowed(event, state): - # if the sender has been gdpr17ed, always return a redacted - # copy of the event. - if erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) - return prune_event(event) - - if not state: - return event - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - if visibility in ["invited", "joined"]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.itervalues(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except Exception: - continue - - if domain != server_name: - continue - - memtype = ev.membership - if memtype == Membership.JOIN: - return event - elif memtype == Membership.INVITE: - if visibility == "invited": - return event - else: - # server has no users in the room: redact - return prune_event(event) - - return event - defer.returnValue([ redact_disallowed(e, event_to_state[e.event_id]) for e in events diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8436c29fe8..0dc1a924d3 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -73,6 +73,51 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") + @tests.unittest.DEBUG + @defer.inlineCallbacks + def test_erased_user(self): + # 4 message events, from erased and unerased users, with a membership + # change in the middle of them. + events_to_filter = [] + + evt = yield self.inject_message("@unerased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@erased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_room_member("@joiner:remote_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@unerased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@erased:local_hs") + events_to_filter.append(evt) + + # the erasey user gets erased + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + + # ... and the filtering happens. + filtered = yield filter_events_for_server( + self.store, "test_server", events_to_filter, + ) + + for i in range(0, len(events_to_filter)): + self.assertEqual( + events_to_filter[i].event_id, filtered[i].event_id, + "Unexpected event at result position %i" % (i, ) + ) + + for i in (0, 3): + self.assertEqual( + events_to_filter[i].content["body"], filtered[i].content["body"], + "Unexpected event content at result position %i" % (i,) + ) + + for i in (1, 4): + self.assertNotIn("body", filtered[i].content) + @defer.inlineCallbacks def inject_visibility(self, user_id, visibility): content = {"history_visibility": visibility} @@ -109,6 +154,24 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().persist_event(event, context) defer.returnValue(event) + @defer.inlineCallbacks + def inject_message(self, user_id, content=None): + if content is None: + content = {"body": "testytest"} + builder = self.event_builder_factory.new({ + "type": "m.room.message", + "sender": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + @defer.inlineCallbacks def test_large_room(self): # see what happens when we have a large room with hundreds of thousands -- cgit 1.5.1 From 8c69b735e3186875a8bf676241f990470fb7c592 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 Jul 2018 15:33:13 +0100 Subject: Make Distributor run its processes as a background process This is more involved than it might otherwise be, because the current implementation just drops its logcontexts and runs everything in the sentinel context. It turns out that we aren't actually using a bunch of the functionality here (notably suppress_failures and the fact that Distributor.fire returns a deferred), so the easiest way to fix this is actually by simplifying a bunch of code. --- synapse/util/distributor.py | 44 +++++++++++++++-------------------- tests/test_distributor.py | 56 ++++----------------------------------------- 2 files changed, 22 insertions(+), 78 deletions(-) (limited to 'tests') diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index d91ae400eb..194da87639 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -17,20 +17,18 @@ import logging from twisted.internet import defer -from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) def user_left_room(distributor, user, room_id): - with PreserveLoggingContext(): - distributor.fire("user_left_room", user=user, room_id=room_id) + distributor.fire("user_left_room", user=user, room_id=room_id) def user_joined_room(distributor, user, room_id): - with PreserveLoggingContext(): - distributor.fire("user_joined_room", user=user, room_id=room_id) + distributor.fire("user_joined_room", user=user, room_id=room_id) class Distributor(object): @@ -44,9 +42,7 @@ class Distributor(object): model will do for today. """ - def __init__(self, suppress_failures=True): - self.suppress_failures = suppress_failures - + def __init__(self): self.signals = {} self.pre_registration = {} @@ -56,7 +52,6 @@ class Distributor(object): self.signals[name] = Signal( name, - suppress_failures=self.suppress_failures, ) if name in self.pre_registration: @@ -82,7 +77,11 @@ class Distributor(object): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - return self.signals[name].fire(*args, **kwargs) + run_as_background_process( + name, + self.signals[name].fire, + *args, **kwargs + ) class Signal(object): @@ -95,9 +94,8 @@ class Signal(object): method into all of the observers. """ - def __init__(self, name, suppress_failures): + def __init__(self, name): self.name = name - self.suppress_failures = suppress_failures self.observers = [] def observe(self, observer): @@ -107,7 +105,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -125,22 +122,17 @@ class Signal(object): failure.type, failure.value, failure.getTracebackObject())) - if not self.suppress_failures: - return failure return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - with PreserveLoggingContext(): - deferreds = [ - do(observer) - for observer in self.observers - ] - - res = yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + deferreds = [ + run_in_background(do, o) + for o in self.observers + ] - defer.returnValue(res) + return make_deferred_yieldable(defer.gatherResults( + deferreds, consumeErrors=True, + )) def __repr__(self): return "" % (self.name,) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 04a88056f1..71d11cda77 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +16,6 @@ from mock import Mock, patch -from twisted.internet import defer - from synapse.util.distributor import Distributor from . import unittest @@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase): def setUp(self): self.dist = Distributor() - @defer.inlineCallbacks def test_signal_dispatch(self): self.dist.declare("alert") observer = Mock() self.dist.observe("alert", observer) - d = self.dist.fire("alert", 1, 2, 3) - yield d - self.assertTrue(d.called) + self.dist.fire("alert", 1, 2, 3) observer.assert_called_with(1, 2, 3) - @defer.inlineCallbacks - def test_signal_dispatch_deferred(self): - self.dist.declare("whine") - - d_inner = defer.Deferred() - - def observer(): - return d_inner - - self.dist.observe("whine", observer) - - d_outer = self.dist.fire("whine") - - self.assertFalse(d_outer.called) - - d_inner.callback(None) - yield d_outer - self.assertTrue(d_outer.called) - - @defer.inlineCallbacks def test_signal_catch(self): self.dist.declare("alarm") @@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase): with patch( "synapse.util.distributor.logger", spec=["warning"] ) as mock_logger: - d = self.dist.fire("alarm", "Go") - yield d - self.assertTrue(d.called) + self.dist.fire("alarm", "Go") observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") @@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase): mock_logger.warning.call_args[0][0], str ) - @defer.inlineCallbacks - def test_signal_catch_no_suppress(self): - # Gut-wrenching - self.dist.suppress_failures = False - - self.dist.declare("whail") - - class MyException(Exception): - pass - - @defer.inlineCallbacks - def observer(): - raise MyException("Oopsie") - - self.dist.observe("whail", observer) - - d = self.dist.fire("whail") - - yield self.assertFailure(d, MyException) - self.dist.suppress_failures = True - - @defer.inlineCallbacks def test_signal_prereg(self): observer = Mock() self.dist.observe("flare", observer) self.dist.declare("flare") - yield self.dist.fire("flare", 4, 5) + self.dist.fire("flare", 4, 5) observer.assert_called_with(4, 5) -- cgit 1.5.1 From a97c845271f9a89ebdb7186d4c9d04c099bd1beb Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 19 Jul 2018 20:03:33 +1000 Subject: Move v1-only APIs into their own module & isolate deprecated ones (#3460) --- changelog.d/3460.misc | 0 setup.cfg | 1 + synapse/http/client.py | 6 +- synapse/rest/__init__.py | 43 ++- synapse/rest/client/v1/register.py | 436 ----------------------------- synapse/rest/client/v1/room.py | 5 +- synapse/rest/client/v1_only/__init__.py | 3 + synapse/rest/client/v1_only/base.py | 39 +++ synapse/rest/client/v1_only/register.py | 437 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_events.py | 90 +----- tests/rest/client/v1/test_register.py | 5 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_filter.py | 8 +- tests/rest/client/v2_alpha/test_sync.py | 8 +- 14 files changed, 549 insertions(+), 534 deletions(-) create mode 100644 changelog.d/3460.misc delete mode 100644 synapse/rest/client/v1/register.py create mode 100644 synapse/rest/client/v1_only/__init__.py create mode 100644 synapse/rest/client/v1_only/base.py create mode 100644 synapse/rest/client/v1_only/register.py (limited to 'tests') diff --git a/changelog.d/3460.misc b/changelog.d/3460.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/setup.cfg b/setup.cfg index 3f84283a38..c2620be6c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,3 +36,4 @@ known_compat = mock,six known_twisted=twisted,OpenSSL multi_line_output=3 include_trailing_comma=true +combine_as_imports=true diff --git a/synapse/http/client.py b/synapse/http/client.py index d6a0d75b2b..25b6307884 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent -from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import ( + Agent, + BrowserLikeRedirectAgent, + ContentDecoderAgent, + FileBodyProducer as TwistedFileBodyProducer, GzipDecoder, HTTPConnectionPool, PartialDownloadError, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 75c2a4ec8e..3418f06fd6 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import PY3 + from synapse.http.server import JsonResource from synapse.rest.client import versions -from synapse.rest.client.v1 import admin, directory, events, initial_sync -from synapse.rest.client.v1 import login as v1_login -from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher -from synapse.rest.client.v1 import register as v1_register -from synapse.rest.client.v1 import room, voip +from synapse.rest.client.v1 import ( + admin, + directory, + events, + initial_sync, + login as v1_login, + logout, + presence, + profile, + push_rule, + pusher, + room, + voip, +) from synapse.rest.client.v2_alpha import ( account, account_data, @@ -42,6 +54,11 @@ from synapse.rest.client.v2_alpha import ( user_directory, ) +if not PY3: + from synapse.rest.client.v1_only import ( + register as v1_register, + ) + class ClientRestResource(JsonResource): """A resource for version 1 of the matrix client API.""" @@ -54,14 +71,22 @@ class ClientRestResource(JsonResource): def register_servlets(client_resource, hs): versions.register_servlets(client_resource) - # "v1" - room.register_servlets(hs, client_resource) + if not PY3: + # "v1" (Python 2 only) + v1_register.register_servlets(hs, client_resource) + + # Deprecated in r0 + initial_sync.register_servlets(hs, client_resource) + room.register_deprecated_servlets(hs, client_resource) + + # Partially deprecated in r0 events.register_servlets(hs, client_resource) - v1_register.register_servlets(hs, client_resource) + + # "v1" + "r0" + room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource) - initial_sync.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py deleted file mode 100644 index 25a143af8d..0000000000 --- a/synapse/rest/client/v1/register.py +++ /dev/null @@ -1,436 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains REST servlets to do with registration: /register""" -import hmac -import logging -from hashlib import sha1 - -from twisted.internet import defer - -import synapse.util.stringutils as stringutils -from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request -from synapse.types import create_requester - -from .base import ClientV1RestServlet, client_path_patterns - -logger = logging.getLogger(__name__) - - -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - def compare_digest(a, b): - return a == b - - -class RegisterRestServlet(ClientV1RestServlet): - """Handles registration with the home server. - - This servlet is in control of the registration flow; the registration - handler doesn't have a concept of multi-stages or sessions. - """ - - PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(RegisterRestServlet, self).__init__(hs) - # sessions are stored as: - # self.sessions = { - # "session_id" : { __session_dict__ } - # } - # TODO: persistent storage - self.sessions = {} - self.enable_registration = hs.config.enable_registration - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.handlers = hs.get_handlers() - - def on_GET(self, request): - - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid - - flows = [] - if self.hs.config.enable_registration_captcha: - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([ - { - "type": LoginType.RECAPTCHA, - "stages": [ - LoginType.RECAPTCHA, - LoginType.EMAIL_IDENTITY, - LoginType.PASSWORD - ] - }, - ]) - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([ - { - "type": LoginType.RECAPTCHA, - "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] - } - ]) - else: - # only support the email-only flow if we don't require MSISDN 3PIDs - if require_email or not require_msisdn: - flows.extend([ - { - "type": LoginType.EMAIL_IDENTITY, - "stages": [ - LoginType.EMAIL_IDENTITY, LoginType.PASSWORD - ] - } - ]) - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([ - { - "type": LoginType.PASSWORD - } - ]) - return (200, {"flows": flows}) - - @defer.inlineCallbacks - def on_POST(self, request): - register_json = parse_json_object_from_request(request) - - session = (register_json["session"] - if "session" in register_json else None) - login_type = None - assert_params_in_dict(register_json, ["type"]) - - try: - login_type = register_json["type"] - - is_application_server = login_type == LoginType.APPLICATION_SERVICE - is_using_shared_secret = login_type == LoginType.SHARED_SECRET - - can_register = ( - self.enable_registration - or is_application_server - or is_using_shared_secret - ) - if not can_register: - raise SynapseError(403, "Registration has been disabled") - - stages = { - LoginType.RECAPTCHA: self._do_recaptcha, - LoginType.PASSWORD: self._do_password, - LoginType.EMAIL_IDENTITY: self._do_email_identity, - LoginType.APPLICATION_SERVICE: self._do_app_service, - LoginType.SHARED_SECRET: self._do_shared_secret, - } - - session_info = self._get_session_info(request, session) - logger.debug("%s : session info %s request info %s", - login_type, session_info, register_json) - response = yield stages[login_type]( - request, - register_json, - session_info - ) - - if "access_token" not in response: - # isn't a final response - response["session"] = session_info["id"] - - defer.returnValue((200, response)) - except KeyError as e: - logger.exception(e) - raise SynapseError(400, "Missing JSON keys for login type %s." % ( - login_type, - )) - - def on_OPTIONS(self, request): - return (200, {}) - - def _get_session_info(self, request, session_id): - if not session_id: - # create a new session - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - LoginType.EMAIL_IDENTITY: False, - LoginType.RECAPTCHA: False - } - - return self.sessions[session_id] - - def _save_session(self, session): - # TODO: Persistent storage - logger.debug("Saving session %s", session) - self.sessions[session["id"]] = session - - def _remove_session(self, session): - logger.debug("Removing session %s", session) - self.sessions.pop(session["id"]) - - @defer.inlineCallbacks - def _do_recaptcha(self, request, register_json, session): - if not self.hs.config.enable_registration_captcha: - raise SynapseError(400, "Captcha not required.") - - yield self._check_recaptcha(request, register_json, session) - - session[LoginType.RECAPTCHA] = True # mark captcha as done - self._save_session(session) - defer.returnValue({ - "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY] - }) - - @defer.inlineCallbacks - def _check_recaptcha(self, request, register_json, session): - if ("captcha_bypass_hmac" in register_json and - self.hs.config.captcha_bypass_secret): - if "user" not in register_json: - raise SynapseError(400, "Captcha bypass needs 'user'") - - want = hmac.new( - key=self.hs.config.captcha_bypass_secret, - msg=register_json["user"], - digestmod=sha1, - ).hexdigest() - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got = str(register_json["captcha_bypass_hmac"]) - - if compare_digest(want, got): - session["user"] = register_json["user"] - defer.returnValue(None) - else: - raise SynapseError( - 400, "Captcha bypass HMAC incorrect", - errcode=Codes.CAPTCHA_NEEDED - ) - - challenge = None - user_response = None - try: - challenge = register_json["challenge"] - user_response = register_json["response"] - except KeyError: - raise SynapseError(400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED) - - ip_addr = self.hs.get_ip_from_request(request) - - handler = self.handlers.registration_handler - yield handler.check_recaptcha( - ip_addr, - self.hs.config.recaptcha_private_key, - challenge, - user_response - ) - - @defer.inlineCallbacks - def _do_email_identity(self, request, register_json, session): - if (self.hs.config.enable_registration_captcha and - not session[LoginType.RECAPTCHA]): - raise SynapseError(400, "Captcha is required.") - - threepidCreds = register_json['threepidCreds'] - handler = self.handlers.registration_handler - logger.debug("Registering email. threepidcreds: %s" % (threepidCreds)) - yield handler.register_email(threepidCreds) - session["threepidCreds"] = threepidCreds # store creds for next stage - session[LoginType.EMAIL_IDENTITY] = True # mark email as done - self._save_session(session) - defer.returnValue({ - "next": LoginType.PASSWORD - }) - - @defer.inlineCallbacks - def _do_password(self, request, register_json, session): - if (self.hs.config.enable_registration_captcha and - not session[LoginType.RECAPTCHA]): - # captcha should've been done by this stage! - raise SynapseError(400, "Captcha is required.") - - if ("user" in session and "user" in register_json and - session["user"] != register_json["user"]): - raise SynapseError( - 400, "Cannot change user ID during registration" - ) - - password = register_json["password"].encode("utf-8") - desired_user_id = ( - register_json["user"].encode("utf-8") - if "user" in register_json else None - ) - - handler = self.handlers.registration_handler - (user_id, token) = yield handler.register( - localpart=desired_user_id, - password=password - ) - - if session[LoginType.EMAIL_IDENTITY]: - logger.debug("Binding emails %s to %s" % ( - session["threepidCreds"], user_id) - ) - yield handler.bind_emails(user_id, session["threepidCreds"]) - - result = { - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - } - self._remove_session(session) - defer.returnValue(result) - - @defer.inlineCallbacks - def _do_app_service(self, request, register_json, session): - as_token = self.auth.get_access_token_from_request(request) - - assert_params_in_dict(register_json, ["user"]) - user_localpart = register_json["user"].encode("utf-8") - - handler = self.handlers.registration_handler - user_id = yield handler.appservice_register( - user_localpart, as_token - ) - token = yield self.auth_handler.issue_access_token(user_id) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - - @defer.inlineCallbacks - def _do_shared_secret(self, request, register_json, session): - assert_params_in_dict(register_json, ["mac", "user", "password"]) - - if not self.hs.config.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") - - user = register_json["user"].encode("utf-8") - password = register_json["password"].encode("utf-8") - admin = register_json.get("admin", None) - - # Its important to check as we use null bytes as HMAC field separators - if b"\x00" in user: - raise SynapseError(400, "Invalid user") - if b"\x00" in password: - raise SynapseError(400, "Invalid password") - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got_mac = str(register_json["mac"]) - - want_mac = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), - digestmod=sha1, - ) - want_mac.update(user) - want_mac.update(b"\x00") - want_mac.update(password) - want_mac.update(b"\x00") - want_mac.update(b"admin" if admin else b"notadmin") - want_mac = want_mac.hexdigest() - - if compare_digest(want_mac, got_mac): - handler = self.handlers.registration_handler - user_id, token = yield handler.register( - localpart=user.lower(), - password=password, - admin=bool(admin), - ) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - else: - raise SynapseError( - 403, "HMAC incorrect", - ) - - -class CreateUserRestServlet(ClientV1RestServlet): - """Handles user creation via a server-to-server interface - """ - - PATTERNS = client_path_patterns("/createUser$", releases=()) - - def __init__(self, hs): - super(CreateUserRestServlet, self).__init__(hs) - self.store = hs.get_datastore() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_POST(self, request): - user_json = parse_json_object_from_request(request) - - access_token = self.auth.get_access_token_from_request(request) - app_service = self.store.get_app_service_by_token( - access_token - ) - if not app_service: - raise SynapseError(403, "Invalid application service token.") - - requester = create_requester(app_service.sender) - - logger.debug("creating user: %s", user_json) - response = yield self._do_create(requester, user_json) - - defer.returnValue((200, response)) - - def on_OPTIONS(self, request): - return 403, {} - - @defer.inlineCallbacks - def _do_create(self, requester, user_json): - assert_params_in_dict(user_json, ["localpart", "displayname"]) - - localpart = user_json["localpart"].encode("utf-8") - displayname = user_json["displayname"].encode("utf-8") - password_hash = user_json["password_hash"].encode("utf-8") \ - if user_json.get("password_hash") else None - - handler = self.handlers.registration_handler - user_id, token = yield handler.get_or_create_user( - requester=requester, - localpart=localpart, - displayname=displayname, - password_hash=password_hash - ) - - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - - -def register_servlets(hs, http_server): - RegisterRestServlet(hs).register(http_server) - CreateUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 3d62447854..b9512a2b61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -832,10 +832,13 @@ def register_servlets(hs, http_server): RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) - RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + + +def register_deprecated_servlets(hs, http_server): + RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py new file mode 100644 index 0000000000..936f902ace --- /dev/null +++ b/synapse/rest/client/v1_only/__init__.py @@ -0,0 +1,3 @@ +""" +REST APIs that are only used in v1 (the legacy API). +""" diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py new file mode 100644 index 0000000000..9d4db7437c --- /dev/null +++ b/synapse/rest/client/v1_only/base.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +import re + +from synapse.api.urls import CLIENT_PREFIX + + +def v1_only_client_path_patterns(path_regex, include_in_unstable=True): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + list of SRE_Pattern + """ + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + if include_in_unstable: + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + return patterns diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py new file mode 100644 index 0000000000..3439c3c6d4 --- /dev/null +++ b/synapse/rest/client/v1_only/register.py @@ -0,0 +1,437 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains REST servlets to do with registration: /register""" +import hmac +import logging +from hashlib import sha1 + +from twisted.internet import defer + +import synapse.util.stringutils as stringutils +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.rest.client.v1.base import ClientV1RestServlet +from synapse.types import create_requester + +from .base import v1_only_client_path_patterns + +logger = logging.getLogger(__name__) + + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + def compare_digest(a, b): + return a == b + + +class RegisterRestServlet(ClientV1RestServlet): + """Handles registration with the home server. + + This servlet is in control of the registration flow; the registration + handler doesn't have a concept of multi-stages or sessions. + """ + + PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RegisterRestServlet, self).__init__(hs) + # sessions are stored as: + # self.sessions = { + # "session_id" : { __session_dict__ } + # } + # TODO: persistent storage + self.sessions = {} + self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() + + def on_GET(self, request): + + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] + if self.hs.config.enable_registration_captcha: + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([ + { + "type": LoginType.RECAPTCHA, + "stages": [ + LoginType.RECAPTCHA, + LoginType.EMAIL_IDENTITY, + LoginType.PASSWORD + ] + }, + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ + { + "type": LoginType.RECAPTCHA, + "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] + } + ]) + else: + # only support the email-only flow if we don't require MSISDN 3PIDs + if require_email or not require_msisdn: + flows.extend([ + { + "type": LoginType.EMAIL_IDENTITY, + "stages": [ + LoginType.EMAIL_IDENTITY, LoginType.PASSWORD + ] + } + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ + { + "type": LoginType.PASSWORD + } + ]) + return (200, {"flows": flows}) + + @defer.inlineCallbacks + def on_POST(self, request): + register_json = parse_json_object_from_request(request) + + session = (register_json["session"] + if "session" in register_json else None) + login_type = None + assert_params_in_dict(register_json, ["type"]) + + try: + login_type = register_json["type"] + + is_application_server = login_type == LoginType.APPLICATION_SERVICE + is_using_shared_secret = login_type == LoginType.SHARED_SECRET + + can_register = ( + self.enable_registration + or is_application_server + or is_using_shared_secret + ) + if not can_register: + raise SynapseError(403, "Registration has been disabled") + + stages = { + LoginType.RECAPTCHA: self._do_recaptcha, + LoginType.PASSWORD: self._do_password, + LoginType.EMAIL_IDENTITY: self._do_email_identity, + LoginType.APPLICATION_SERVICE: self._do_app_service, + LoginType.SHARED_SECRET: self._do_shared_secret, + } + + session_info = self._get_session_info(request, session) + logger.debug("%s : session info %s request info %s", + login_type, session_info, register_json) + response = yield stages[login_type]( + request, + register_json, + session_info + ) + + if "access_token" not in response: + # isn't a final response + response["session"] = session_info["id"] + + defer.returnValue((200, response)) + except KeyError as e: + logger.exception(e) + raise SynapseError(400, "Missing JSON keys for login type %s." % ( + login_type, + )) + + def on_OPTIONS(self, request): + return (200, {}) + + def _get_session_info(self, request, session_id): + if not session_id: + # create a new session + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + self.sessions[session_id] = { + "id": session_id, + LoginType.EMAIL_IDENTITY: False, + LoginType.RECAPTCHA: False + } + + return self.sessions[session_id] + + def _save_session(self, session): + # TODO: Persistent storage + logger.debug("Saving session %s", session) + self.sessions[session["id"]] = session + + def _remove_session(self, session): + logger.debug("Removing session %s", session) + self.sessions.pop(session["id"]) + + @defer.inlineCallbacks + def _do_recaptcha(self, request, register_json, session): + if not self.hs.config.enable_registration_captcha: + raise SynapseError(400, "Captcha not required.") + + yield self._check_recaptcha(request, register_json, session) + + session[LoginType.RECAPTCHA] = True # mark captcha as done + self._save_session(session) + defer.returnValue({ + "next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY] + }) + + @defer.inlineCallbacks + def _check_recaptcha(self, request, register_json, session): + if ("captcha_bypass_hmac" in register_json and + self.hs.config.captcha_bypass_secret): + if "user" not in register_json: + raise SynapseError(400, "Captcha bypass needs 'user'") + + want = hmac.new( + key=self.hs.config.captcha_bypass_secret, + msg=register_json["user"], + digestmod=sha1, + ).hexdigest() + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got = str(register_json["captcha_bypass_hmac"]) + + if compare_digest(want, got): + session["user"] = register_json["user"] + defer.returnValue(None) + else: + raise SynapseError( + 400, "Captcha bypass HMAC incorrect", + errcode=Codes.CAPTCHA_NEEDED + ) + + challenge = None + user_response = None + try: + challenge = register_json["challenge"] + user_response = register_json["response"] + except KeyError: + raise SynapseError(400, "Captcha response is required", + errcode=Codes.CAPTCHA_NEEDED) + + ip_addr = self.hs.get_ip_from_request(request) + + handler = self.handlers.registration_handler + yield handler.check_recaptcha( + ip_addr, + self.hs.config.recaptcha_private_key, + challenge, + user_response + ) + + @defer.inlineCallbacks + def _do_email_identity(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + raise SynapseError(400, "Captcha is required.") + + threepidCreds = register_json['threepidCreds'] + handler = self.handlers.registration_handler + logger.debug("Registering email. threepidcreds: %s" % (threepidCreds)) + yield handler.register_email(threepidCreds) + session["threepidCreds"] = threepidCreds # store creds for next stage + session[LoginType.EMAIL_IDENTITY] = True # mark email as done + self._save_session(session) + defer.returnValue({ + "next": LoginType.PASSWORD + }) + + @defer.inlineCallbacks + def _do_password(self, request, register_json, session): + if (self.hs.config.enable_registration_captcha and + not session[LoginType.RECAPTCHA]): + # captcha should've been done by this stage! + raise SynapseError(400, "Captcha is required.") + + if ("user" in session and "user" in register_json and + session["user"] != register_json["user"]): + raise SynapseError( + 400, "Cannot change user ID during registration" + ) + + password = register_json["password"].encode("utf-8") + desired_user_id = ( + register_json["user"].encode("utf-8") + if "user" in register_json else None + ) + + handler = self.handlers.registration_handler + (user_id, token) = yield handler.register( + localpart=desired_user_id, + password=password + ) + + if session[LoginType.EMAIL_IDENTITY]: + logger.debug("Binding emails %s to %s" % ( + session["threepidCreds"], user_id) + ) + yield handler.bind_emails(user_id, session["threepidCreds"]) + + result = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + } + self._remove_session(session) + defer.returnValue(result) + + @defer.inlineCallbacks + def _do_app_service(self, request, register_json, session): + as_token = self.auth.get_access_token_from_request(request) + + assert_params_in_dict(register_json, ["user"]) + user_localpart = register_json["user"].encode("utf-8") + + handler = self.handlers.registration_handler + user_id = yield handler.appservice_register( + user_localpart, as_token + ) + token = yield self.auth_handler.issue_access_token(user_id) + self._remove_session(session) + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + @defer.inlineCallbacks + def _do_shared_secret(self, request, register_json, session): + assert_params_in_dict(register_json, ["mac", "user", "password"]) + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) + + # Its important to check as we use null bytes as HMAC field separators + if b"\x00" in user: + raise SynapseError(400, "Invalid user") + if b"\x00" in password: + raise SynapseError(400, "Invalid password") + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got_mac = str(register_json["mac"]) + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=sha1, + ) + want_mac.update(user) + want_mac.update(b"\x00") + want_mac.update(password) + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") + want_mac = want_mac.hexdigest() + + if compare_digest(want_mac, got_mac): + handler = self.handlers.registration_handler + user_id, token = yield handler.register( + localpart=user.lower(), + password=password, + admin=bool(admin), + ) + self._remove_session(session) + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + else: + raise SynapseError( + 403, "HMAC incorrect", + ) + + +class CreateUserRestServlet(ClientV1RestServlet): + """Handles user creation via a server-to-server interface + """ + + PATTERNS = v1_only_client_path_patterns("/createUser$") + + def __init__(self, hs): + super(CreateUserRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_POST(self, request): + user_json = parse_json_object_from_request(request) + + access_token = self.auth.get_access_token_from_request(request) + app_service = self.store.get_app_service_by_token( + access_token + ) + if not app_service: + raise SynapseError(403, "Invalid application service token.") + + requester = create_requester(app_service.sender) + + logger.debug("creating user: %s", user_json) + response = yield self._do_create(requester, user_json) + + defer.returnValue((200, response)) + + def on_OPTIONS(self, request): + return 403, {} + + @defer.inlineCallbacks + def _do_create(self, requester, user_json): + assert_params_in_dict(user_json, ["localpart", "displayname"]) + + localpart = user_json["localpart"].encode("utf-8") + displayname = user_json["displayname"].encode("utf-8") + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None + + handler = self.handlers.registration_handler + user_id, token = yield handler.get_or_create_user( + requester=requester, + localpart=localpart, + displayname=displayname, + password_hash=password_hash + ) + + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) + CreateUserRestServlet(hs).register(http_server) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index a5af36a99c..50418153fa 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -14,100 +14,30 @@ # limitations under the License. """ Tests REST events for /events paths.""" + from mock import Mock, NonCallableMock +from six import PY3 -# twisted imports from twisted.internet import defer -import synapse.rest.client.v1.events -import synapse.rest.client.v1.register -import synapse.rest.client.v1.room - -from tests import unittest - from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase PATH_PREFIX = "/_matrix/client/api/v1" -class EventStreamPaginationApiTestCase(unittest.TestCase): - """ Tests event streaming query parameters and start/end keys used in the - Pagination stream API. """ - user_id = "sid1" - - def setUp(self): - # configure stream and inject items - pass - - def tearDown(self): - pass - - def TODO_test_long_poll(self): - # stream from 'end' key, send (self+other) message, expect message. - - # stream from 'END', send (self+other) message, expect message. - - # stream from 'end' key, send (self+other) topic, expect topic. - - # stream from 'END', send (self+other) topic, expect topic. - - # stream from 'end' key, send (self+other) invite, expect invite. - - # stream from 'END', send (self+other) invite, expect invite. - - pass - - def TODO_test_stream_forward(self): - # stream from START, expect injected items - - # stream from 'start' key, expect same content - - # stream from 'end' key, expect nothing - - # stream from 'END', expect nothing - - # The following is needed for cases where content is removed e.g. you - # left a room, so the token you're streaming from is > the one that - # would be returned naturally from START>END. - # stream from very new token (higher than end key), expect same token - # returned as end key - pass - - def TODO_test_limits(self): - # stream from a key, expect limit_num items - - # stream from START, expect limit_num items - - pass - - def TODO_test_range(self): - # stream from key to key, expect X items - - # stream from key to END, expect X items - - # stream from START to key, expect X items - - # stream from START to END, expect all items - pass - - def TODO_test_direction(self): - # stream from END to START and fwds, expect newest first - - # stream from END to START and bwds, expect oldest first - - # stream from START to END and fwds, expect oldest first - - # stream from START to END and bwds, expect newest first - - pass - - class EventStreamPermissionsTestCase(RestTestCase): """ Tests event streaming (GET /events). """ + if PY3: + skip = "Skip on Py3 until ported to use not V1 only register." + @defer.inlineCallbacks def setUp(self): + import synapse.rest.client.v1.events + import synapse.rest.client.v1_only.register + import synapse.rest.client.v1.room + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( @@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) + synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index f15fb36213..83a23cd8fe 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -16,11 +16,12 @@ import json from mock import Mock +from six import PY3 from twisted.test.proto_helpers import MemoryReactorClock from synapse.http.server import JsonResource -from synapse.rest.client.v1.register import register_servlets +from synapse.rest.client.v1_only.register import register_servlets from synapse.util import Clock from tests import unittest @@ -31,6 +32,8 @@ class CreateUserServletTestCase(unittest.TestCase): """ Tests for CreateUserRestServlet. """ + if PY3: + skip = "Not ported to Python 3." def setUp(self): self.registration_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6b5764095e..00fc796787 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,6 @@ import json from mock import Mock, NonCallableMock from six.moves.urllib import parse as urlparse -# twisted imports from twisted.internet import defer import synapse.rest.client.v1.room @@ -86,6 +85,7 @@ class RoomBase(unittest.TestCase): self.resource = JsonResource(self.hs) synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) + synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) self.helper = RestHelper(self.hs, self.resource, self.user_id) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 5ea9cc825f..e890f0feac 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -21,8 +21,12 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest -from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock -from tests.server import make_request, setup_test_homeserver, wait_until_result +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 704cf97a40..03ec3993b2 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -20,8 +20,12 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest -from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock -from tests.server import make_request, setup_test_homeserver, wait_until_result +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) PATH_PREFIX = "/_matrix/client/v2_alpha" -- cgit 1.5.1