summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/room_member.py287
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/storage/room.py30
-rw-r--r--synapse/storage/state.py34
4 files changed, 235 insertions, 122 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0127cf4166..9977be8831 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+import abc
 import logging
 
 from signedjson.key import decode_verify_key_bytes
@@ -31,6 +31,7 @@ from synapse.types import UserID, RoomID
 from synapse.util.async import Linearizer
 from synapse.util.distributor import user_left_room, user_joined_room
 
+
 logger = logging.getLogger(__name__)
 
 id_server_scheme = "https://"
@@ -42,6 +43,8 @@ class RoomMemberHandler(object):
     #   API that takes ID strings and returns pagination chunks. These concerns
     #   ought to be separated out a lot better.
 
+    __metaclass__ = abc.ABCMeta
+
     def __init__(self, hs):
         self.hs = hs
         self.store = hs.get_datastore()
@@ -61,9 +64,87 @@ class RoomMemberHandler(object):
         self.clock = hs.get_clock()
         self.spam_checker = hs.get_spam_checker()
 
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
-        self.distributor.declare("user_left_room")
+    @abc.abstractmethod
+    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+        """Try and join a room that this server is not in
+
+        Args:
+            requester (Requester)
+            remote_room_hosts (list[str]): List of servers that can be used
+                to join via.
+            room_id (str): Room that we are trying to join
+            user (UserID): User who is trying to join
+            content (dict): A dict that should be used as the content of the
+                join event.
+
+        Returns:
+            Deferred
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+        """Attempt to reject an invite for a room this server is not in. If we
+        fail to do so we locally mark the invite as rejected.
+
+        Args:
+            requester (Requester)
+            remote_room_hosts (list[str]): List of servers to use to try and
+                reject invite
+            room_id (str)
+            target (UserID): The user rejecting the invite
+
+        Returns:
+            Deferred[dict]: A dictionary to be returned to the client, may
+            include event_id etc, or nothing if we locally rejected
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+        """Get a guest access token for a 3PID, creating a guest account if
+        one doesn't already exist.
+
+        Args:
+            requester (Requester)
+            medium (str)
+            address (str)
+            inviter_user_id (str): The user ID who is trying to invite the
+                3PID
+
+        Returns:
+            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+            3PID guest account.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _user_joined_room(self, target, room_id):
+        """Notifies distributor on master process that the user has joined the
+        room.
+
+        Args:
+            target (UserID)
+            room_id (str)
+
+        Returns:
+            Deferred|None
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _user_left_room(self, target, room_id):
+        """Notifies distributor on master process that the user has left the
+        room.
+
+        Args:
+            target (UserID)
+            room_id (str)
+
+        Returns:
+            Deferred|None
+        """
+        raise NotImplementedError()
 
     @defer.inlineCallbacks
     def _local_membership_update(
@@ -127,83 +208,16 @@ class RoomMemberHandler(object):
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
             if newly_joined:
-                yield user_joined_room(self.distributor, target, room_id)
+                yield self._user_joined_room(target, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
-                    user_left_room(self.distributor, target, room_id)
+                    yield self._user_left_room(target, room_id)
 
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def _remote_join(self, remote_room_hosts, room_id, user, content):
-        """Try and join a room that this server is not in
-
-        Args:
-            remote_room_hosts (list[str]): List of servers that can be used
-                to join via.
-            room_id (str): Room that we are trying to join
-            user (UserID): User who is trying to join
-            content (dict): A dict that should be used as the content of the
-                join event.
-
-        Returns:
-            Deferred
-        """
-        if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
-
-        # We don't do an auth check if we are doing an invite
-        # join dance for now, since we're kinda implicitly checking
-        # that we are allowed to join when we decide whether or not we
-        # need to do the invite/join dance.
-        yield self.federation_handler.do_invite_join(
-            remote_room_hosts,
-            room_id,
-            user.to_string(),
-            content,
-        )
-        yield user_joined_room(self.distributor, user, room_id)
-
-    @defer.inlineCallbacks
-    def _remote_reject_invite(self, remote_room_hosts, room_id, target):
-        """Attempt to reject an invite for a room this server is not in. If we
-        fail to do so we locally mark the invite as rejected.
-
-        Args:
-            remote_room_hosts (list[str]): List of servers to use to try and
-                reject invite
-            room_id (str)
-            target (UserID): The user rejecting the invite
-
-        Returns:
-            Deferred[dict]: A dictionary to be returned to the client, may
-            include event_id etc, or nothing if we locally rejected
-        """
-        fed_handler = self.federation_handler
-        try:
-            ret = yield fed_handler.do_remotely_reject_invite(
-                remote_room_hosts,
-                room_id,
-                target.to_string(),
-            )
-            defer.returnValue(ret)
-        except Exception as e:
-            # if we were unable to reject the exception, just mark
-            # it as rejected on our end and plough ahead.
-            #
-            # The 'except' clause is very broad, but we need to
-            # capture everything from DNS failures upwards
-            #
-            logger.warn("Failed to reject invite: %s", e)
-
-            yield self.store.locally_reject_invite(
-                target.to_string(), room_id
-            )
-            defer.returnValue({})
-
-    @defer.inlineCallbacks
     def update_membership(
             self,
             requester,
@@ -356,7 +370,7 @@ class RoomMemberHandler(object):
                     content["kind"] = "guest"
 
                 ret = yield self._remote_join(
-                    remote_room_hosts, room_id, target, content
+                    requester, remote_room_hosts, room_id, target, content
                 )
                 defer.returnValue(ret)
 
@@ -378,7 +392,7 @@ class RoomMemberHandler(object):
                     # send the rejection to the inviter's HS.
                     remote_room_hosts = remote_room_hosts + [inviter.domain]
                     res = yield self._remote_reject_invite(
-                        remote_room_hosts, room_id, target,
+                        requester, remote_room_hosts, room_id, target,
                     )
                     defer.returnValue(res)
 
@@ -476,12 +490,12 @@ class RoomMemberHandler(object):
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
             if newly_joined:
-                yield user_joined_room(self.distributor, target_user, room_id)
+                yield self._user_joined_room(target_user, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
-                    user_left_room(self.distributor, target_user, room_id)
+                    yield self._user_left_room(target_user, room_id)
 
     @defer.inlineCallbacks
     def _can_guest_join(self, current_state_ids):
@@ -672,6 +686,7 @@ class RoomMemberHandler(object):
 
         token, public_keys, fallback_public_key, display_name = (
             yield self._ask_id_server_for_third_party_invite(
+                requester=requester,
                 id_server=id_server,
                 medium=medium,
                 address=address,
@@ -708,6 +723,7 @@ class RoomMemberHandler(object):
     @defer.inlineCallbacks
     def _ask_id_server_for_third_party_invite(
             self,
+            requester,
             id_server,
             medium,
             address,
@@ -724,6 +740,7 @@ class RoomMemberHandler(object):
         Asks an identity server for a third party invite.
 
         Args:
+            requester (Requester)
             id_server (str): hostname + optional port for the identity server.
             medium (str): The literal string "email".
             address (str): The third party address being invited.
@@ -766,8 +783,8 @@ class RoomMemberHandler(object):
         }
 
         if self.config.invite_3pid_guest:
-            rh = self.registration_handler
-            guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
+            guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+                requester=requester,
                 medium=medium,
                 address=address,
                 inviter_user_id=inviter_user_id,
@@ -801,27 +818,6 @@ class RoomMemberHandler(object):
         defer.returnValue((token, public_keys, fallback_public_key, display_name))
 
     @defer.inlineCallbacks
-    def forget(self, user, room_id):
-        user_id = user.to_string()
-
-        member = yield self.state_handler.get_current_state(
-            room_id=room_id,
-            event_type=EventTypes.Member,
-            state_key=user_id
-        )
-        membership = member.membership if member else None
-
-        if membership is not None and membership not in [
-            Membership.LEAVE, Membership.BAN
-        ]:
-            raise SynapseError(400, "User %s in room %s" % (
-                user_id, room_id
-            ))
-
-        if membership:
-            yield self.store.forget(user_id, room_id)
-
-    @defer.inlineCallbacks
     def _is_host_in_room(self, current_state_ids):
         # Have we just created the room, and is this about to be the very
         # first member event?
@@ -842,3 +838,94 @@ class RoomMemberHandler(object):
                 defer.returnValue(True)
 
         defer.returnValue(False)
+
+
+class RoomMemberMasterHandler(RoomMemberHandler):
+    def __init__(self, hs):
+        super(RoomMemberMasterHandler, self).__init__(hs)
+
+        self.distributor = hs.get_distributor()
+        self.distributor.declare("user_joined_room")
+        self.distributor.declare("user_left_room")
+
+    @defer.inlineCallbacks
+    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+        """Implements RoomMemberHandler._remote_join
+        """
+        if len(remote_room_hosts) == 0:
+            raise SynapseError(404, "No known servers")
+
+        # We don't do an auth check if we are doing an invite
+        # join dance for now, since we're kinda implicitly checking
+        # that we are allowed to join when we decide whether or not we
+        # need to do the invite/join dance.
+        yield self.federation_handler.do_invite_join(
+            remote_room_hosts,
+            room_id,
+            user.to_string(),
+            content,
+        )
+        yield self._user_joined_room(user, room_id)
+
+    @defer.inlineCallbacks
+    def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+        """Implements RoomMemberHandler._remote_reject_invite
+        """
+        fed_handler = self.federation_handler
+        try:
+            ret = yield fed_handler.do_remotely_reject_invite(
+                remote_room_hosts,
+                room_id,
+                target.to_string(),
+            )
+            defer.returnValue(ret)
+        except Exception as e:
+            # if we were unable to reject the exception, just mark
+            # it as rejected on our end and plough ahead.
+            #
+            # The 'except' clause is very broad, but we need to
+            # capture everything from DNS failures upwards
+            #
+            logger.warn("Failed to reject invite: %s", e)
+
+            yield self.store.locally_reject_invite(
+                target.to_string(), room_id
+            )
+            defer.returnValue({})
+
+    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+        """Implements RoomMemberHandler.get_or_register_3pid_guest
+        """
+        rg = self.registration_handler
+        return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+
+    def _user_joined_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_joined_room
+        """
+        return user_joined_room(self.distributor, target, room_id)
+
+    def _user_left_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_left_room
+        """
+        return user_left_room(self.distributor, target, room_id)
+
+    @defer.inlineCallbacks
+    def forget(self, user, room_id):
+        user_id = user.to_string()
+
+        member = yield self.state_handler.get_current_state(
+            room_id=room_id,
+            event_type=EventTypes.Member,
+            state_key=user_id
+        )
+        membership = member.membership if member else None
+
+        if membership is not None and membership not in [
+            Membership.LEAVE, Membership.BAN
+        ]:
+            raise SynapseError(400, "User %s in room %s" % (
+                user_id, room_id
+            ))
+
+        if membership:
+            yield self.store.forget(user_id, room_id)
diff --git a/synapse/server.py b/synapse/server.py
index 43c6e0a6d6..763f0f5a68 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -47,7 +47,7 @@ from synapse.handlers.device import DeviceHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.room_list import RoomListHandler
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import RoomMemberMasterHandler
 from synapse.handlers.set_password import SetPasswordHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import TypingHandler
@@ -392,7 +392,9 @@ class HomeServer(object):
         return SpamChecker(self)
 
     def build_room_member_handler(self):
-        return RoomMemberHandler(self)
+        if self.config.worker_app:
+            raise Exception("Can't use RoomMemberHandler on workers")
+        return RoomMemberMasterHandler(self)
 
     def build_federation_registry(self):
         return FederationHandlerRegistry()
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 7f2c08d7a6..34ed84ea22 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -157,6 +157,18 @@ class RoomWorkerStore(SQLBaseStore):
             "get_public_room_changes", get_public_room_changes_txn
         )
 
+    @cached(max_entries=10000)
+    def is_room_blocked(self, room_id):
+        return self._simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="1",
+            allow_none=True,
+            desc="is_room_blocked",
+        )
+
 
 class RoomStore(RoomWorkerStore, SearchStore):
 
@@ -485,18 +497,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
         else:
             defer.returnValue(None)
 
-    @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self._simple_select_one_onecol(
-            table="blocked_rooms",
-            keyvalues={
-                "room_id": room_id,
-            },
-            retcol="1",
-            allow_none=True,
-            desc="is_room_blocked",
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         yield self._simple_insert(
@@ -507,7 +507,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
             },
             desc="block_room",
         )
-        self.is_room_blocked.invalidate((room_id,))
+        yield self.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked, (room_id,),
+        )
 
     def get_media_mxcs_in_room(self, room_id):
         """Retrieves all the local and remote media MXC URIs in a given room
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2b325e1c1f..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -240,6 +240,9 @@ class StateGroupWorkerStore(SQLBaseStore):
                     (
                         "AND type = ? AND state_key = ?",
                         (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
                     )
                     for etype, state_key in types
                 ]
@@ -259,10 +262,19 @@ class StateGroupWorkerStore(SQLBaseStore):
                         key = (typ, state_key)
                         results[group][key] = event_id
         else:
+            where_args = []
+            where_clauses = []
+            wildcard_types = False
             if types is not None:
-                where_clause = "AND (%s)" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                for typ in types:
+                    if typ[1] is None:
+                        where_clauses.append("(type = ?)")
+                        where_args.extend(typ[0])
+                        wildcard_types = True
+                    else:
+                        where_clauses.append("(type = ? AND state_key = ?)")
+                        where_args.extend([typ[0], typ[1]])
+                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
 
@@ -279,7 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
@@ -292,9 +304,17 @@ class StateGroupWorkerStore(SQLBaseStore):
                         if (typ, state_key) not in results[group]
                     )
 
-                    # If the lengths match then we must have all the types,
-                    # so no need to go walk further down the tree.
-                    if types is not None and len(results[group]) == len(types):
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        types is not None and
+                        not wildcard_types and
+                        len(results[group]) == len(types)
+                    ):
                         break
 
                     next_group = self._simple_select_one_onecol_txn(