diff options
author | Richard van der Hoff <richard@matrix.org> | 2020-05-06 15:56:03 +0100 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2020-05-06 15:56:03 +0100 |
commit | 62ee86211965e366dc285a72e3dc3f6d87a76b85 (patch) | |
tree | aa3ae2f0d827cdd182f38e3388b3879c35917139 /synapse | |
parent | Merge pull request #7428 from matrix-org/rav/cross_signing_keys_cache (diff) | |
parent | Stop Auth methods from polling the config on every req. (#7420) (diff) | |
download | synapse-62ee86211965e366dc285a72e3dc3f6d87a76b85.tar.xz |
Merge branch 'release-v1.13.0' into develop
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/auth.py | 83 | ||||
-rw-r--r-- | synapse/api/auth_blocking.py | 104 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 12 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 51 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 52 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/devices.py | 4 | ||||
-rw-r--r-- | synapse/util/caches/stream_change_cache.py | 19 |
7 files changed, 208 insertions, 117 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c5d1eb952b..1ad5ff9410 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,16 +26,15 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes +from synapse.api.auth_blocking import AuthBlocking +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, InvalidClientTokenError, MissingClientTokenError, - ResourceLimitError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.config.server import is_threepid_reserved from synapse.events import EventBase from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -77,7 +76,11 @@ class Auth(object): self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) + self._auth_blocking = AuthBlocking(self.hs) + self._account_validity = hs.config.account_validity + self._track_appservice_user_ips = hs.config.track_appservice_user_ips + self._macaroon_secret_key = hs.config.macaroon_secret_key @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): @@ -191,7 +194,7 @@ class Auth(object): opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self.hs.config.track_appservice_user_ips: + if ip_addr and self._track_appservice_user_ips: yield self.store.insert_client_ip( user_id=user_id, access_token=access_token, @@ -454,7 +457,7 @@ class Auth(object): # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self._macaroon_secret_key) def _verify_expiry(self, caveat): prefix = "time < " @@ -663,71 +666,5 @@ class Auth(object): % (user_id, room_id), ) - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): - """Checks if the user should be rejected for some external reason, - such as monthly active user limiting or global disable flag - - Args: - user_id(str|None): If present, checks for presence against existing - MAU cohort - - threepid(dict|None): If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. - - user_type(str|None): If present, is used to decide whether to check against - certain blocking reasons like MAU. - """ - - # Never fail an auth check for the server notices users or support user - # This can be a problem where event creation is prohibited due to blocking - if user_id is not None: - if user_id == self.hs.config.server_notices_mxid: - return - if (yield self.store.is_support_user(user_id)): - return - - if self.hs.config.hs_disabled: - raise ResourceLimitError( - 403, - self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=self.hs.config.admin_contact, - limit_type=LimitBlockingTypes.HS_DISABLED, - ) - if self.hs.config.limit_usage_by_mau is True: - assert not (user_id and threepid) - - # If the user is already part of the MAU cohort or a trial user - if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) - if timestamp: - return - - is_trial = yield self.store.is_trial_user(user_id) - if is_trial: - return - elif threepid: - # If the user does not exist yet, but is signing up with a - # reserved threepid then pass auth check - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - return - elif user_type == UserTypes.SUPPORT: - # If the user does not exist yet and is of type "support", - # allow registration. Support users are excluded from MAU checks. - return - # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() - if current_mau >= self.hs.config.max_mau_value: - raise ResourceLimitError( - 403, - "Monthly Active User Limit Exceeded", - admin_contact=self.hs.config.admin_contact, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, - ) + def check_auth_blocking(self, *args, **kwargs): + return self._auth_blocking.check_auth_blocking(*args, **kwargs) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py new file mode 100644 index 0000000000..5c499b6b4e --- /dev/null +++ b/synapse/api/auth_blocking.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import LimitBlockingTypes, UserTypes +from synapse.api.errors import Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved + +logger = logging.getLogger(__name__) + + +class AuthBlocking(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + self._server_notices_mxid = hs.config.server_notices_mxid + self._hs_disabled = hs.config.hs_disabled + self._hs_disabled_message = hs.config.hs_disabled_message + self._admin_contact = hs.config.admin_contact + self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + + @defer.inlineCallbacks + def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + + Args: + user_id(str|None): If present, checks for presence against existing + MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. + + user_type(str|None): If present, is used to decide whether to check against + certain blocking reasons like MAU. + """ + + # Never fail an auth check for the server notices users or support user + # This can be a problem where event creation is prohibited due to blocking + if user_id is not None: + if user_id == self._server_notices_mxid: + return + if (yield self.store.is_support_user(user_id)): + return + + if self._hs_disabled: + raise ResourceLimitError( + 403, + self._hs_disabled_message, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self._admin_contact, + limit_type=LimitBlockingTypes.HS_DISABLED, + ) + if self._limit_usage_by_mau is True: + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user + if user_id: + timestamp = yield self.store.user_last_seen_monthly_active(user_id) + if timestamp: + return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid): + return + elif user_type == UserTypes.SUPPORT: + # If the user does not exist yet and is of type "support", + # allow registration. Support users are excluded from MAU checks. + return + # Else if there is no room in the MAU bucket, bail + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self._max_mau_value: + raise ResourceLimitError( + 403, + "Monthly Active User Limit Exceeded", + admin_contact=self._admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4f76b7a743..00718d7f2d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1143,10 +1143,14 @@ class SyncHandler(object): user_id ) - tracked_users = set(users_who_share_room) - - # Always tell the user about their own devices - tracked_users.add(user_id) + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index bf4f1a5949..b14a3d9fca 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -81,9 +81,6 @@ class ReplicationCommandHandler: self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() - # Set of streams that we've caught up with. - self._streams_connected = set() # type: Set[str] - self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -99,9 +96,13 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReconnectingClientFactory] - # The currently connected connections. + # The currently connected connections. (The list of places we need to send + # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] + # For each connection, the incoming streams that are coming from that connection + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -257,9 +258,11 @@ class ReplicationCommandHandler: # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. with await self._position_linearizer.queue(cmd.stream_name): - if stream_name not in self._streams_connected: - # If the stream isn't marked as connected then we haven't seen a - # `POSITION` command yet, and so we may have missed some rows. + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. logger.debug( @@ -302,21 +305,25 @@ class ReplicationCommandHandler: # Ignore POSITION that are just our own echoes return - stream = self._streams.get(cmd.stream_name) + logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) + + stream_name = cmd.stream_name + stream = self._streams.get(stream_name) if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + logger.error("Got POSITION for unknown stream: %s", stream_name) return # We protect catching up with a linearizer in case the replication # connection reconnects under us. - with await self._position_linearizer.queue(cmd.stream_name): + with await self._position_linearizer.queue(stream_name): # We're about to go and catch up with the stream, so remove from set # of connected streams. - self._streams_connected.discard(cmd.stream_name) + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) # We clear the pending batches for the stream as the fetching of the # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(cmd.stream_name, []) + self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. current_token = stream.current_token() @@ -326,6 +333,12 @@ class ReplicationCommandHandler: # between then and now. missing_updates = cmd.token != current_token while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) ( updates, current_token, @@ -341,16 +354,18 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, + stream_name, cmd.instance_name, token, [stream.parse_row(row) for row in rows], ) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + await self._replication_data_handler.on_position(stream_name, cmd.token) - self._streams_connected.add(cmd.stream_name) + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand @@ -408,6 +423,12 @@ class ReplicationCommandHandler: def lost_connection(self, connection: AbstractConnection): """Called when a connection is closed/lost. """ + # we no longer need _streams_by_connection for this connection. + streams = self._streams_by_connection.pop(connection, None) + if streams: + logger.info( + "Lost replication connection; streams now disconnected: %s", streams + ) try: self._connections.remove(connection) except ValueError: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 41c623d737..db69f92557 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import PreserveLoggingContext +from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( Command, @@ -41,8 +41,14 @@ logger = logging.getLogger(__name__) class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. - Parses incoming messages from redis into replication commands, and passes - them to `ReplicationCommandHandler` + This class fulfils two functions: + + (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis + connection, parsing *incoming* messages into replication commands, and passing them + to `ReplicationCommandHandler` + + (b) it implements the AbstractConnection API, where it sends *outgoing* commands + onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom constructor, so instead we expect the defined attributes below to be set @@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Attributes: handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to (not anything to - do with Synapse replication streams). + stream_name: The *redis* stream name to subscribe to and publish from + (not anything to do with Synapse replication streams). outbound_redis_connection: The connection to redis to use to send commands. """ @@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): + logger.info("Connected to redis") super().connectionMade() - logger.info("Connected to redis instance") - self.subscribe(self.stream_name) - self.send_command(ReplicateCommand()) - + run_as_background_process("subscribe-replication", self._send_subscribe) self.handler.new_connection(self) + async def _send_subscribe(self): + # it's important to make sure that we only send the REPLICATE command once we + # have successfully subscribed to the stream - otherwise we might miss the + # POSITION response sent back by the other end. + logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) + await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info( + "Successfully subscribed to redis stream, sending REPLICATE command" + ) + await self._async_send_command(ReplicateCommand()) + logger.info("REPLICATE successfully sent") + def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ @@ -120,8 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): + logger.info("Lost connection to redis") super().connectionLost(reason) - logger.info("Lost connection to redis instance") self.handler.lost_connection(self) def send_command(self, cmd: Command): @@ -130,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ + run_as_background_process("send-cmd", self._async_send_command, cmd) + + async def _async_send_command(self, cmd: Command): + """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -140,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() - async def _send(): - with PreserveLoggingContext(): - # Note that we use the other connection as we can't send - # commands using the subscription connection. - await self.outbound_redis_connection.publish( - self.stream_name, encoded_string - ) - - run_as_background_process("send-cmd", _send) + await make_deferred_yieldable( + self.outbound_redis_connection.publish(self.stream_name, encoded_string) + ) class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 536cef3abd..fe6d6ecfe0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -536,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore): # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = list( - self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key ) if not to_check: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 38dc3f501e..e54f80d76e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,12 +14,13 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Mapping, Optional, Set +from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types from sortedcontainers import SortedDict +from synapse.types import Collection from synapse.util import caches logger = logging.getLogger(__name__) @@ -85,8 +86,8 @@ class StreamChangeCache: return False def get_entities_changed( - self, entities: Iterable[EntityType], stream_pos: int - ) -> Set[EntityType]: + self, entities: Collection[EntityType], stream_pos: int + ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the @@ -94,7 +95,17 @@ class StreamChangeCache: """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: - result = set(changed_entities).intersection(entities) + # We now do an intersection, trying to do so in the most efficient + # way possible (some of these sets are *large*). First check in the + # given iterable is already set that we can reuse, otherwise we + # create a set of the *smallest* of the two iterables and call + # `intersection(..)` on it (this can be twice as fast as the reverse). + if isinstance(entities, (set, frozenset)): + result = entities.intersection(changed_entities) + elif len(changed_entities) < len(entities): + result = set(changed_entities).intersection(entities) + else: + result = set(entities).intersection(changed_entities) self.metrics.inc_hits() else: result = set(entities) |