summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-07-16 18:18:36 +0100
committerErik Johnston <erik@matrix.org>2015-07-16 18:18:36 +0100
commitb6d4a4c6d883a4c3084938e11fc5bb654b5779e4 (patch)
treedbcdb1eb3459c5a53c7bdb8117d2aaee26332366 /synapse/storage
parentMerge pull request #203 from matrix-org/erikj/room_creation_presets (diff)
parentDocs (diff)
downloadsynapse-b6d4a4c6d883a4c3084938e11fc5bb654b5779e4.tar.xz
Merge pull request #199 from matrix-org/erikj/receipts
Implement read receipts.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/receipts.py348
-rw-r--r--synapse/storage/schema/delta/21/receipts.sql38
-rw-r--r--synapse/storage/util/id_generators.py7
5 files changed, 396 insertions, 3 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e089d81675..71d5d92500 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -39,6 +39,8 @@ from .signatures import SignatureStore
 from .filtering import FilteringStore
 from .end_to_end_keys import EndToEndKeyStore
 
+from .receipts import ReceiptsStore
+
 
 import fnmatch
 import imp
@@ -75,6 +77,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 PushRuleStore,
                 ApplicationServiceTransactionStore,
                 EventsStore,
+                ReceiptsStore,
                 EndToEndKeyStore,
                 ):
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8d33def6c6..8f812f0fd7 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -329,13 +329,14 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
-        self._stream_id_gen = StreamIdGenerator()
+        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
         self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
         self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+        self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
new file mode 100644
index 0000000000..d515a0a15c
--- /dev/null
+++ b/synapse/storage/receipts.py
@@ -0,0 +1,348 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore, cached
+
+from twisted.internet import defer
+
+from synapse.util import unwrapFirstError
+
+from blist import sorteddict
+import logging
+import ujson as json
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(ReceiptsStore, self).__init__(hs)
+
+        self._receipts_stream_cache = _RoomStreamChangeCache()
+
+    @defer.inlineCallbacks
+    def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+        """Get receipts for multiple rooms for sending to clients.
+
+        Args:
+            room_ids (list): List of room_ids.
+            to_key (int): Max stream id to fetch receipts upto.
+            from_key (int): Min stream id to fetch receipts from. None fetches
+                from the start.
+
+        Returns:
+            list: A list of receipts.
+        """
+        room_ids = set(room_ids)
+
+        if from_key:
+            room_ids = yield self._receipts_stream_cache.get_rooms_changed(
+                self, room_ids, from_key
+            )
+
+        results = yield defer.gatherResults(
+            [
+                self.get_linearized_receipts_for_room(
+                    room_id, to_key, from_key=from_key
+                )
+                for room_id in room_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        defer.returnValue([ev for res in results for ev in res])
+
+    @defer.inlineCallbacks
+    def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+        """Get receipts for a single room for sending to clients.
+
+        Args:
+            room_ids (str): The room id.
+            to_key (int): Max stream id to fetch receipts upto.
+            from_key (int): Min stream id to fetch receipts from. None fetches
+                from the start.
+
+        Returns:
+            list: A list of receipts.
+        """
+        def f(txn):
+            if from_key:
+                sql = (
+                    "SELECT * FROM receipts_linearized WHERE"
+                    " room_id = ? AND stream_id > ? AND stream_id <= ?"
+                )
+
+                txn.execute(
+                    sql,
+                    (room_id, from_key, to_key)
+                )
+            else:
+                sql = (
+                    "SELECT * FROM receipts_linearized WHERE"
+                    " room_id = ? AND stream_id <= ?"
+                )
+
+                txn.execute(
+                    sql,
+                    (room_id, to_key)
+                )
+
+            rows = self.cursor_to_dict(txn)
+
+            return rows
+
+        rows = yield self.runInteraction(
+            "get_linearized_receipts_for_room", f
+        )
+
+        if not rows:
+            defer.returnValue([])
+
+        content = {}
+        for row in rows:
+            content.setdefault(
+                row["event_id"], {}
+            ).setdefault(
+                row["receipt_type"], {}
+            )[row["user_id"]] = json.loads(row["data"])
+
+        defer.returnValue([{
+            "type": "m.receipt",
+            "room_id": room_id,
+            "content": content,
+        }])
+
+    def get_max_receipt_stream_id(self):
+        return self._receipts_id_gen.get_max_token(self)
+
+    @cached
+    @defer.inlineCallbacks
+    def get_graph_receipts_for_room(self, room_id):
+        """Get receipts for sending to remote servers.
+        """
+        rows = yield self._simple_select_list(
+            table="receipts_graph",
+            keyvalues={"room_id": room_id},
+            retcols=["receipt_type", "user_id", "event_id"],
+            desc="get_linearized_receipts_for_room",
+        )
+
+        result = {}
+        for row in rows:
+            result.setdefault(
+                row["user_id"], {}
+            ).setdefault(
+                row["receipt_type"], []
+            ).append(row["event_id"])
+
+        defer.returnValue(result)
+
+    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+                                      user_id, event_id, data, stream_id):
+
+        # We don't want to clobber receipts for more recent events, so we
+        # have to compare orderings of existing receipts
+        sql = (
+            "SELECT topological_ordering, stream_ordering, event_id FROM events"
+            " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+            " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+        )
+
+        txn.execute(sql, (room_id, receipt_type, user_id))
+        results = txn.fetchall()
+
+        if results:
+            res = self._simple_select_one_txn(
+                txn,
+                table="events",
+                retcols=["topological_ordering", "stream_ordering"],
+                keyvalues={"event_id": event_id},
+            )
+            topological_ordering = int(res["topological_ordering"])
+            stream_ordering = int(res["stream_ordering"])
+
+            for to, so, _ in results:
+                if int(to) > topological_ordering:
+                    return False
+                elif int(to) == topological_ordering and int(so) >= stream_ordering:
+                    return False
+
+        self._simple_delete_txn(
+            txn,
+            table="receipts_linearized",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+            }
+        )
+
+        self._simple_insert_txn(
+            txn,
+            table="receipts_linearized",
+            values={
+                "stream_id": stream_id,
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+                "event_id": event_id,
+                "data": json.dumps(data),
+            }
+        )
+
+        return True
+
+    @defer.inlineCallbacks
+    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+        """Insert a receipt, either from local client or remote server.
+
+        Automatically does conversion between linearized and graph
+        representations.
+        """
+        if not event_ids:
+            return
+
+        if len(event_ids) == 1:
+            linearized_event_id = event_ids[0]
+        else:
+            # we need to points in graph -> linearized form.
+            # TODO: Make this better.
+            def graph_to_linear(txn):
+                query = (
+                    "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
+                    " SELECT max(stream_ordering) WHERE event_id IN (%s)"
+                    ")"
+                ) % (",".join(["?"] * len(event_ids)))
+
+                txn.execute(query, [room_id] + event_ids)
+                rows = txn.fetchall()
+                if rows:
+                    return rows[0][0]
+                else:
+                    raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
+            linearized_event_id = yield self.runInteraction(
+                "insert_receipt_conv", graph_to_linear
+            )
+
+        stream_id_manager = yield self._receipts_id_gen.get_next(self)
+        with stream_id_manager as stream_id:
+            yield self._receipts_stream_cache.room_has_changed(
+                self, room_id, stream_id
+            )
+            have_persisted = yield self.runInteraction(
+                "insert_linearized_receipt",
+                self.insert_linearized_receipt_txn,
+                room_id, receipt_type, user_id, linearized_event_id,
+                data,
+                stream_id=stream_id,
+            )
+
+            if not have_persisted:
+                defer.returnValue(None)
+
+        yield self.insert_graph_receipt(
+            room_id, receipt_type, user_id, event_ids, data
+        )
+
+        max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+        defer.returnValue((stream_id, max_persisted_id))
+
+    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
+                             data):
+        return self.runInteraction(
+            "insert_graph_receipt",
+            self.insert_graph_receipt_txn,
+            room_id, receipt_type, user_id, event_ids, data
+        )
+
+    def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
+                                 user_id, event_ids, data):
+        self._simple_delete_txn(
+            txn,
+            table="receipts_graph",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+            }
+        )
+        self._simple_insert_txn(
+            txn,
+            table="receipts_graph",
+            values={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+                "event_ids": json.dumps(event_ids),
+                "data": json.dumps(data),
+            }
+        )
+
+
+class _RoomStreamChangeCache(object):
+    """Keeps track of the stream_id of the latest change in rooms.
+
+    Given a list of rooms and stream key, it will give a subset of rooms that
+    may have changed since that key. If the key is too old then the cache
+    will simply return all rooms.
+    """
+    def __init__(self, size_of_cache=1000):
+        self._size_of_cache = size_of_cache
+        self._room_to_key = {}
+        self._cache = sorteddict()
+        self._earliest_key = None
+
+    @defer.inlineCallbacks
+    def get_rooms_changed(self, store, room_ids, key):
+        """Returns subset of room ids that have had new receipts since the
+        given key. If the key is too old it will just return the given list.
+        """
+        if key > (yield self._get_earliest_key(store)):
+            keys = self._cache.keys()
+            i = keys.bisect_right(key)
+
+            result = set(
+                self._cache[k] for k in keys[i:]
+            ).intersection(room_ids)
+        else:
+            result = room_ids
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def room_has_changed(self, store, room_id, key):
+        """Informs the cache that the room has been changed at the given key.
+        """
+        if key > (yield self._get_earliest_key(store)):
+            old_key = self._room_to_key.get(room_id, None)
+            if old_key:
+                key = max(key, old_key)
+                self._cache.pop(old_key, None)
+            self._cache[key] = room_id
+
+            while len(self._cache) > self._size_of_cache:
+                k, r = self._cache.popitem()
+                self._earliest_key = max(k, self._earliest_key)
+                self._room_to_key.pop(r, None)
+
+    @defer.inlineCallbacks
+    def _get_earliest_key(self, store):
+        if self._earliest_key is None:
+            self._earliest_key = yield store.get_max_receipt_stream_id()
+            self._earliest_key = int(self._earliest_key)
+
+        defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql
new file mode 100644
index 0000000000..2f64d609fc
--- /dev/null
+++ b/synapse/storage/schema/delta/21/receipts.sql
@@ -0,0 +1,38 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS receipts_graph(
+    room_id TEXT NOT NULL,
+    receipt_type TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    event_ids TEXT NOT NULL,
+    data TEXT NOT NULL,
+    CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE TABLE IF NOT EXISTS receipts_linearized (
+    stream_id BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    receipt_type TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    event_id TEXT NOT NULL,
+    data TEXT NOT NULL,
+    CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE INDEX receipts_linearized_id ON receipts_linearized(
+  stream_id
+);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 83eab63098..e956df62c7 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,7 +72,10 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next_txn(txn) as stream_id:
             # ... persist event ...
     """
-    def __init__(self):
+    def __init__(self, table, column):
+        self.table = table
+        self.column = column
+
         self._lock = threading.Lock()
 
         self._current_max = None
@@ -157,7 +160,7 @@ class StreamIdGenerator(object):
 
     def _get_or_compute_current_max(self, txn):
         with self._lock:
-            txn.execute("SELECT MAX(stream_ordering) FROM events")
+            txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
             rows = txn.fetchall()
             val, = rows[0]