summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-07-01 11:41:55 +0100
committerErik Johnston <erik@matrix.org>2015-07-01 17:19:12 +0100
commit80a61330ee794147b213b1d54f2292a1c9adc002 (patch)
treee6832467f89a0829cc8a7c76288234b4facb082c /synapse/storage
parentAdd tables for receipts (diff)
downloadsynapse-80a61330ee794147b213b1d54f2292a1c9adc002.tar.xz
Add basic storage functions for handling of receipts
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/receipts.py162
-rw-r--r--synapse/storage/schema/delta/21/receipts.sql31
-rw-r--r--synapse/storage/util/id_generators.py7
4 files changed, 186 insertions, 17 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8d33def6c6..8f812f0fd7 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -329,13 +329,14 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
-        self._stream_id_gen = StreamIdGenerator()
+        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
         self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
         self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+        self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
new file mode 100644
index 0000000000..0168e74a0d
--- /dev/null
+++ b/synapse/storage/receipts.py
@@ -0,0 +1,162 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore, cached
+
+from twisted.internet import defer
+
+
+class ReceiptStore(SQLBaseStore):
+
+    @cached
+    @defer.inlineCallbacks
+    def get_linearized_receipts_for_room(self, room_id):
+        rows = yield self._simple_select_list(
+            table="receipts_linearized",
+            keyvalues={"room_id": room_id},
+            retcols=["receipt_type", "user_id", "event_id"],
+            desc="get_linearized_receipts_for_room",
+        )
+
+        result = {}
+        for row in rows:
+            result.setdefault(
+                row["event_id"], {}
+            ).setdefault(
+                row["receipt_type"], []
+            ).append(row["user_id"])
+
+        defer.returnValue(result)
+
+    @cached
+    @defer.inlineCallbacks
+    def get_graph_receipts_for_room(self, room_id):
+        rows = yield self._simple_select_list(
+            table="receipts_graph",
+            keyvalues={"room_id": room_id},
+            retcols=["receipt_type", "user_id", "event_id"],
+            desc="get_linearized_receipts_for_room",
+        )
+
+        result = {}
+        for row in rows:
+            result.setdefault(
+                row["user_id"], {}
+            ).setdefault(
+                row["receipt_type"], []
+            ).append(row["event_id"])
+
+        defer.returnValue(result)
+
+    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+                                      user_id, event_id, stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="receipts_linearized",
+            keyvalues={
+                "stream_id": stream_id,
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+            }
+        )
+
+        self._simple_insert_txn(
+            txn,
+            table="receipts_linearized",
+            values={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+                "event_id": event_id,
+            }
+        )
+
+    @defer.inlineCallbacks
+    def insert_receipt(self, room_id, receipt_type, user_id, event_ids):
+        if not event_ids:
+            return
+
+        if len(event_ids) == 1:
+            linearized_event_id = event_ids[0]
+        else:
+            # we need to points in graph -> linearized form.
+            def graph_to_linear(txn):
+                query = (
+                    "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
+                    " SELECT max(stream_ordering) WHERE event_id IN (%s)"
+                    ")"
+                ) % (",".join(["?"] * len(event_ids)))
+
+                txn.execute(query, [room_id] + event_ids)
+                rows = txn.fetchall()
+                if rows:
+                    return rows[0][0]
+                else:
+                    # TODO: ARGH?!
+                    return None
+
+            linearized_event_id = yield self.runInteraction(
+                graph_to_linear, desc="insert_receipt_conv"
+            )
+
+        stream_id_manager = yield self._stream_id_gen.get_next(self)
+        with stream_id_manager() as stream_id:
+            yield self.runInteraction(
+                self.insert_linearized_receipt_txn,
+                room_id, receipt_type, user_id, linearized_event_id,
+                stream_id=stream_id,
+                desc="insert_linearized_receipt"
+            )
+
+        yield self.insert_graph_receipt(
+            room_id, receipt_type, user_id, event_ids
+        )
+
+        max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+        defer.returnValue((stream_id, max_persisted_id))
+
+    def insert_graph_receipt(self, room_id, receipt_type,
+                             user_id, event_ids):
+        return self.runInteraction(
+            self.insert_graph_receipt_txn,
+            room_id, receipt_type, user_id, event_ids,
+            desc="insert_graph_receipt"
+        )
+
+    def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
+                                 user_id, event_ids):
+        self._simple_delete_txn(
+            txn,
+            table="receipts_graph",
+            keyvalues={
+                "room_id": room_id,
+                "receipt_type": receipt_type,
+                "user_id": user_id,
+            }
+        )
+        self._simple_insert_many_txn(
+            txn,
+            table="receipts_graph",
+            values=[
+                {
+                    "room_id": room_id,
+                    "receipt_type": receipt_type,
+                    "user_id": user_id,
+                    "event_id": event_id,
+                }
+                for event_id in event_ids
+            ],
+        )
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql
index da9e18e903..ccd64ec7f4 100644
--- a/synapse/storage/schema/delta/21/receipts.sql
+++ b/synapse/storage/schema/delta/21/receipts.sql
@@ -1,16 +1,18 @@
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 
 CREATE TABLE IF NOT EXISTS receipts_graph(
     room_id TEXT NOT NULL,
@@ -24,12 +26,13 @@ CREATE INDEX receipts_graph_room_tuple ON receipts_graph(
 );
 
 CREATE TABLE IF NOT EXISTS receipts_linearized (
+    stream_id BIGINT NOT NULL,
     room_id TEXT NOT NULL,
     receipt_type TEXT NOT NULL,
     user_id TEXT NOT NULL,
     event_id TEXT NOT NULL
 );
 
-CREATE INDEX receipts_graph_room_tuple ON receipts_graph(
+CREATE INDEX receipts_linearized_room_tuple ON receipts_graph(
   room_id, receipt_type, user_id
 );
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 89d1643f10..b39006315d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,7 +72,10 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next_txn(txn) as stream_id:
             # ... persist event ...
     """
-    def __init__(self):
+    def __init__(self, table, column):
+        self.table = table
+        self.column = column
+
         self._lock = threading.Lock()
 
         self._current_max = None
@@ -126,7 +129,7 @@ class StreamIdGenerator(object):
 
     def _get_or_compute_current_max(self, txn):
         with self._lock:
-            txn.execute("SELECT MAX(stream_ordering) FROM events")
+            txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
             rows = txn.fetchall()
             val, = rows[0]