summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/filtering.py246
-rw-r--r--synapse/rest/client/v2_alpha/__init__.py7
-rw-r--r--synapse/rest/client/v2_alpha/filter.py104
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/__init__.py4
-rw-r--r--synapse/storage/filtering.py67
-rw-r--r--synapse/storage/schema/delta/v13.sql24
-rw-r--r--synapse/storage/schema/filtering.sql24
8 files changed, 479 insertions, 2 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
new file mode 100644
index 0000000000..7e239138b7
--- /dev/null
+++ b/synapse/api/filtering.py
@@ -0,0 +1,246 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, RoomID
+
+
+class Filtering(object):
+
+    def __init__(self, hs):
+        super(Filtering, self).__init__()
+        self.store = hs.get_datastore()
+
+    def get_user_filter(self, user_localpart, filter_id):
+        return self.store.get_user_filter(user_localpart, filter_id)
+
+    def add_user_filter(self, user_localpart, user_filter):
+        self._check_valid_filter(user_filter)
+        return self.store.add_user_filter(user_localpart, user_filter)
+
+    def filter_public_user_data(self, events, user, filter_id):
+        return self._filter_on_key(
+            events, user, filter_id, ["public_user_data"]
+        )
+
+    def filter_private_user_data(self, events, user, filter_id):
+        return self._filter_on_key(
+            events, user, filter_id, ["private_user_data"]
+        )
+
+    def filter_room_state(self, events, user, filter_id):
+        return self._filter_on_key(
+            events, user, filter_id, ["room", "state"]
+        )
+
+    def filter_room_events(self, events, user, filter_id):
+        return self._filter_on_key(
+            events, user, filter_id, ["room", "events"]
+        )
+
+    def filter_room_ephemeral(self, events, user, filter_id):
+        return self._filter_on_key(
+            events, user, filter_id, ["room", "ephemeral"]
+        )
+
+    # TODO(paul): surely we should probably add a delete_user_filter or
+    #   replace_user_filter at some point? There's no REST API specified for
+    #   them however
+
+    @defer.inlineCallbacks
+    def _filter_on_key(self, events, user, filter_id, keys):
+        filter_json = yield self.get_user_filter(user.localpart, filter_id)
+        if not filter_json:
+            defer.returnValue(events)
+
+        try:
+            # extract the right definition from the filter
+            definition = filter_json
+            for key in keys:
+                definition = definition[key]
+            defer.returnValue(self._filter_with_definition(events, definition))
+        except KeyError:
+            # return all events if definition isn't specified.
+            defer.returnValue(events)  
+
+    def _filter_with_definition(self, events, definition):
+        return [e for e in events if self._passes_definition(definition, e)]
+
+    def _passes_definition(self, definition, event):
+        """Check if the event passes through the given definition.
+
+        Args:
+            definition(dict): The definition to check against.
+            event(Event): The event to check.
+        Returns:
+            True if the event passes through the filter.
+        """
+        # Algorithm notes:
+        # For each key in the definition, check the event meets the criteria:
+        #   * For types: Literal match or prefix match (if ends with wildcard)
+        #   * For senders/rooms: Literal match only
+        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
+        #     and 'not_types' then it is treated as only being in 'not_types')
+        
+        # room checks
+        if hasattr(event, "room_id"):
+            room_id = event.room_id
+            allow_rooms = definition["rooms"] if "rooms" in definition else None
+            reject_rooms = (
+                definition["not_rooms"] if "not_rooms" in definition else None
+            )
+            if reject_rooms and room_id in reject_rooms:
+                return False
+            if allow_rooms and room_id not in allow_rooms:
+                return False
+
+        # sender checks
+        if hasattr(event, "sender"):
+            # Should we be including event.state_key for some event types?
+            sender = event.sender
+            allow_senders = (
+                definition["senders"] if "senders" in definition else None
+            )
+            reject_senders = (
+                definition["not_senders"] if "not_senders" in definition else None
+            )
+            if reject_senders and sender in reject_senders:
+                return False
+            if allow_senders and sender not in allow_senders:
+                return False
+
+        # type checks
+        if "not_types" in definition:
+            for def_type in definition["not_types"]:
+                if self._event_matches_type(event, def_type):
+                    return False
+        if "types" in definition:
+            included = False
+            for def_type in definition["types"]:
+                if self._event_matches_type(event, def_type):
+                    included = True
+                    break
+            if not included:
+                return False
+
+        return True
+
+    def _event_matches_type(self, event, def_type):
+        if def_type.endswith("*"):
+            type_prefix = def_type[:-1]
+            return event.type.startswith(type_prefix)
+        else:
+            return event.type == def_type
+
+    def _check_valid_filter(self, user_filter):
+        """Check if the provided filter is valid.
+
+        This inspects all definitions contained within the filter.
+
+        Args:
+            user_filter(dict): The filter
+        Raises:
+            SynapseError: If the filter is not valid.
+        """
+        # NB: Filters are the complete json blobs. "Definitions" are an
+        # individual top-level key e.g. public_user_data. Filters are made of
+        # many definitions.
+
+        top_level_definitions = [
+            "public_user_data", "private_user_data", "server_data"
+        ]
+
+        room_level_definitions = [
+            "state", "events", "ephemeral"
+        ]
+
+        for key in top_level_definitions:
+            if key in user_filter:
+                self._check_definition(user_filter[key])
+
+        if "room" in user_filter:
+            for key in room_level_definitions:
+                if key in user_filter["room"]:
+                    self._check_definition(user_filter["room"][key])
+
+
+    def _check_definition(self, definition):
+        """Check if the provided definition is valid.
+
+        This inspects not only the types but also the values to make sure they
+        make sense.
+
+        Args:
+            definition(dict): The filter definition
+        Raises:
+            SynapseError: If there was a problem with this definition.
+        """
+        # NB: Filters are the complete json blobs. "Definitions" are an
+        # individual top-level key e.g. public_user_data. Filters are made of
+        # many definitions.
+        if type(definition) != dict:
+            raise SynapseError(
+                400, "Expected JSON object, not %s" % (definition,)
+            )
+
+        # check rooms are valid room IDs
+        room_id_keys = ["rooms", "not_rooms"]
+        for key in room_id_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for room_id in definition[key]:
+                    RoomID.from_string(room_id)
+
+        # check senders are valid user IDs
+        user_id_keys = ["senders", "not_senders"]
+        for key in user_id_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for user_id in definition[key]:
+                    UserID.from_string(user_id)
+
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        event_keys = ["types", "not_types"]
+        for key in event_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for event_type in definition[key]:
+                    if not isinstance(event_type, basestring):
+                        raise SynapseError(400, "Event type should be a string")
+
+        try:
+            event_format = definition["format"]
+            if event_format not in ["federation", "events"]:
+                raise SynapseError(400, "Invalid format: %s" % (event_format,))
+        except KeyError:
+            pass  # format is optional
+
+        try:
+            event_select_list = definition["select"]
+            for select_key in event_select_list:
+                if select_key not in ["event_id", "origin_server_ts",
+                                      "thread_id", "content", "content.body"]:
+                    raise SynapseError(400, "Bad select: %s" % (select_key,))
+        except KeyError:
+            pass  # select is optional
+
+        if ("bundle_updates" in definition and
+                type(definition["bundle_updates"]) != bool):
+            raise SynapseError(400, "Bad bundle_updates: expected bool.")
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index bb740e2803..4349579ab7 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -14,6 +14,11 @@
 # limitations under the License.
 
 
+from . import (
+    filter
+)
+
+
 from synapse.http.server import JsonResource
 
 
@@ -26,4 +31,4 @@ class ClientV2AlphaRestResource(JsonResource):
 
     @staticmethod
     def register_servlets(client_resource, hs):
-        pass
+        filter.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
new file mode 100644
index 0000000000..81a3e95155
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_v2_pattern
+
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class GetFilterRestServlet(RestServlet):
+    PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
+
+    def __init__(self, hs):
+        super(GetFilterRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.filtering = hs.get_filtering()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, filter_id):
+        target_user = UserID.from_string(user_id)
+        auth_user = yield self.auth.get_user_by_req(request)
+
+        if target_user != auth_user:
+            raise AuthError(403, "Cannot get filters for other users")
+
+        if not self.hs.is_mine(target_user):
+            raise SynapseError(400, "Can only get filters for local users")
+
+        try:
+            filter_id = int(filter_id)
+        except:
+            raise SynapseError(400, "Invalid filter_id")
+
+        try:
+            filter = yield self.filtering.get_user_filter(
+                user_localpart=target_user.localpart,
+                filter_id=filter_id,
+            )
+
+            defer.returnValue((200, filter))
+        except KeyError:
+            raise SynapseError(400, "No such filter")
+
+
+class CreateFilterRestServlet(RestServlet):
+    PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
+
+    def __init__(self, hs):
+        super(CreateFilterRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.filtering = hs.get_filtering()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id):
+        target_user = UserID.from_string(user_id)
+        auth_user = yield self.auth.get_user_by_req(request)
+
+        if target_user != auth_user:
+            raise AuthError(403, "Cannot create filters for other users")
+
+        if not self.hs.is_mine(target_user):
+            raise SynapseError(400, "Can only create filters for local users")
+
+        try:
+            content = json.loads(request.content.read())
+
+            # TODO(paul): check for required keys and invalid keys
+        except:
+            raise SynapseError(400, "Invalid filter definition")
+
+        filter_id = yield self.filtering.add_user_filter(
+            user_localpart=target_user.localpart,
+            user_filter=content,
+        )
+
+        defer.returnValue((200, {"filter_id": str(filter_id)}))
+
+
+def register_servlets(hs, http_server):
+    GetFilterRestServlet(hs).register(http_server)
+    CreateFilterRestServlet(hs).register(http_server)
diff --git a/synapse/server.py b/synapse/server.py
index f152f0321e..23bdad0c7c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -33,6 +33,7 @@ from synapse.api.ratelimiting import Ratelimiter
 from synapse.crypto.keyring import Keyring
 from synapse.push.pusherpool import PusherPool
 from synapse.events.builder import EventBuilderFactory
+from synapse.api.filtering import Filtering
 
 
 class BaseHomeServer(object):
@@ -81,6 +82,7 @@ class BaseHomeServer(object):
         'keyring',
         'pusherpool',
         'event_builder_factory',
+        'filtering',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -200,5 +202,8 @@ class HomeServer(BaseHomeServer):
             hostname=self.hostname,
         )
 
+    def build_filtering(self):
+        return Filtering(self)
+
     def build_pusherpool(self):
         return PusherPool(self)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 277581b4e2..89a1e60c2b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -32,9 +32,9 @@ from .event_federation import EventFederationStore
 from .pusher import PusherStore
 from .push_rule import PushRuleStore
 from .media_repository import MediaRepositoryStore
-
 from .state import StateStore
 from .signatures import SignatureStore
+from .filtering import FilteringStore
 
 from syutil.base64util import decode_base64
 from syutil.jsonutil import encode_canonical_json
@@ -64,6 +64,7 @@ SCHEMAS = [
     "event_signatures",
     "pusher",
     "media_repository",
+    "filtering",
 ]
 
 
@@ -85,6 +86,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 DirectoryStore, KeyStore, StateStore, SignatureStore,
                 EventFederationStore,
                 MediaRepositoryStore,
+                FilteringStore,
                 PusherStore,
                 PushRuleStore
                 ):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
new file mode 100644
index 0000000000..cb01c2040f
--- /dev/null
+++ b/synapse/storage/filtering.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+import json
+
+
+# TODO(paul)
+_filters_for_user = {}
+
+
+class FilteringStore(SQLBaseStore):
+    @defer.inlineCallbacks
+    def get_user_filter(self, user_localpart, filter_id):
+        def_json = yield self._simple_select_one_onecol(
+            table="user_filters",
+            keyvalues={
+                "user_id": user_localpart,
+                "filter_id": filter_id,
+            },
+            retcol="filter_json",
+            allow_none=False,
+        )
+
+        defer.returnValue(json.loads(def_json))
+
+    def add_user_filter(self, user_localpart, user_filter):
+        def_json = json.dumps(user_filter)
+
+        # Need an atomic transaction to SELECT the maximal ID so far then
+        # INSERT a new one
+        def _do_txn(txn):
+            sql = (
+                "SELECT MAX(filter_id) FROM user_filters "
+                "WHERE user_id = ?"
+            )
+            txn.execute(sql, (user_localpart,))
+            max_id = txn.fetchone()[0]
+            if max_id is None:
+                filter_id = 0
+            else:
+                filter_id = max_id + 1
+
+            sql = (
+                "INSERT INTO user_filters (user_id, filter_id, filter_json)"
+                "VALUES(?, ?, ?)"
+            )
+            txn.execute(sql, (user_localpart, filter_id, def_json))
+
+            return filter_id
+
+        return self.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/v13.sql
new file mode 100644
index 0000000000..beb39ca201
--- /dev/null
+++ b/synapse/storage/schema/delta/v13.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS user_filters(
+  user_id TEXT,
+  filter_id INTEGER,
+  filter_json TEXT,
+  FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
+  user_id, filter_id
+);
diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql
new file mode 100644
index 0000000000..beb39ca201
--- /dev/null
+++ b/synapse/storage/schema/filtering.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS user_filters(
+  user_id TEXT,
+  filter_id INTEGER,
+  filter_json TEXT,
+  FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
+  user_id, filter_id
+);