diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/room_member.py | 287 | ||||
-rw-r--r-- | synapse/replication/slave/storage/profile.py | 21 | ||||
-rw-r--r-- | synapse/server.py | 6 | ||||
-rw-r--r-- | synapse/storage/profile.py | 50 | ||||
-rw-r--r-- | synapse/storage/room.py | 30 | ||||
-rw-r--r-- | synapse/storage/state.py | 34 |
6 files changed, 282 insertions, 146 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0127cf4166..9977be8831 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import abc import logging from signedjson.key import decode_verify_key_bytes @@ -31,6 +31,7 @@ from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room + logger = logging.getLogger(__name__) id_server_scheme = "https://" @@ -42,6 +43,8 @@ class RoomMemberHandler(object): # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. + __metaclass__ = abc.ABCMeta + def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() @@ -61,9 +64,87 @@ class RoomMemberHandler(object): self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") + @abc.abstractmethod + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Try and join a room that this server is not in + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers that can be used + to join via. + room_id (str): Room that we are trying to join + user (UserID): User who is trying to join + content (dict): A dict that should be used as the content of the + join event. + + Returns: + Deferred + """ + raise NotImplementedError() + + @abc.abstractmethod + def _remote_reject_invite(self, remote_room_hosts, room_id, target): + """Attempt to reject an invite for a room this server is not in. If we + fail to do so we locally mark the invite as rejected. + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers to use to try and + reject invite + room_id (str) + target (UserID): The user rejecting the invite + + Returns: + Deferred[dict]: A dictionary to be returned to the client, may + include event_id etc, or nothing if we locally rejected + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_joined_room(self, target, room_id): + """Notifies distributor on master process that the user has joined the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_left_room(self, target, room_id): + """Notifies distributor on master process that the user has left the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() @defer.inlineCallbacks def _local_membership_update( @@ -127,83 +208,16 @@ class RoomMemberHandler(object): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target, room_id) + yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + yield self._user_left_room(target, room_id) defer.returnValue(event) @defer.inlineCallbacks - def _remote_join(self, remote_room_hosts, room_id, user, content): - """Try and join a room that this server is not in - - Args: - remote_room_hosts (list[str]): List of servers that can be used - to join via. - room_id (str): Room that we are trying to join - user (UserID): User who is trying to join - content (dict): A dict that should be used as the content of the - join event. - - Returns: - Deferred - """ - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, - ) - yield user_joined_room(self.distributor, user, room_id) - - @defer.inlineCallbacks - def _remote_reject_invite(self, remote_room_hosts, room_id, target): - """Attempt to reject an invite for a room this server is not in. If we - fail to do so we locally mark the invite as rejected. - - Args: - remote_room_hosts (list[str]): List of servers to use to try and - reject invite - room_id (str) - target (UserID): The user rejecting the invite - - Returns: - Deferred[dict]: A dictionary to be returned to the client, may - include event_id etc, or nothing if we locally rejected - """ - fed_handler = self.federation_handler - try: - ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), - ) - defer.returnValue(ret) - except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. - # - # The 'except' clause is very broad, but we need to - # capture everything from DNS failures upwards - # - logger.warn("Failed to reject invite: %s", e) - - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) - defer.returnValue({}) - - @defer.inlineCallbacks def update_membership( self, requester, @@ -356,7 +370,7 @@ class RoomMemberHandler(object): content["kind"] = "guest" ret = yield self._remote_join( - remote_room_hosts, room_id, target, content + requester, remote_room_hosts, room_id, target, content ) defer.returnValue(ret) @@ -378,7 +392,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - remote_room_hosts, room_id, target, + requester, remote_room_hosts, room_id, target, ) defer.returnValue(res) @@ -476,12 +490,12 @@ class RoomMemberHandler(object): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target_user, room_id) + yield self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + yield self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -672,6 +686,7 @@ class RoomMemberHandler(object): token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( + requester=requester, id_server=id_server, medium=medium, address=address, @@ -708,6 +723,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( self, + requester, id_server, medium, address, @@ -724,6 +740,7 @@ class RoomMemberHandler(object): Asks an identity server for a third party invite. Args: + requester (Requester) id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. @@ -766,8 +783,8 @@ class RoomMemberHandler(object): } if self.config.invite_3pid_guest: - rh = self.registration_handler - guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest( + guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( + requester=requester, medium=medium, address=address, inviter_user_id=inviter_user_id, @@ -801,27 +818,6 @@ class RoomMemberHandler(object): defer.returnValue((token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks - def forget(self, user, room_id): - user_id = user.to_string() - - member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN - ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) - - if membership: - yield self.store.forget(user_id, room_id) - - @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): # Have we just created the room, and is this about to be the very # first member event? @@ -842,3 +838,94 @@ class RoomMemberHandler(object): defer.returnValue(True) defer.returnValue(False) + + +class RoomMemberMasterHandler(RoomMemberHandler): + def __init__(self, hs): + super(RoomMemberMasterHandler, self).__init__(hs) + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Implements RoomMemberHandler._remote_join + """ + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield self._user_joined_room(user, room_id) + + @defer.inlineCallbacks + def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + """Implements RoomMemberHandler._remote_reject_invite + """ + fed_handler = self.federation_handler + try: + ret = yield fed_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + target.to_string(), + ) + defer.returnValue(ret) + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + defer.returnValue({}) + + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Implements RoomMemberHandler.get_or_register_3pid_guest + """ + rg = self.registration_handler + return rg.get_or_register_3pid_guest(medium, address, inviter_user_id) + + def _user_joined_room(self, target, room_id): + """Implements RoomMemberHandler._user_joined_room + """ + return user_joined_room(self.distributor, target, room_id) + + def _user_left_room(self, target, room_id): + """Implements RoomMemberHandler._user_left_room + """ + return user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership not in [ + Membership.LEAVE, Membership.BAN + ]: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py new file mode 100644 index 0000000000..46c28d4171 --- /dev/null +++ b/synapse/replication/slave/storage/profile.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.storage.profile import ProfileWorkerStore + + +class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/server.py b/synapse/server.py index 43c6e0a6d6..763f0f5a68 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -47,7 +47,7 @@ from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.room_list import RoomListHandler -from synapse.handlers.room_member import RoomMemberHandler +from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler @@ -392,7 +392,9 @@ class HomeServer(object): return SpamChecker(self) def build_room_member_handler(self): - return RoomMemberHandler(self) + if self.config.worker_app: + raise Exception("Can't use RoomMemberHandler on workers") + return RoomMemberMasterHandler(self) def build_federation_registry(self): return FederationHandlerRegistry() diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index ec02e73bc2..8612bd5ecc 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -21,14 +21,7 @@ from synapse.api.errors import StoreError from ._base import SQLBaseStore -class ProfileStore(SQLBaseStore): - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", - ) - +class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: @@ -61,14 +54,6 @@ class ProfileStore(SQLBaseStore): desc="get_profile_displayname", ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - def get_profile_avatar_url(self, user_localpart): return self._simple_select_one_onecol( table="profiles", @@ -77,14 +62,6 @@ class ProfileStore(SQLBaseStore): desc="get_profile_avatar_url", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, - desc="set_profile_avatar_url", - ) - def get_from_remote_profile_cache(self, user_id): return self._simple_select_one( table="remote_profile_cache", @@ -94,6 +71,31 @@ class ProfileStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) + +class ProfileStore(ProfileWorkerStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + desc="create_profile", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + + def set_profile_avatar_url(self, user_localpart, new_avatar_url): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"avatar_url": new_avatar_url}, + desc="set_profile_avatar_url", + ) + def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 7f2c08d7a6..34ed84ea22 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -157,6 +157,18 @@ class RoomWorkerStore(SQLBaseStore): "get_public_room_changes", get_public_room_changes_txn ) + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self._simple_select_one_onecol( + table="blocked_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + class RoomStore(RoomWorkerStore, SearchStore): @@ -485,18 +497,6 @@ class RoomStore(RoomWorkerStore, SearchStore): else: defer.returnValue(None) - @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( - table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, - retcol="1", - allow_none=True, - desc="is_room_blocked", - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( @@ -507,7 +507,11 @@ class RoomStore(RoomWorkerStore, SearchStore): }, desc="block_room", ) - self.is_room_blocked.invalidate((room_id,)) + yield self.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, (room_id,), + ) def get_media_mxcs_in_room(self, room_id): """Retrieves all the local and remote media MXC URIs in a given room diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2b325e1c1f..ffa4246031 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -240,6 +240,9 @@ class StateGroupWorkerStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -259,10 +262,19 @@ class StateGroupWorkerStore(SQLBaseStore): key = (typ, state_key) results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.extend(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -279,7 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" @@ -292,9 +304,17 @@ class StateGroupWorkerStore(SQLBaseStore): if (typ, state_key) not in results[group] ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( |