summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py354
-rw-r--r--synapse/api/constants.py14
-rw-r--r--synapse/api/errors.py16
-rw-r--r--synapse/api/filtering.py200
4 files changed, 448 insertions, 136 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1e3b0fbfb7..8111b34428 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,15 +14,20 @@
 # limitations under the License.
 
 """This module contains classes for authenticating the user."""
+from canonicaljson import encode_canonical_json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json, SignatureVerifyException
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership, JoinRules
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
+from synapse.types import RoomID, UserID, EventID
 from synapse.util.logutils import log_function
-from synapse.types import UserID, ClientInfo
+from unpaddedbase64 import decode_base64
 
 import logging
+import pymacaroons
 
 logger = logging.getLogger(__name__)
 
@@ -30,6 +35,7 @@ logger = logging.getLogger(__name__)
 AuthEventTypes = (
     EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
     EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
+    EventTypes.ThirdPartyInvite,
 )
 
 
@@ -40,6 +46,13 @@ class Auth(object):
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+        self._KNOWN_CAVEAT_PREFIXES = set([
+            "gen = ",
+            "guest = ",
+            "type = ",
+            "time < ",
+            "user_id = ",
+        ])
 
     def check(self, event, auth_events):
         """ Checks if this event is correctly authed.
@@ -52,6 +65,8 @@ class Auth(object):
         Returns:
             True if the auth checks pass.
         """
+        self.check_size_limits(event)
+
         try:
             if not hasattr(event, "room_id"):
                 raise AuthError(500, "Event has no room_id: %s" % event)
@@ -65,6 +80,23 @@ class Auth(object):
                 # FIXME
                 return True
 
+            creation_event = auth_events.get((EventTypes.Create, ""), None)
+
+            if not creation_event:
+                raise SynapseError(
+                    403,
+                    "Room %r does not exist" % (event.room_id,)
+                )
+
+            creating_domain = RoomID.from_string(event.room_id).domain
+            originating_domain = UserID.from_string(event.sender).domain
+            if creating_domain != originating_domain:
+                if not self.can_federate(event, auth_events):
+                    raise AuthError(
+                        403,
+                        "This room has been marked as unfederatable."
+                    )
+
             # FIXME: Temp hack
             if event.type == EventTypes.Aliases:
                 return True
@@ -91,7 +123,7 @@ class Auth(object):
                 self._check_power_levels(event, auth_events)
 
             if event.type == EventTypes.Redaction:
-                self._check_redaction(event, auth_events)
+                self.check_redaction(event, auth_events)
 
             logger.debug("Allowing! %s", event)
         except AuthError as e:
@@ -102,8 +134,39 @@ class Auth(object):
             logger.info("Denying! %s", event)
             raise
 
+    def check_size_limits(self, event):
+        def too_big(field):
+            raise EventSizeError("%s too large" % (field,))
+
+        if len(event.user_id) > 255:
+            too_big("user_id")
+        if len(event.room_id) > 255:
+            too_big("room_id")
+        if event.is_state() and len(event.state_key) > 255:
+            too_big("state_key")
+        if len(event.type) > 255:
+            too_big("type")
+        if len(event.event_id) > 255:
+            too_big("event_id")
+        if len(encode_canonical_json(event.get_pdu_json())) > 65536:
+            too_big("event")
+
     @defer.inlineCallbacks
     def check_joined_room(self, room_id, user_id, current_state=None):
+        """Check if the user is currently joined in the room
+        Args:
+            room_id(str): The room to check.
+            user_id(str): The user to check.
+            current_state(dict): Optional map of the current state of the room.
+                If provided then that map is used to check whether they are a
+                member of the room. Otherwise the current membership is
+                loaded from the database.
+        Raises:
+            AuthError if the user is not in the room.
+        Returns:
+            A deferred membership event for the user if the user is in
+            the room.
+        """
         if current_state:
             member = current_state.get(
                 (EventTypes.Member, user_id),
@@ -120,6 +183,33 @@ class Auth(object):
         defer.returnValue(member)
 
     @defer.inlineCallbacks
+    def check_user_was_in_room(self, room_id, user_id):
+        """Check if the user was in the room at some point.
+        Args:
+            room_id(str): The room to check.
+            user_id(str): The user to check.
+        Raises:
+            AuthError if the user was never in the room.
+        Returns:
+            A deferred membership event for the user if the user was in the
+            room. This will be the join event if they are currently joined to
+            the room. This will be the leave event if they have left the room.
+        """
+        member = yield self.state.get_current_state(
+            room_id=room_id,
+            event_type=EventTypes.Member,
+            state_key=user_id
+        )
+        membership = member.membership if member else None
+
+        if membership not in (Membership.JOIN, Membership.LEAVE):
+            raise AuthError(403, "User %s not in room %s" % (
+                user_id, room_id
+            ))
+
+        defer.returnValue(member)
+
+    @defer.inlineCallbacks
     def check_host_in_room(self, room_id, host):
         curr_state = yield self.state.get_current_state(room_id)
 
@@ -153,6 +243,11 @@ class Auth(object):
                 user_id, room_id, repr(member)
             ))
 
+    def can_federate(self, event, auth_events):
+        creation_event = auth_events.get((EventTypes.Create, ""))
+
+        return creation_event.content.get("m.federate", True) is True
+
     @log_function
     def is_membership_change_allowed(self, event, auth_events):
         membership = event.content["membership"]
@@ -168,6 +263,15 @@ class Auth(object):
 
         target_user_id = event.state_key
 
+        creating_domain = RoomID.from_string(event.room_id).domain
+        target_domain = UserID.from_string(target_user_id).domain
+        if creating_domain != target_domain:
+            if not self.can_federate(event, auth_events):
+                raise AuthError(
+                    403,
+                    "This room has been marked as unfederatable."
+                )
+
         # get info about the caller
         key = (EventTypes.Member, event.user_id, )
         caller = auth_events.get(key)
@@ -213,8 +317,17 @@ class Auth(object):
             }
         )
 
+        if Membership.INVITE == membership and "third_party_invite" in event.content:
+            if not self._verify_third_party_invite(event, auth_events):
+                raise AuthError(403, "You are not invited to this room.")
+            return True
+
         if Membership.JOIN != membership:
-            # JOIN is the only action you can perform if you're not in the room
+            if (caller_invited
+                    and Membership.LEAVE == membership
+                    and target_user_id == event.user_id):
+                return True
+
             if not caller_in_room:  # caller isn't joined
                 raise AuthError(
                     403,
@@ -278,6 +391,66 @@ class Auth(object):
 
         return True
 
+    def _verify_third_party_invite(self, event, auth_events):
+        """
+        Validates that the invite event is authorized by a previous third-party invite.
+
+        Checks that the public key, and keyserver, match those in the third party invite,
+        and that the invite event has a signature issued using that public key.
+
+        Args:
+            event: The m.room.member join event being validated.
+            auth_events: All relevant previous context events which may be used
+                for authorization decisions.
+
+        Return:
+            True if the event fulfills the expectations of a previous third party
+            invite event.
+        """
+        if "third_party_invite" not in event.content:
+            return False
+        if "signed" not in event.content["third_party_invite"]:
+            return False
+        signed = event.content["third_party_invite"]["signed"]
+        for key in {"mxid", "token"}:
+            if key not in signed:
+                return False
+
+        token = signed["token"]
+
+        invite_event = auth_events.get(
+            (EventTypes.ThirdPartyInvite, token,)
+        )
+        if not invite_event:
+            return False
+
+        if event.user_id != invite_event.user_id:
+            return False
+        try:
+            public_key = invite_event.content["public_key"]
+            if signed["mxid"] != event.state_key:
+                return False
+            if signed["token"] != token:
+                return False
+            for server, signature_block in signed["signatures"].items():
+                for key_name, encoded_signature in signature_block.items():
+                    if not key_name.startswith("ed25519:"):
+                        return False
+                    verify_key = decode_verify_key_bytes(
+                        key_name,
+                        decode_base64(public_key)
+                    )
+                    verify_signed_json(signed, server, verify_key)
+
+                    # We got the public key from the invite, so we know that the
+                    # correct server signed the signed bundle.
+                    # The caller is responsible for checking that the signing
+                    # server has not revoked that public key.
+                    return True
+            return False
+        except (KeyError, SignatureVerifyException,):
+            return False
+
     def _get_power_level_event(self, auth_events):
         key = (EventTypes.PowerLevels, "", )
         return auth_events.get(key)
@@ -316,15 +489,15 @@ class Auth(object):
             return default
 
     @defer.inlineCallbacks
-    def get_user_by_req(self, request):
+    def get_user_by_req(self, request, allow_guest=False):
         """ Get a registered user's ID.
 
         Args:
             request - An HTTP request with an access_token query parameter.
         Returns:
-            tuple : of UserID and device string:
-                User ID object of the user making the request
-                ClientInfo object of the client instance the user is using
+            tuple of:
+                UserID (str)
+                Access token ID (str)
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
@@ -354,17 +527,15 @@ class Auth(object):
 
                 request.authenticated_entity = user_id
 
-                defer.returnValue(
-                    (UserID.from_string(user_id), ClientInfo("", ""))
-                )
+                defer.returnValue((UserID.from_string(user_id), "", False))
                 return
             except KeyError:
                 pass  # normal users won't have the user_id query parameter set.
 
-            user_info = yield self.get_user_by_token(access_token)
+            user_info = yield self._get_user_by_access_token(access_token)
             user = user_info["user"]
-            device_id = user_info["device_id"]
             token_id = user_info["token_id"]
+            is_guest = user_info["is_guest"]
 
             ip_addr = self.hs.get_ip_from_request(request)
             user_agent = request.requestHeaders.getRawHeaders(
@@ -375,14 +546,18 @@ class Auth(object):
                 self.store.insert_client_ip(
                     user=user,
                     access_token=access_token,
-                    device_id=user_info["device_id"],
                     ip=ip_addr,
                     user_agent=user_agent
                 )
 
+            if is_guest and not allow_guest:
+                raise AuthError(
+                    403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+                )
+
             request.authenticated_entity = user.to_string()
 
-            defer.returnValue((user, ClientInfo(device_id, token_id)))
+            defer.returnValue((user, token_id, is_guest,))
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -390,30 +565,124 @@ class Auth(object):
             )
 
     @defer.inlineCallbacks
-    def get_user_by_token(self, token):
+    def _get_user_by_access_token(self, token):
         """ Get a registered user's ID.
 
         Args:
             token (str): The access token to get the user by.
         Returns:
-            dict : dict that includes the user, device_id, and whether the
-                user is a server admin.
+            dict : dict that includes the user and the ID of their access token.
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
-        ret = yield self.store.get_user_by_token(token)
+        try:
+            ret = yield self._get_user_from_macaroon(token)
+        except AuthError:
+            # TODO(daniel): Remove this fallback when all existing access tokens
+            # have been re-issued as macaroons.
+            ret = yield self._look_up_user_by_access_token(token)
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def _get_user_from_macaroon(self, macaroon_str):
+        try:
+            macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
+            self.validate_macaroon(
+                macaroon, "access",
+                [lambda c: c.startswith("time < ")]
+            )
+
+            user_prefix = "user_id = "
+            user = None
+            guest = False
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id.startswith(user_prefix):
+                    user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
+                elif caveat.caveat_id == "guest = true":
+                    guest = True
+
+            if user is None:
+                raise AuthError(
+                    self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+                    errcode=Codes.UNKNOWN_TOKEN
+                )
+
+            if guest:
+                ret = {
+                    "user": user,
+                    "is_guest": True,
+                    "token_id": None,
+                }
+            else:
+                # This codepath exists so that we can actually return a
+                # token ID, because we use token IDs in place of device
+                # identifiers throughout the codebase.
+                # TODO(daniel): Remove this fallback when device IDs are
+                # properly implemented.
+                ret = yield self._look_up_user_by_access_token(macaroon_str)
+                if ret["user"] != user:
+                    logger.error(
+                        "Macaroon user (%s) != DB user (%s)",
+                        user,
+                        ret["user"]
+                    )
+                    raise AuthError(
+                        self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                        "User mismatch in macaroon",
+                        errcode=Codes.UNKNOWN_TOKEN
+                    )
+            defer.returnValue(ret)
+        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+
+    def validate_macaroon(self, macaroon, type_string, additional_validation_functions):
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = " + type_string)
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_exact("guest = true")
+
+        for validation_function in additional_validation_functions:
+            v.satisfy_general(validation_function)
+        v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_general(self._verify_recognizes_caveats)
+        v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+    def verify_expiry(self, caveat):
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix):])
+        now = self.hs.get_clock().time_msec()
+        return now < expiry
+
+    def _verify_recognizes_caveats(self, caveat):
+        first_space = caveat.find(" ")
+        if first_space < 0:
+            return False
+        second_space = caveat.find(" ", first_space + 1)
+        if second_space < 0:
+            return False
+        return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
+
+    @defer.inlineCallbacks
+    def _look_up_user_by_access_token(self, token):
+        ret = yield self.store.get_user_by_access_token(token)
         if not ret:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
                 errcode=Codes.UNKNOWN_TOKEN
             )
         user_info = {
-            "admin": bool(ret.get("admin", False)),
-            "device_id": ret.get("device_id"),
             "user": UserID.from_string(ret.get("name")),
             "token_id": ret.get("token_id", None),
+            "is_guest": False,
         }
-
         defer.returnValue(user_info)
 
     @defer.inlineCallbacks
@@ -488,6 +757,16 @@ class Auth(object):
             else:
                 if member_event:
                     auth_ids.append(member_event.event_id)
+
+            if e_type == Membership.INVITE:
+                if "third_party_invite" in event.content:
+                    key = (
+                        EventTypes.ThirdPartyInvite,
+                        event.content["third_party_invite"]["token"]
+                    )
+                    third_party_invite = current_state.get(key)
+                    if third_party_invite:
+                        auth_ids.append(third_party_invite.event_id)
         elif member_event:
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
@@ -548,16 +827,35 @@ class Auth(object):
 
         return True
 
-    def _check_redaction(self, event, auth_events):
+    def check_redaction(self, event, auth_events):
+        """Check whether the event sender is allowed to redact the target event.
+
+        Returns:
+            True if the the sender is allowed to redact the target event if the
+            target event was created by them.
+            False if the sender is allowed to redact the target event with no
+            further checks.
+
+        Raises:
+            AuthError if the event sender is definitely not allowed to redact
+            the target event.
+        """
         user_level = self._get_user_power_level(event.user_id, auth_events)
 
         redact_level = self._get_named_level(auth_events, "redact", 50)
 
-        if user_level < redact_level:
-            raise AuthError(
-                403,
-                "You don't have permission to redact events"
-            )
+        if user_level > redact_level:
+            return False
+
+        redacter_domain = EventID.from_string(event.event_id).domain
+        redactee_domain = EventID.from_string(event.redacts).domain
+        if redacter_domain == redactee_domain:
+            return True
+
+        raise AuthError(
+            403,
+            "You don't have permission to redact events"
+        )
 
     def _check_power_levels(self, event, auth_events):
         user_list = event.content.get("users", {})
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 1423986c1e..c2450b771a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -27,16 +27,6 @@ class Membership(object):
     LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
 
 
-class Feedback(object):
-
-    """Represents the types of feedback a user can send in response to a
-    message."""
-
-    DELIVERED = u"delivered"
-    READ = u"read"
-    LIST = (DELIVERED, READ)
-
-
 class PresenceState(object):
     """Represents the presence state of a user."""
     OFFLINE = u"offline"
@@ -73,11 +63,12 @@ class EventTypes(object):
     PowerLevels = "m.room.power_levels"
     Aliases = "m.room.aliases"
     Redaction = "m.room.redaction"
-    Feedback = "m.room.message.feedback"
+    ThirdPartyInvite = "m.room.third_party_invite"
 
     RoomHistoryVisibility = "m.room.history_visibility"
     CanonicalAlias = "m.room.canonical_alias"
     RoomAvatar = "m.room.avatar"
+    GuestAccess = "m.room.guest_access"
 
     # These are used for validation
     Message = "m.room.message"
@@ -94,3 +85,4 @@ class RejectedReason(object):
 class RoomCreationPreset(object):
     PRIVATE_CHAT = "private_chat"
     PUBLIC_CHAT = "public_chat"
+    TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c3b4d971a8..d4037b3d55 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -33,6 +33,7 @@ class Codes(object):
     NOT_FOUND = "M_NOT_FOUND"
     MISSING_TOKEN = "M_MISSING_TOKEN"
     UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
+    GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
     LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
     CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
     CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
@@ -47,7 +48,6 @@ class CodeMessageException(RuntimeError):
     """An exception with integer code and message string attributes."""
 
     def __init__(self, code, msg):
-        logger.info("%s: %s, %s", type(self).__name__, code, msg)
         super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
         self.code = code
         self.msg = msg
@@ -77,11 +77,6 @@ class SynapseError(CodeMessageException):
         )
 
 
-class RoomError(SynapseError):
-    """An error raised when a room event fails."""
-    pass
-
-
 class RegistrationError(SynapseError):
     """An error raised when a registration event fails."""
     pass
@@ -125,6 +120,15 @@ class AuthError(SynapseError):
         super(AuthError, self).__init__(*args, **kwargs)
 
 
+class EventSizeError(SynapseError):
+    """An error raised when an event is too big."""
+
+    def __init__(self, *args, **kwargs):
+        if "errcode" not in kwargs:
+            kwargs["errcode"] = Codes.TOO_LARGE
+        super(EventSizeError, self).__init__(413, *args, **kwargs)
+
+
 class EventStreamError(SynapseError):
     """An error raised when there a problem with the event stream."""
     def __init__(self, *args, **kwargs):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4d570b74f8..aaa2433cae 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -24,7 +24,7 @@ class Filtering(object):
 
     def get_user_filter(self, user_localpart, filter_id):
         result = self.store.get_user_filter(user_localpart, filter_id)
-        result.addCallback(Filter)
+        result.addCallback(FilterCollection)
         return result
 
     def add_user_filter(self, user_localpart, user_filter):
@@ -50,11 +50,11 @@ class Filtering(object):
         # many definitions.
 
         top_level_definitions = [
-            "public_user_data", "private_user_data", "server_data"
+            "presence"
         ]
 
         room_level_definitions = [
-            "state", "events", "ephemeral"
+            "state", "timeline", "ephemeral", "private_user_data"
         ]
 
         for key in top_level_definitions:
@@ -114,116 +114,134 @@ class Filtering(object):
                     if not isinstance(event_type, basestring):
                         raise SynapseError(400, "Event type should be a string")
 
-        if "format" in definition:
-            event_format = definition["format"]
-            if event_format not in ["federation", "events"]:
-                raise SynapseError(400, "Invalid format: %s" % (event_format,))
 
-        if "select" in definition:
-            event_select_list = definition["select"]
-            for select_key in event_select_list:
-                if select_key not in ["event_id", "origin_server_ts",
-                                      "thread_id", "content", "content.body"]:
-                    raise SynapseError(400, "Bad select: %s" % (select_key,))
+class FilterCollection(object):
+    def __init__(self, filter_json):
+        self.filter_json = filter_json
 
-        if ("bundle_updates" in definition and
-                type(definition["bundle_updates"]) != bool):
-            raise SynapseError(400, "Bad bundle_updates: expected bool.")
+        self.room_timeline_filter = Filter(
+            self.filter_json.get("room", {}).get("timeline", {})
+        )
 
+        self.room_state_filter = Filter(
+            self.filter_json.get("room", {}).get("state", {})
+        )
 
-class Filter(object):
-    def __init__(self, filter_json):
-        self.filter_json = filter_json
+        self.room_ephemeral_filter = Filter(
+            self.filter_json.get("room", {}).get("ephemeral", {})
+        )
+
+        self.room_private_user_data = Filter(
+            self.filter_json.get("room", {}).get("private_user_data", {})
+        )
+
+        self.presence_filter = Filter(
+            self.filter_json.get("presence", {})
+        )
+
+    def timeline_limit(self):
+        return self.room_timeline_filter.limit()
+
+    def presence_limit(self):
+        return self.presence_filter.limit()
 
-    def filter_public_user_data(self, events):
-        return self._filter_on_key(events, ["public_user_data"])
+    def ephemeral_limit(self):
+        return self.room_ephemeral_filter.limit()
 
-    def filter_private_user_data(self, events):
-        return self._filter_on_key(events, ["private_user_data"])
+    def filter_presence(self, events):
+        return self.presence_filter.filter(events)
 
     def filter_room_state(self, events):
-        return self._filter_on_key(events, ["room", "state"])
+        return self.room_state_filter.filter(events)
 
-    def filter_room_events(self, events):
-        return self._filter_on_key(events, ["room", "events"])
+    def filter_room_timeline(self, events):
+        return self.room_timeline_filter.filter(events)
 
     def filter_room_ephemeral(self, events):
-        return self._filter_on_key(events, ["room", "ephemeral"])
+        return self.room_ephemeral_filter.filter(events)
 
-    def _filter_on_key(self, events, keys):
-        filter_json = self.filter_json
-        if not filter_json:
-            return events
+    def filter_room_private_user_data(self, events):
+        return self.room_private_user_data.filter(events)
 
-        try:
-            # extract the right definition from the filter
-            definition = filter_json
-            for key in keys:
-                definition = definition[key]
-            return self._filter_with_definition(events, definition)
-        except KeyError:
-            # return all events if definition isn't specified.
-            return events
 
-    def _filter_with_definition(self, events, definition):
-        return [e for e in events if self._passes_definition(definition, e)]
+class Filter(object):
+    def __init__(self, filter_json):
+        self.filter_json = filter_json
 
-    def _passes_definition(self, definition, event):
-        """Check if the event passes through the given definition.
+    def check(self, event):
+        """Checks whether the filter matches the given event.
 
-        Args:
-            definition(dict): The definition to check against.
-            event(Event): The event to check.
         Returns:
-            True if the event passes through the filter.
+            bool: True if the event matches
         """
-        # Algorithm notes:
-        # For each key in the definition, check the event meets the criteria:
-        #   * For types: Literal match or prefix match (if ends with wildcard)
-        #   * For senders/rooms: Literal match only
-        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
-        #     and 'not_types' then it is treated as only being in 'not_types')
-
-        # room checks
-        if hasattr(event, "room_id"):
-            room_id = event.room_id
-            allow_rooms = definition.get("rooms", None)
-            reject_rooms = definition.get("not_rooms", None)
-            if reject_rooms and room_id in reject_rooms:
-                return False
-            if allow_rooms and room_id not in allow_rooms:
-                return False
+        if isinstance(event, dict):
+            return self.check_fields(
+                event.get("room_id", None),
+                event.get("sender", None),
+                event.get("type", None),
+            )
+        else:
+            return self.check_fields(
+                getattr(event, "room_id", None),
+                getattr(event, "sender", None),
+                event.type,
+            )
 
-        # sender checks
-        if hasattr(event, "sender"):
-            # Should we be including event.state_key for some event types?
-            sender = event.sender
-            allow_senders = definition.get("senders", None)
-            reject_senders = definition.get("not_senders", None)
-            if reject_senders and sender in reject_senders:
-                return False
-            if allow_senders and sender not in allow_senders:
+    def check_fields(self, room_id, sender, event_type):
+        """Checks whether the filter matches the given event fields.
+
+        Returns:
+            bool: True if the event fields match
+        """
+        literal_keys = {
+            "rooms": lambda v: room_id == v,
+            "senders": lambda v: sender == v,
+            "types": lambda v: _matches_wildcard(event_type, v)
+        }
+
+        for name, match_func in literal_keys.items():
+            not_name = "not_%s" % (name,)
+            disallowed_values = self.filter_json.get(not_name, [])
+            if any(map(match_func, disallowed_values)):
                 return False
 
-        # type checks
-        if "not_types" in definition:
-            for def_type in definition["not_types"]:
-                if self._event_matches_type(event, def_type):
+            allowed_values = self.filter_json.get(name, None)
+            if allowed_values is not None:
+                if not any(map(match_func, allowed_values)):
                     return False
-        if "types" in definition:
-            included = False
-            for def_type in definition["types"]:
-                if self._event_matches_type(event, def_type):
-                    included = True
-                    break
-            if not included:
-                return False
 
         return True
 
-    def _event_matches_type(self, event, def_type):
-        if def_type.endswith("*"):
-            type_prefix = def_type[:-1]
-            return event.type.startswith(type_prefix)
-        else:
-            return event.type == def_type
+    def filter_rooms(self, room_ids):
+        """Apply the 'rooms' filter to a given list of rooms.
+
+        Args:
+            room_ids (list): A list of room_ids.
+
+        Returns:
+            list: A list of room_ids that match the filter
+        """
+        room_ids = set(room_ids)
+
+        disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+        room_ids -= disallowed_rooms
+
+        allowed_rooms = self.filter_json.get("rooms", None)
+        if allowed_rooms is not None:
+            room_ids &= set(allowed_rooms)
+
+        return room_ids
+
+    def filter(self, events):
+        return filter(self.check, events)
+
+    def limit(self):
+        return self.filter_json.get("limit", 10)
+
+
+def _matches_wildcard(actual_value, filter_value):
+    if filter_value.endswith("*"):
+        type_prefix = filter_value[:-1]
+        return actual_value.startswith(type_prefix)
+    else:
+        return actual_value == filter_value