diff options
-rw-r--r-- | changelog.d/7420.misc | 1 | ||||
-rw-r--r-- | changelog.d/7423.misc | 1 | ||||
-rw-r--r-- | changelog.d/7427.feature | 1 | ||||
-rw-r--r-- | synapse/api/auth.py | 83 | ||||
-rw-r--r-- | synapse/api/auth_blocking.py | 104 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 12 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 51 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 52 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/devices.py | 4 | ||||
-rw-r--r-- | synapse/util/caches/stream_change_cache.py | 19 | ||||
-rw-r--r-- | tests/api/test_auth.py | 36 | ||||
-rw-r--r-- | tests/handlers/test_auth.py | 23 | ||||
-rw-r--r-- | tests/handlers/test_sync.py | 13 | ||||
-rw-r--r-- | tests/test_mau.py | 14 |
14 files changed, 264 insertions, 150 deletions
diff --git a/changelog.d/7420.misc b/changelog.d/7420.misc new file mode 100644 index 0000000000..e834a9163e --- /dev/null +++ b/changelog.d/7420.misc @@ -0,0 +1 @@ +Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request. \ No newline at end of file diff --git a/changelog.d/7423.misc b/changelog.d/7423.misc new file mode 100644 index 0000000000..eb1767ac13 --- /dev/null +++ b/changelog.d/7423.misc @@ -0,0 +1 @@ +Speed up fetching device lists changes when handling `/sync` requests. diff --git a/changelog.d/7427.feature b/changelog.d/7427.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7427.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c5d1eb952b..1ad5ff9410 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,16 +26,15 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes +from synapse.api.auth_blocking import AuthBlocking +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, InvalidClientTokenError, MissingClientTokenError, - ResourceLimitError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.config.server import is_threepid_reserved from synapse.events import EventBase from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -77,7 +76,11 @@ class Auth(object): self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) + self._auth_blocking = AuthBlocking(self.hs) + self._account_validity = hs.config.account_validity + self._track_appservice_user_ips = hs.config.track_appservice_user_ips + self._macaroon_secret_key = hs.config.macaroon_secret_key @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): @@ -191,7 +194,7 @@ class Auth(object): opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self.hs.config.track_appservice_user_ips: + if ip_addr and self._track_appservice_user_ips: yield self.store.insert_client_ip( user_id=user_id, access_token=access_token, @@ -454,7 +457,7 @@ class Auth(object): # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self._macaroon_secret_key) def _verify_expiry(self, caveat): prefix = "time < " @@ -663,71 +666,5 @@ class Auth(object): % (user_id, room_id), ) - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): - """Checks if the user should be rejected for some external reason, - such as monthly active user limiting or global disable flag - - Args: - user_id(str|None): If present, checks for presence against existing - MAU cohort - - threepid(dict|None): If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. - - user_type(str|None): If present, is used to decide whether to check against - certain blocking reasons like MAU. - """ - - # Never fail an auth check for the server notices users or support user - # This can be a problem where event creation is prohibited due to blocking - if user_id is not None: - if user_id == self.hs.config.server_notices_mxid: - return - if (yield self.store.is_support_user(user_id)): - return - - if self.hs.config.hs_disabled: - raise ResourceLimitError( - 403, - self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=self.hs.config.admin_contact, - limit_type=LimitBlockingTypes.HS_DISABLED, - ) - if self.hs.config.limit_usage_by_mau is True: - assert not (user_id and threepid) - - # If the user is already part of the MAU cohort or a trial user - if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) - if timestamp: - return - - is_trial = yield self.store.is_trial_user(user_id) - if is_trial: - return - elif threepid: - # If the user does not exist yet, but is signing up with a - # reserved threepid then pass auth check - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - return - elif user_type == UserTypes.SUPPORT: - # If the user does not exist yet and is of type "support", - # allow registration. Support users are excluded from MAU checks. - return - # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() - if current_mau >= self.hs.config.max_mau_value: - raise ResourceLimitError( - 403, - "Monthly Active User Limit Exceeded", - admin_contact=self.hs.config.admin_contact, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, - ) + def check_auth_blocking(self, *args, **kwargs): + return self._auth_blocking.check_auth_blocking(*args, **kwargs) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py new file mode 100644 index 0000000000..5c499b6b4e --- /dev/null +++ b/synapse/api/auth_blocking.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import LimitBlockingTypes, UserTypes +from synapse.api.errors import Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved + +logger = logging.getLogger(__name__) + + +class AuthBlocking(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + self._server_notices_mxid = hs.config.server_notices_mxid + self._hs_disabled = hs.config.hs_disabled + self._hs_disabled_message = hs.config.hs_disabled_message + self._admin_contact = hs.config.admin_contact + self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + + @defer.inlineCallbacks + def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + + Args: + user_id(str|None): If present, checks for presence against existing + MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. + + user_type(str|None): If present, is used to decide whether to check against + certain blocking reasons like MAU. + """ + + # Never fail an auth check for the server notices users or support user + # This can be a problem where event creation is prohibited due to blocking + if user_id is not None: + if user_id == self._server_notices_mxid: + return + if (yield self.store.is_support_user(user_id)): + return + + if self._hs_disabled: + raise ResourceLimitError( + 403, + self._hs_disabled_message, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self._admin_contact, + limit_type=LimitBlockingTypes.HS_DISABLED, + ) + if self._limit_usage_by_mau is True: + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user + if user_id: + timestamp = yield self.store.user_last_seen_monthly_active(user_id) + if timestamp: + return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid): + return + elif user_type == UserTypes.SUPPORT: + # If the user does not exist yet and is of type "support", + # allow registration. Support users are excluded from MAU checks. + return + # Else if there is no room in the MAU bucket, bail + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self._max_mau_value: + raise ResourceLimitError( + 403, + "Monthly Active User Limit Exceeded", + admin_contact=self._admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4f76b7a743..00718d7f2d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1143,10 +1143,14 @@ class SyncHandler(object): user_id ) - tracked_users = set(users_who_share_room) - - # Always tell the user about their own devices - tracked_users.add(user_id) + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index bf4f1a5949..b14a3d9fca 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -81,9 +81,6 @@ class ReplicationCommandHandler: self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() - # Set of streams that we've caught up with. - self._streams_connected = set() # type: Set[str] - self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -99,9 +96,13 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReconnectingClientFactory] - # The currently connected connections. + # The currently connected connections. (The list of places we need to send + # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] + # For each connection, the incoming streams that are coming from that connection + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -257,9 +258,11 @@ class ReplicationCommandHandler: # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. with await self._position_linearizer.queue(cmd.stream_name): - if stream_name not in self._streams_connected: - # If the stream isn't marked as connected then we haven't seen a - # `POSITION` command yet, and so we may have missed some rows. + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. logger.debug( @@ -302,21 +305,25 @@ class ReplicationCommandHandler: # Ignore POSITION that are just our own echoes return - stream = self._streams.get(cmd.stream_name) + logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) + + stream_name = cmd.stream_name + stream = self._streams.get(stream_name) if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + logger.error("Got POSITION for unknown stream: %s", stream_name) return # We protect catching up with a linearizer in case the replication # connection reconnects under us. - with await self._position_linearizer.queue(cmd.stream_name): + with await self._position_linearizer.queue(stream_name): # We're about to go and catch up with the stream, so remove from set # of connected streams. - self._streams_connected.discard(cmd.stream_name) + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) # We clear the pending batches for the stream as the fetching of the # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(cmd.stream_name, []) + self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. current_token = stream.current_token() @@ -326,6 +333,12 @@ class ReplicationCommandHandler: # between then and now. missing_updates = cmd.token != current_token while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) ( updates, current_token, @@ -341,16 +354,18 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, + stream_name, cmd.instance_name, token, [stream.parse_row(row) for row in rows], ) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + await self._replication_data_handler.on_position(stream_name, cmd.token) - self._streams_connected.add(cmd.stream_name) + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand @@ -408,6 +423,12 @@ class ReplicationCommandHandler: def lost_connection(self, connection: AbstractConnection): """Called when a connection is closed/lost. """ + # we no longer need _streams_by_connection for this connection. + streams = self._streams_by_connection.pop(connection, None) + if streams: + logger.info( + "Lost replication connection; streams now disconnected: %s", streams + ) try: self._connections.remove(connection) except ValueError: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 41c623d737..db69f92557 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import PreserveLoggingContext +from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( Command, @@ -41,8 +41,14 @@ logger = logging.getLogger(__name__) class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. - Parses incoming messages from redis into replication commands, and passes - them to `ReplicationCommandHandler` + This class fulfils two functions: + + (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis + connection, parsing *incoming* messages into replication commands, and passing them + to `ReplicationCommandHandler` + + (b) it implements the AbstractConnection API, where it sends *outgoing* commands + onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom constructor, so instead we expect the defined attributes below to be set @@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Attributes: handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to (not anything to - do with Synapse replication streams). + stream_name: The *redis* stream name to subscribe to and publish from + (not anything to do with Synapse replication streams). outbound_redis_connection: The connection to redis to use to send commands. """ @@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): + logger.info("Connected to redis") super().connectionMade() - logger.info("Connected to redis instance") - self.subscribe(self.stream_name) - self.send_command(ReplicateCommand()) - + run_as_background_process("subscribe-replication", self._send_subscribe) self.handler.new_connection(self) + async def _send_subscribe(self): + # it's important to make sure that we only send the REPLICATE command once we + # have successfully subscribed to the stream - otherwise we might miss the + # POSITION response sent back by the other end. + logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) + await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info( + "Successfully subscribed to redis stream, sending REPLICATE command" + ) + await self._async_send_command(ReplicateCommand()) + logger.info("REPLICATE successfully sent") + def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ @@ -120,8 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): + logger.info("Lost connection to redis") super().connectionLost(reason) - logger.info("Lost connection to redis instance") self.handler.lost_connection(self) def send_command(self, cmd: Command): @@ -130,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ + run_as_background_process("send-cmd", self._async_send_command, cmd) + + async def _async_send_command(self, cmd: Command): + """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -140,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() - async def _send(): - with PreserveLoggingContext(): - # Note that we use the other connection as we can't send - # commands using the subscription connection. - await self.outbound_redis_connection.publish( - self.stream_name, encoded_string - ) - - run_as_background_process("send-cmd", _send) + await make_deferred_yieldable( + self.outbound_redis_connection.publish(self.stream_name, encoded_string) + ) class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 536cef3abd..fe6d6ecfe0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -536,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore): # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = list( - self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key ) if not to_check: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 38dc3f501e..e54f80d76e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,12 +14,13 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Mapping, Optional, Set +from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types from sortedcontainers import SortedDict +from synapse.types import Collection from synapse.util import caches logger = logging.getLogger(__name__) @@ -85,8 +86,8 @@ class StreamChangeCache: return False def get_entities_changed( - self, entities: Iterable[EntityType], stream_pos: int - ) -> Set[EntityType]: + self, entities: Collection[EntityType], stream_pos: int + ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the @@ -94,7 +95,17 @@ class StreamChangeCache: """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: - result = set(changed_entities).intersection(entities) + # We now do an intersection, trying to do so in the most efficient + # way possible (some of these sets are *large*). First check in the + # given iterable is already set that we can reuse, otherwise we + # create a set of the *smallest* of the two iterables and call + # `intersection(..)` on it (this can be twice as fast as the reverse). + if isinstance(entities, (set, frozenset)): + result = entities.intersection(changed_entities) + elif len(changed_entities) < len(entities): + result = set(changed_entities).intersection(entities) + else: + result = set(entities).intersection(changed_entities) self.metrics.inc_hits() else: result = set(entities) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cc0b10e7f6..0bfb86bf1f 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.auth._auth_blocking + self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 50 lots_of_users = 100 small_number_of_users = 1 # Ensure no error thrown yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) @@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): - self.hs.config.max_mau_value = 50 - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed @@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_reserved_threepid(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} - self.hs.config.mau_limits_reserved_threepids = [threepid] + self.auth_blocking._mau_limits_reserved_threepids = [threepid] with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred(self.auth.check_auth_blocking()) @@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_hs_disabled(self): - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) @@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase): """ # this should be the default, but we had a bug where the test was doing the wrong # thing, so let's make it explicit - self.hs.config.server_notices_mxid = None + self.auth_blocking._server_notices_mxid = None - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) @@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True user = "@user:server" - self.hs.config.server_notices_mxid = user - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._server_notices_mxid = user + self.auth_blocking._hs_disabled_message = "Reason for being disabled" yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 52c4ac8b11..c01b04e1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests - self.hs.config.max_mau_value = 50 + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking._max_mau_value = 50 + self.small_number_of_users = 1 self.large_number_of_users = 100 @@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) @@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_parity(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 4cbe9784ed..e178d7765b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - def test_wait_for_sync_for_user_auth_blocking(self): + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) self.reactor.advance(100) # So we get not 0 time - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True e = self.get_failure( self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.hs.config.hs_disabled = False + self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) diff --git a/tests/test_mau.py b/tests/test_mau.py index 1fbe0d51ff..eb159e3ba5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from mock import Mock +from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync @@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.hs_disabled = False self.hs.config.max_mau_value = 2 - self.hs.config.mau_trial_days = 0 self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" + self.hs.config.mau_trial_days = 0 + + # AuthBlocking reads config options during hs creation. Recreate the + # hs' copy of AuthBlocking after we've updated config values above + self.auth_blocking = AuthBlocking(self.hs) + self.hs.get_auth()._auth_blocking = self.auth_blocking + return self.hs def test_simple_deny_mau(self): @@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_trial_users_cant_come_back(self): + self.auth_blocking._mau_trial_days = 1 self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_tracked_but_not_limited(self): - self.hs.config.max_mau_value = 1 # should not matter - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 1 # should not matter + self.auth_blocking._limit_usage_by_mau = False self.hs.config.mau_stats_only = True # Simply being able to create 2 users indicates that the |