summary refs log tree commit diff
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2015-01-27 18:46:03 +0000
committerPaul "LeoNerd" Evans <paul@matrix.org>2015-01-27 18:46:03 +0000
commit06cc1470129d443f71bfc81ba716f63b9505467d (patch)
tree48a551ad093909d2544ed1459900ab74c6148d75
parentMore unit-testing of REST errors (diff)
downloadsynapse-06cc1470129d443f71bfc81ba716f63b9505467d.tar.xz
Initial stab at real SQL storage implementation of user filter definitions
-rw-r--r--synapse/storage/__init__.py1
-rw-r--r--synapse/storage/filtering.py49
-rw-r--r--synapse/storage/schema/filtering.sql24
-rw-r--r--tests/api/test_filtering.py19
4 files changed, 78 insertions, 15 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index efa63031bd..7c5631d014 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -61,6 +61,7 @@ SCHEMAS = [
     "event_edges",
     "event_signatures",
     "media_repository",
+    "filtering",
 ]
 
 
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 18e0e7c298..e98eaf8032 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -17,6 +17,8 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 
+import json
+
 
 # TODO(paul)
 _filters_for_user = {}
@@ -25,22 +27,41 @@ _filters_for_user = {}
 class FilteringStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_user_filter(self, user_localpart, filter_id):
-        filters = _filters_for_user.get(user_localpart, None)
-
-        if not filters or filter_id >= len(filters):
-            raise KeyError()
+        def_json = yield self._simple_select_one_onecol(
+            table="user_filters",
+            keyvalues={
+                "user_id": user_localpart,
+                "filter_id": filter_id,
+            },
+            retcol="definition",
+            allow_none=False,
+        )
 
-        # trivial yield to make it a generator so d.iC works
-        yield
-        defer.returnValue(filters[filter_id])
+        defer.returnValue(json.loads(def_json))
 
-    @defer.inlineCallbacks
     def add_user_filter(self, user_localpart, definition):
-        filters = _filters_for_user.setdefault(user_localpart, [])
+        def_json = json.dumps(definition)
+
+        # Need an atomic transaction to SELECT the maximal ID so far then
+        # INSERT a new one
+        def _do_txn(txn):
+            sql = (
+                "SELECT MAX(filter_id) FROM user_filters "
+                "WHERE user_id = ?"
+            )
+            txn.execute(sql, (user_localpart,))
+            max_id = txn.fetchone()[0]
+            if max_id is None:
+                filter_id = 0
+            else:
+                filter_id = max_id + 1
+
+            sql = (
+                "INSERT INTO user_filters (user_id, filter_id, definition)"
+                "VALUES(?, ?, ?)"
+            )
+            txn.execute(sql, (user_localpart, filter_id, def_json))
 
-        filter_id = len(filters)
-        filters.append(definition)
+            return filter_id
 
-        # trivial yield, see above
-        yield
-        defer.returnValue(filter_id)
+        return self.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql
new file mode 100644
index 0000000000..795aca4afd
--- /dev/null
+++ b/synapse/storage/schema/filtering.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS user_filters(
+  user_id TEXT,
+  filter_id INTEGER,
+  definition TEXT,
+  FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
+  user_id, filter_id
+);
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index fecadd1056..149948374d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -53,16 +53,33 @@ class FilteringTestCase(unittest.TestCase):
 
         self.filtering = hs.get_filtering()
 
+        self.datastore = hs.get_datastore()
+
     @defer.inlineCallbacks
-    def test_filter(self):
+    def test_add_filter(self):
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=user_localpart,
             definition={"type": ["m.*"]},
         )
+
         self.assertEquals(filter_id, 0)
+        self.assertEquals({"type": ["m.*"]},
+            (yield self.datastore.get_user_filter(
+                user_localpart=user_localpart,
+                filter_id=0,
+            ))
+        )
+
+    @defer.inlineCallbacks
+    def test_get_filter(self):
+        filter_id = yield self.datastore.add_user_filter(
+            user_localpart=user_localpart,
+            definition={"type": ["m.*"]},
+        )
 
         filter = yield self.filtering.get_user_filter(
             user_localpart=user_localpart,
             filter_id=filter_id,
         )
+
         self.assertEquals(filter, {"type": ["m.*"]})