summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/schema/delta/25/tags.sql38
-rw-r--r--synapse/storage/tags.py216
-rw-r--r--synapse/storage/transactions.py48
4 files changed, 272 insertions, 32 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a1bd9c4ce9..e7443f2838 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore
 
 from .receipts import ReceiptsStore
 from .search import SearchStore
+from .tags import TagsStore
 
 
 import logging
@@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 ReceiptsStore,
                 EndToEndKeyStore,
                 SearchStore,
+                TagsStore,
                 ):
 
     def __init__(self, hs):
diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql
new file mode 100644
index 0000000000..527424c998
--- /dev/null
+++ b/synapse/storage/schema/delta/25/tags.sql
@@ -0,0 +1,38 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS room_tags(
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    tag     TEXT NOT NULL,  -- The name of the tag.
+    content TEXT NOT NULL,  -- The JSON content of the tag.
+    CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag)
+);
+
+CREATE TABLE IF NOT EXISTS room_tags_revisions (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    stream_id BIGINT NOT NULL, -- The current version of the room tags.
+    CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id)
+);
+
+CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id(
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_id  BIGINT NOT NULL,
+    CHECK (Lock='X')
+);
+
+INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0);
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
new file mode 100644
index 0000000000..641ea250f0
--- /dev/null
+++ b/synapse/storage/tags.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
+from twisted.internet import defer
+from .util.id_generators import StreamIdGenerator
+
+import ujson as json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TagsStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(TagsStore, self).__init__(hs)
+
+        self._private_user_data_id_gen = StreamIdGenerator(
+            "private_user_data_max_stream_id", "stream_id"
+        )
+
+    def get_max_private_user_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._private_user_data_id_gen.get_max_token(self)
+
+    @cached()
+    def get_tags_for_user(self, user_id):
+        """Get all the tags for a user.
+
+
+        Args:
+            user_id(str): The user to get the tags for.
+        Returns:
+            A deferred dict mapping from room_id strings to lists of tag
+            strings.
+        """
+
+        deferred = self._simple_select_list(
+            "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+        )
+
+        @deferred.addCallback
+        def tags_by_room(rows):
+            tags_by_room = {}
+            for row in rows:
+                room_tags = tags_by_room.setdefault(row["room_id"], {})
+                room_tags[row["tag"]] = json.loads(row["content"])
+            return tags_by_room
+
+        return deferred
+
+    @defer.inlineCallbacks
+    def get_updated_tags(self, user_id, stream_id):
+        """Get all the tags for the rooms where the tags have changed since the
+        given version
+
+        Args:
+            user_id(str): The user to get the tags for.
+            stream_id(int): The earliest update to get for the user.
+        Returns:
+            A deferred dict mapping from room_id strings to lists of tag
+            strings for all the rooms that changed since the stream_id token.
+        """
+        def get_updated_tags_txn(txn):
+            sql = (
+                "SELECT room_id from room_tags_revisions"
+                " WHERE user_id = ? AND stream_id > ?"
+            )
+            txn.execute(sql, (user_id, stream_id))
+            room_ids = [row[0] for row in txn.fetchall()]
+            return room_ids
+
+        room_ids = yield self.runInteraction(
+            "get_updated_tags", get_updated_tags_txn
+        )
+
+        results = {}
+        if room_ids:
+            tags_by_room = yield self.get_tags_for_user(user_id)
+            for room_id in room_ids:
+                results[room_id] = tags_by_room[room_id]
+
+        defer.returnValue(results)
+
+    def get_tags_for_room(self, user_id, room_id):
+        """Get all the tags for the given room
+        Args:
+            user_id(str): The user to get tags for
+            room_id(str): The room to get tags for
+        Returns:
+            A deferred list of string tags.
+        """
+        return self._simple_select_list(
+            table="room_tags",
+            keyvalues={"user_id": user_id, "room_id": room_id},
+            retcols=("tag", "content"),
+            desc="get_tags_for_room",
+        ).addCallback(lambda rows: {
+            row["tag"]: json.loads(row["content"]) for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def add_tag_to_room(self, user_id, room_id, tag, content):
+        """Add a tag to a room for a user.
+        Args:
+            user_id(str): The user to add a tag for.
+            room_id(str): The room to add a tag for.
+            tag(str): The tag name to add.
+            content(dict): A json object to associate with the tag.
+        Returns:
+            A deferred that completes once the tag has been added.
+        """
+        content_json = json.dumps(content)
+
+        def add_tag_txn(txn, next_id):
+            self._simple_upsert_txn(
+                txn,
+                table="room_tags",
+                keyvalues={
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "tag": tag,
+                },
+                values={
+                    "content": content_json,
+                }
+            )
+            self._update_revision_txn(txn, user_id, room_id, next_id)
+
+        with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
+            yield self.runInteraction("add_tag", add_tag_txn, next_id)
+
+        self.get_tags_for_user.invalidate((user_id,))
+
+        result = yield self._private_user_data_id_gen.get_max_token(self)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def remove_tag_from_room(self, user_id, room_id, tag):
+        """Remove a tag from a room for a user.
+        Returns:
+            A deferred that completes once the tag has been removed
+        """
+        def remove_tag_txn(txn, next_id):
+            sql = (
+                "DELETE FROM room_tags "
+                " WHERE user_id = ? AND room_id = ? AND tag = ?"
+            )
+            txn.execute(sql, (user_id, room_id, tag))
+            self._update_revision_txn(txn, user_id, room_id, next_id)
+
+        with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
+            yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+
+        self.get_tags_for_user.invalidate((user_id,))
+
+        result = yield self._private_user_data_id_gen.get_max_token(self)
+        defer.returnValue(result)
+
+    def _update_revision_txn(self, txn, user_id, room_id, next_id):
+        """Update the latest revision of the tags for the given user and room.
+
+        Args:
+            txn: The database cursor
+            user_id(str): The ID of the user.
+            room_id(str): The ID of the room.
+            next_id(int): The the revision to advance to.
+        """
+
+        update_max_id_sql = (
+            "UPDATE private_user_data_max_stream_id"
+            " SET stream_id = ?"
+            " WHERE stream_id < ?"
+        )
+        txn.execute(update_max_id_sql, (next_id, next_id))
+
+        update_sql = (
+            "UPDATE room_tags_revisions"
+            " SET stream_id = ?"
+            " WHERE user_id = ?"
+            " AND room_id = ?"
+        )
+        txn.execute(update_sql, (next_id, user_id, room_id))
+
+        if txn.rowcount == 0:
+            insert_sql = (
+                "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
+                " VALUES (?, ?, ?)"
+            )
+            try:
+                txn.execute(insert_sql, (user_id, room_id, next_id))
+            except self.database_engine.module.IntegrityError:
+                # Ignore insertion errors. It doesn't matter if the row wasn't
+                # inserted because if two updates happend concurrently the one
+                # with the higher stream_id will not be reported to a client
+                # unless the previous update has completed. It doesn't matter
+                # which stream_id ends up in the table, as long as it is higher
+                # than the id that the client has.
+                pass
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 15695e9831..4e0d7c9774 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore):
             retry_interval (int) - how long until next retry in ms
         """
 
-        # As this is the new value, we might as well prefill the cache
-        self.get_destination_retry_timings.prefill(
-            destination,
-            {
-                "destination": destination,
-                "retry_last_ts": retry_last_ts,
-                "retry_interval": retry_interval
-            },
-        )
-
         # XXX: we could chose to not bother persisting this if our cache thinks
         # this is a NOOP
         return self.runInteraction(
@@ -275,31 +265,25 @@ class TransactionStore(SQLBaseStore):
 
     def _set_destination_retry_timings(self, txn, destination,
                                        retry_last_ts, retry_interval):
-        query = (
-            "UPDATE destinations"
-            " SET retry_last_ts = ?, retry_interval = ?"
-            " WHERE destination = ?"
-        )
+        txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
 
-        txn.execute(
-            query,
-            (
-                retry_last_ts, retry_interval, destination,
-            )
+        self._simple_upsert_txn(
+            txn,
+            "destinations",
+            keyvalues={
+                "destination": destination,
+            },
+            values={
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval,
+            },
+            insertion_values={
+                "destination": destination,
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval,
+            }
         )
 
-        if txn.rowcount == 0:
-            # destination wasn't already in table. Insert it.
-            self._simple_insert_txn(
-                txn,
-                table="destinations",
-                values={
-                    "destination": destination,
-                    "retry_last_ts": retry_last_ts,
-                    "retry_interval": retry_interval,
-                }
-            )
-
     def get_destinations_needing_retry(self):
         """Get all destinations which are due a retry for sending a transaction.