summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-10-20 16:57:51 +0100
committerErik Johnston <erik@matrix.org>2015-10-20 16:57:51 +0100
commit44e2933bf87d278ef0d860ee04e61ecac9744480 (patch)
treec6a15fc8843f99a9f9ce0e47ab2b4254f5d285e4 /synapse/api
parentExplicitly check for Sqlite3Engine (diff)
parentDocstring (diff)
downloadsynapse-44e2933bf87d278ef0d860ee04e61ecac9744480.tar.xz
Merge branch 'erikj/filter_refactor' into erikj/search
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py27
-rw-r--r--synapse/api/filtering.py156
2 files changed, 78 insertions, 105 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5c83aafa7d..494c8ac3d4 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,7 +14,8 @@
 # limitations under the License.
 
 """This module contains classes for authenticating the user."""
-from nacl.exceptions import BadSignatureError
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json, SignatureVerifyException
 
 from twisted.internet import defer
 
@@ -26,7 +27,6 @@ from synapse.util import third_party_invites
 from unpaddedbase64 import decode_base64
 
 import logging
-import nacl.signing
 import pymacaroons
 
 logger = logging.getLogger(__name__)
@@ -308,7 +308,11 @@ class Auth(object):
         )
 
         if Membership.JOIN != membership:
-            # JOIN is the only action you can perform if you're not in the room
+            if (caller_invited
+                    and Membership.LEAVE == membership
+                    and target_user_id == event.user_id):
+                return True
+
             if not caller_in_room:  # caller isn't joined
                 raise AuthError(
                     403,
@@ -416,16 +420,23 @@ class Auth(object):
                     key_validity_url
                 )
                 return False
-            for _, signature_block in join_third_party_invite["signatures"].items():
+            signed = join_third_party_invite["signed"]
+            if signed["mxid"] != event.user_id:
+                return False
+            if signed["token"] != token:
+                return False
+            for server, signature_block in signed["signatures"].items():
                 for key_name, encoded_signature in signature_block.items():
                     if not key_name.startswith("ed25519:"):
                         return False
-                    verify_key = nacl.signing.VerifyKey(decode_base64(public_key))
-                    signature = decode_base64(encoded_signature)
-                    verify_key.verify(token, signature)
+                    verify_key = decode_verify_key_bytes(
+                        key_name,
+                        decode_base64(public_key)
+                    )
+                    verify_signed_json(signed, server, verify_key)
                     return True
             return False
-        except (KeyError, BadSignatureError,):
+        except (KeyError, SignatureVerifyException,):
             return False
 
     def _get_power_level_event(self, auth_events):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index e79e91e7eb..60b6648e0d 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -24,7 +24,7 @@ class Filtering(object):
 
     def get_user_filter(self, user_localpart, filter_id):
         result = self.store.get_user_filter(user_localpart, filter_id)
-        result.addCallback(Filter)
+        result.addCallback(FilterCollection)
         return result
 
     def add_user_filter(self, user_localpart, user_filter):
@@ -131,125 +131,87 @@ class Filtering(object):
             raise SynapseError(400, "Bad bundle_updates: expected bool.")
 
 
-class Filter(object):
+class FilterCollection(object):
     def __init__(self, filter_json):
         self.filter_json = filter_json
 
+        self.room_timeline_filter = Filter(
+            self.filter_json.get("room", {}).get("timeline", {})
+        )
+
+        self.room_state_filter = Filter(
+            self.filter_json.get("room", {}).get("state", {})
+        )
+
+        self.room_ephemeral_filter = Filter(
+            self.filter_json.get("room", {}).get("ephemeral", {})
+        )
+
+        self.presence_filter = Filter(
+            self.filter_json.get("presence", {})
+        )
+
     def timeline_limit(self):
-        return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10)
+        return self.room_timeline_filter.limit()
 
     def presence_limit(self):
-        return self.filter_json.get("presence", {}).get("limit", 10)
+        return self.presence_filter.limit()
 
     def ephemeral_limit(self):
-        return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10)
+        return self.room_ephemeral_filter.limit()
 
     def filter_presence(self, events):
-        return self._filter_on_key(events, ["presence"])
+        return self.presence_filter.filter(events)
 
     def filter_room_state(self, events):
-        return self._filter_on_key(events, ["room", "state"])
+        return self.room_state_filter.filter(events)
 
     def filter_room_timeline(self, events):
-        return self._filter_on_key(events, ["room", "timeline"])
+        return self.room_timeline_filter.filter(events)
 
     def filter_room_ephemeral(self, events):
-        return self._filter_on_key(events, ["room", "ephemeral"])
-
-    def _filter_on_key(self, events, keys):
-        filter_json = self.filter_json
-        if not filter_json:
-            return events
-
-        try:
-            # extract the right definition from the filter
-            definition = filter_json
-            for key in keys:
-                definition = definition[key]
-            return self._filter_with_definition(events, definition)
-        except KeyError:
-            # return all events if definition isn't specified.
-            return events
-
-    def _filter_with_definition(self, events, definition):
-        return [e for e in events if self._passes_definition(definition, e)]
-
-    def _passes_definition(self, definition, event):
-        """Check if the event passes the filter definition
-        Args:
-            definition(dict): The filter definition to check against
-            event(dict or Event): The event to check
-        Returns:
-            True if the event passes the filter in the definition
-        """
-        if type(event) is dict:
-            room_id = event.get("room_id")
-            sender = event.get("sender")
-            event_type = event["type"]
-        else:
-            room_id = getattr(event, "room_id", None)
-            sender = getattr(event, "sender", None)
-            event_type = event.type
-        return self._event_passes_definition(
-            definition, room_id, sender, event_type
-        )
+        return self.room_ephemeral_filter.filter(events)
 
-    def _event_passes_definition(self, definition, room_id, sender,
-                                 event_type):
-        """Check if the event passes through the given definition.
 
-        Args:
-            definition(dict): The definition to check against.
-            room_id(str): The id of the room this event is in or None.
-            sender(str): The sender of the event
-            event_type(str): The type of the event.
+class Filter(object):
+    def __init__(self, filter_json):
+        self.filter_json = filter_json
+
+    def check(self, event):
+        """Checks whether the filter matches the given event.
+
         Returns:
-            True if the event passes through the filter.
+            bool: True if the event matches
         """
-        # Algorithm notes:
-        # For each key in the definition, check the event meets the criteria:
-        #   * For types: Literal match or prefix match (if ends with wildcard)
-        #   * For senders/rooms: Literal match only
-        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
-        #     and 'not_types' then it is treated as only being in 'not_types')
-
-        # room checks
-        if room_id is not None:
-            allow_rooms = definition.get("rooms", None)
-            reject_rooms = definition.get("not_rooms", None)
-            if reject_rooms and room_id in reject_rooms:
-                return False
-            if allow_rooms and room_id not in allow_rooms:
+        literal_keys = {
+            "rooms": lambda v: event.room_id == v,
+            "senders": lambda v: event.sender == v,
+            "types": lambda v: _matches_wildcard(event.type, v)
+        }
+
+        for name, match_func in literal_keys.items():
+            not_name = "not_%s" % (name,)
+            disallowed_values = self.filter_json.get(not_name, [])
+            if any(map(match_func, disallowed_values)):
                 return False
 
-        # sender checks
-        if sender is not None:
-            allow_senders = definition.get("senders", None)
-            reject_senders = definition.get("not_senders", None)
-            if reject_senders and sender in reject_senders:
-                return False
-            if allow_senders and sender not in allow_senders:
-                return False
-
-        # type checks
-        if "not_types" in definition:
-            for def_type in definition["not_types"]:
-                if self._event_matches_type(event_type, def_type):
+            allowed_values = self.filter_json.get(name, None)
+            if allowed_values is not None:
+                if not any(map(match_func, allowed_values)):
                     return False
-        if "types" in definition:
-            included = False
-            for def_type in definition["types"]:
-                if self._event_matches_type(event_type, def_type):
-                    included = True
-                    break
-            if not included:
-                return False
 
         return True
 
-    def _event_matches_type(self, event_type, def_type):
-        if def_type.endswith("*"):
-            type_prefix = def_type[:-1]
-            return event_type.startswith(type_prefix)
-        else:
-            return event_type == def_type
+    def filter(self, events):
+        return filter(self.check, events)
+
+    def limit(self):
+        return self.filter_json.get("limit", 10)
+
+
+def _matches_wildcard(actual_value, filter_value):
+    if filter_value.endswith("*"):
+        type_prefix = filter_value[:-1]
+        return actual_value.startswith(type_prefix)
+    else:
+        return actual_value == filter_value