diff options
author | Erik Johnston <erik@matrix.org> | 2015-11-17 15:45:43 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2015-11-17 15:45:43 +0000 |
commit | d3861b44424aa6f03cc65719bb1527330157abea (patch) | |
tree | 4377eb0dc5e221862489bdcc802e50e2f1f41cb1 /synapse/api | |
parent | Merge branch 'hotfixes-v0.10.0-r2' of github.com:matrix-org/synapse (diff) | |
parent | Slightly more aggressive retry timers at HTTP level (diff) | |
download | synapse-d3861b44424aa6f03cc65719bb1527330157abea.tar.xz |
Merge branch 'release-v0.11.0' of github.com:matrix-org/synapse v0.11.0
Diffstat (limited to 'synapse/api')
-rw-r--r-- | synapse/api/auth.py | 354 | ||||
-rw-r--r-- | synapse/api/constants.py | 14 | ||||
-rw-r--r-- | synapse/api/errors.py | 16 | ||||
-rw-r--r-- | synapse/api/filtering.py | 200 |
4 files changed, 448 insertions, 136 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1e3b0fbfb7..8111b34428 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,15 +14,20 @@ # limitations under the License. """This module contains classes for authenticating the user.""" +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError +from synapse.types import RoomID, UserID, EventID from synapse.util.logutils import log_function -from synapse.types import UserID, ClientInfo +from unpaddedbase64 import decode_base64 import logging +import pymacaroons logger = logging.getLogger(__name__) @@ -30,6 +35,7 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.ThirdPartyInvite, ) @@ -40,6 +46,13 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + self._KNOWN_CAVEAT_PREFIXES = set([ + "gen = ", + "guest = ", + "type = ", + "time < ", + "user_id = ", + ]) def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -52,6 +65,8 @@ class Auth(object): Returns: True if the auth checks pass. """ + self.check_size_limits(event) + try: if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) @@ -65,6 +80,23 @@ class Auth(object): # FIXME return True + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = RoomID.from_string(event.room_id).domain + originating_domain = UserID.from_string(event.sender).domain + if creating_domain != originating_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # FIXME: Temp hack if event.type == EventTypes.Aliases: return True @@ -91,7 +123,7 @@ class Auth(object): self._check_power_levels(event, auth_events) if event.type == EventTypes.Redaction: - self._check_redaction(event, auth_events) + self.check_redaction(event, auth_events) logger.debug("Allowing! %s", event) except AuthError as e: @@ -102,8 +134,39 @@ class Auth(object): logger.info("Denying! %s", event) raise + def check_size_limits(self, event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): + """Check if the user is currently joined in the room + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user is not in the room. + Returns: + A deferred membership event for the user if the user is in + the room. + """ if current_state: member = current_state.get( (EventTypes.Member, user_id), @@ -120,6 +183,33 @@ class Auth(object): defer.returnValue(member) @defer.inlineCallbacks + def check_user_was_in_room(self, room_id, user_id): + """Check if the user was in the room at some point. + Args: + room_id(str): The room to check. + user_id(str): The user to check. + Raises: + AuthError if the user was never in the room. + Returns: + A deferred membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. + """ + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership not in (Membership.JOIN, Membership.LEAVE): + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + + defer.returnValue(member) + + @defer.inlineCallbacks def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) @@ -153,6 +243,11 @@ class Auth(object): user_id, room_id, repr(member) )) + def can_federate(self, event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + @log_function def is_membership_change_allowed(self, event, auth_events): membership = event.content["membership"] @@ -168,6 +263,15 @@ class Auth(object): target_user_id = event.state_key + creating_domain = RoomID.from_string(event.room_id).domain + target_domain = UserID.from_string(target_user_id).domain + if creating_domain != target_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # get info about the caller key = (EventTypes.Member, event.user_id, ) caller = auth_events.get(key) @@ -213,8 +317,17 @@ class Auth(object): } ) + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not self._verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + return True + if Membership.JOIN != membership: - # JOIN is the only action you can perform if you're not in the room + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + if not caller_in_room: # caller isn't joined raise AuthError( 403, @@ -278,6 +391,66 @@ class Auth(object): return True + def _verify_third_party_invite(self, event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if event.user_id != invite_event.user_id: + return False + try: + public_key = invite_event.content["public_key"] + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + return False + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + return False + except (KeyError, SignatureVerifyException,): + return False + def _get_power_level_event(self, auth_events): key = (EventTypes.PowerLevels, "", ) return auth_events.get(key) @@ -316,15 +489,15 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request): + def get_user_by_req(self, request, allow_guest=False): """ Get a registered user's ID. Args: request - An HTTP request with an access_token query parameter. Returns: - tuple : of UserID and device string: - User ID object of the user making the request - ClientInfo object of the client instance the user is using + tuple of: + UserID (str) + Access token ID (str) Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -354,17 +527,15 @@ class Auth(object): request.authenticated_entity = user_id - defer.returnValue( - (UserID.from_string(user_id), ClientInfo("", "")) - ) + defer.returnValue((UserID.from_string(user_id), "", False)) return except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_token(access_token) + user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] - device_id = user_info["device_id"] token_id = user_info["token_id"] + is_guest = user_info["is_guest"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -375,14 +546,18 @@ class Auth(object): self.store.insert_client_ip( user=user, access_token=access_token, - device_id=user_info["device_id"], ip=ip_addr, user_agent=user_agent ) + if is_guest and not allow_guest: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + request.authenticated_entity = user.to_string() - defer.returnValue((user, ClientInfo(device_id, token_id))) + defer.returnValue((user, token_id, is_guest,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -390,30 +565,124 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_token(self, token): + def _get_user_by_access_token(self, token): """ Get a registered user's ID. Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user, device_id, and whether the - user is a server admin. + dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ - ret = yield self.store.get_user_by_token(token) + try: + ret = yield self._get_user_from_macaroon(token) + except AuthError: + # TODO(daniel): Remove this fallback when all existing access tokens + # have been re-issued as macaroons. + ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) + + @defer.inlineCallbacks + def _get_user_from_macaroon(self, macaroon_str): + try: + macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) + self.validate_macaroon( + macaroon, "access", + [lambda c: c.startswith("time < ")] + ) + + user_prefix = "user_id = " + user = None + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) + elif caveat.caveat_id == "guest = true": + guest = True + + if user is None: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + + if guest: + ret = { + "user": user, + "is_guest": True, + "token_id": None, + } + else: + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN + ) + + def validate_macaroon(self, macaroon, type_string, additional_validation_functions): + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = " + type_string) + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_exact("guest = true") + + for validation_function in additional_validation_functions: + v.satisfy_general(validation_function) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + v = pymacaroons.Verifier() + v.satisfy_general(self._verify_recognizes_caveats) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def verify_expiry(self, caveat): + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix):]) + now = self.hs.get_clock().time_msec() + return now < expiry + + def _verify_recognizes_caveats(self, caveat): + first_space = caveat.find(" ") + if first_space < 0: + return False + second_space = caveat.find(" ", first_space + 1) + if second_space < 0: + return False + return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES + + @defer.inlineCallbacks + def _look_up_user_by_access_token(self, token): + ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), + "is_guest": False, } - defer.returnValue(user_info) @defer.inlineCallbacks @@ -488,6 +757,16 @@ class Auth(object): else: if member_event: auth_ids.append(member_event.event_id) + + if e_type == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["token"] + ) + third_party_invite = current_state.get(key) + if third_party_invite: + auth_ids.append(third_party_invite.event_id) elif member_event: if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) @@ -548,16 +827,35 @@ class Auth(object): return True - def _check_redaction(self, event, auth_events): + def check_redaction(self, event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ user_level = self._get_user_power_level(event.user_id, auth_events) redact_level = self._get_named_level(auth_events, "redact", 50) - if user_level < redact_level: - raise AuthError( - 403, - "You don't have permission to redact events" - ) + if user_level > redact_level: + return False + + redacter_domain = EventID.from_string(event.event_id).domain + redactee_domain = EventID.from_string(event.redacts).domain + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) def _check_power_levels(self, event, auth_events): user_list = event.content.get("users", {}) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 1423986c1e..c2450b771a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -27,16 +27,6 @@ class Membership(object): LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) -class Feedback(object): - - """Represents the types of feedback a user can send in response to a - message.""" - - DELIVERED = u"delivered" - READ = u"read" - LIST = (DELIVERED, READ) - - class PresenceState(object): """Represents the presence state of a user.""" OFFLINE = u"offline" @@ -73,11 +63,12 @@ class EventTypes(object): PowerLevels = "m.room.power_levels" Aliases = "m.room.aliases" Redaction = "m.room.redaction" - Feedback = "m.room.message.feedback" + ThirdPartyInvite = "m.room.third_party_invite" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" RoomAvatar = "m.room.avatar" + GuestAccess = "m.room.guest_access" # These are used for validation Message = "m.room.message" @@ -94,3 +85,4 @@ class RejectedReason(object): class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" + TRUSTED_PRIVATE_CHAT = "trusted_private_chat" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c3b4d971a8..d4037b3d55 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -33,6 +33,7 @@ class Codes(object): NOT_FOUND = "M_NOT_FOUND" MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" + GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" @@ -47,7 +48,6 @@ class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes.""" def __init__(self, code, msg): - logger.info("%s: %s, %s", type(self).__name__, code, msg) super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code self.msg = msg @@ -77,11 +77,6 @@ class SynapseError(CodeMessageException): ) -class RoomError(SynapseError): - """An error raised when a room event fails.""" - pass - - class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass @@ -125,6 +120,15 @@ class AuthError(SynapseError): super(AuthError, self).__init__(*args, **kwargs) +class EventSizeError(SynapseError): + """An error raised when an event is too big.""" + + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.TOO_LARGE + super(EventSizeError, self).__init__(413, *args, **kwargs) + + class EventStreamError(SynapseError): """An error raised when there a problem with the event stream.""" def __init__(self, *args, **kwargs): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4d570b74f8..aaa2433cae 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -24,7 +24,7 @@ class Filtering(object): def get_user_filter(self, user_localpart, filter_id): result = self.store.get_user_filter(user_localpart, filter_id) - result.addCallback(Filter) + result.addCallback(FilterCollection) return result def add_user_filter(self, user_localpart, user_filter): @@ -50,11 +50,11 @@ class Filtering(object): # many definitions. top_level_definitions = [ - "public_user_data", "private_user_data", "server_data" + "presence" ] room_level_definitions = [ - "state", "events", "ephemeral" + "state", "timeline", "ephemeral", "private_user_data" ] for key in top_level_definitions: @@ -114,116 +114,134 @@ class Filtering(object): if not isinstance(event_type, basestring): raise SynapseError(400, "Event type should be a string") - if "format" in definition: - event_format = definition["format"] - if event_format not in ["federation", "events"]: - raise SynapseError(400, "Invalid format: %s" % (event_format,)) - if "select" in definition: - event_select_list = definition["select"] - for select_key in event_select_list: - if select_key not in ["event_id", "origin_server_ts", - "thread_id", "content", "content.body"]: - raise SynapseError(400, "Bad select: %s" % (select_key,)) +class FilterCollection(object): + def __init__(self, filter_json): + self.filter_json = filter_json - if ("bundle_updates" in definition and - type(definition["bundle_updates"]) != bool): - raise SynapseError(400, "Bad bundle_updates: expected bool.") + self.room_timeline_filter = Filter( + self.filter_json.get("room", {}).get("timeline", {}) + ) + self.room_state_filter = Filter( + self.filter_json.get("room", {}).get("state", {}) + ) -class Filter(object): - def __init__(self, filter_json): - self.filter_json = filter_json + self.room_ephemeral_filter = Filter( + self.filter_json.get("room", {}).get("ephemeral", {}) + ) + + self.room_private_user_data = Filter( + self.filter_json.get("room", {}).get("private_user_data", {}) + ) + + self.presence_filter = Filter( + self.filter_json.get("presence", {}) + ) + + def timeline_limit(self): + return self.room_timeline_filter.limit() + + def presence_limit(self): + return self.presence_filter.limit() - def filter_public_user_data(self, events): - return self._filter_on_key(events, ["public_user_data"]) + def ephemeral_limit(self): + return self.room_ephemeral_filter.limit() - def filter_private_user_data(self, events): - return self._filter_on_key(events, ["private_user_data"]) + def filter_presence(self, events): + return self.presence_filter.filter(events) def filter_room_state(self, events): - return self._filter_on_key(events, ["room", "state"]) + return self.room_state_filter.filter(events) - def filter_room_events(self, events): - return self._filter_on_key(events, ["room", "events"]) + def filter_room_timeline(self, events): + return self.room_timeline_filter.filter(events) def filter_room_ephemeral(self, events): - return self._filter_on_key(events, ["room", "ephemeral"]) + return self.room_ephemeral_filter.filter(events) - def _filter_on_key(self, events, keys): - filter_json = self.filter_json - if not filter_json: - return events + def filter_room_private_user_data(self, events): + return self.room_private_user_data.filter(events) - try: - # extract the right definition from the filter - definition = filter_json - for key in keys: - definition = definition[key] - return self._filter_with_definition(events, definition) - except KeyError: - # return all events if definition isn't specified. - return events - def _filter_with_definition(self, events, definition): - return [e for e in events if self._passes_definition(definition, e)] +class Filter(object): + def __init__(self, filter_json): + self.filter_json = filter_json - def _passes_definition(self, definition, event): - """Check if the event passes through the given definition. + def check(self, event): + """Checks whether the filter matches the given event. - Args: - definition(dict): The definition to check against. - event(Event): The event to check. Returns: - True if the event passes through the filter. + bool: True if the event matches """ - # Algorithm notes: - # For each key in the definition, check the event meets the criteria: - # * For types: Literal match or prefix match (if ends with wildcard) - # * For senders/rooms: Literal match only - # * "not_" checks take presedence (e.g. if "m.*" is in both 'types' - # and 'not_types' then it is treated as only being in 'not_types') - - # room checks - if hasattr(event, "room_id"): - room_id = event.room_id - allow_rooms = definition.get("rooms", None) - reject_rooms = definition.get("not_rooms", None) - if reject_rooms and room_id in reject_rooms: - return False - if allow_rooms and room_id not in allow_rooms: - return False + if isinstance(event, dict): + return self.check_fields( + event.get("room_id", None), + event.get("sender", None), + event.get("type", None), + ) + else: + return self.check_fields( + getattr(event, "room_id", None), + getattr(event, "sender", None), + event.type, + ) - # sender checks - if hasattr(event, "sender"): - # Should we be including event.state_key for some event types? - sender = event.sender - allow_senders = definition.get("senders", None) - reject_senders = definition.get("not_senders", None) - if reject_senders and sender in reject_senders: - return False - if allow_senders and sender not in allow_senders: + def check_fields(self, room_id, sender, event_type): + """Checks whether the filter matches the given event fields. + + Returns: + bool: True if the event fields match + """ + literal_keys = { + "rooms": lambda v: room_id == v, + "senders": lambda v: sender == v, + "types": lambda v: _matches_wildcard(event_type, v) + } + + for name, match_func in literal_keys.items(): + not_name = "not_%s" % (name,) + disallowed_values = self.filter_json.get(not_name, []) + if any(map(match_func, disallowed_values)): return False - # type checks - if "not_types" in definition: - for def_type in definition["not_types"]: - if self._event_matches_type(event, def_type): + allowed_values = self.filter_json.get(name, None) + if allowed_values is not None: + if not any(map(match_func, allowed_values)): return False - if "types" in definition: - included = False - for def_type in definition["types"]: - if self._event_matches_type(event, def_type): - included = True - break - if not included: - return False return True - def _event_matches_type(self, event, def_type): - if def_type.endswith("*"): - type_prefix = def_type[:-1] - return event.type.startswith(type_prefix) - else: - return event.type == def_type + def filter_rooms(self, room_ids): + """Apply the 'rooms' filter to a given list of rooms. + + Args: + room_ids (list): A list of room_ids. + + Returns: + list: A list of room_ids that match the filter + """ + room_ids = set(room_ids) + + disallowed_rooms = set(self.filter_json.get("not_rooms", [])) + room_ids -= disallowed_rooms + + allowed_rooms = self.filter_json.get("rooms", None) + if allowed_rooms is not None: + room_ids &= set(allowed_rooms) + + return room_ids + + def filter(self, events): + return filter(self.check, events) + + def limit(self): + return self.filter_json.get("limit", 10) + + +def _matches_wildcard(actual_value, filter_value): + if filter_value.endswith("*"): + type_prefix = filter_value[:-1] + return actual_value.startswith(type_prefix) + else: + return actual_value == filter_value |