diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 5 | ||||
-rw-r--r-- | synapse/storage/_base.py | 3 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 311 | ||||
-rw-r--r-- | synapse/storage/schema/delta/21/receipts.sql | 42 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 7 |
5 files changed, 364 insertions, 4 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c137f47820..2bc88a7954 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -38,6 +38,8 @@ from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore +from .receipts import ReceiptsStore + import fnmatch import imp @@ -51,7 +53,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 20 +SCHEMA_VERSION = 21 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -74,6 +76,7 @@ class DataStore(RoomMemberStore, RoomStore, PushRuleStore, ApplicationServiceTransactionStore, EventsStore, + ReceiptsStore, ): def __init__(self, hs): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8d33def6c6..8f812f0fd7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -329,13 +329,14 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator() + self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py new file mode 100644 index 0000000000..503f68f858 --- /dev/null +++ b/synapse/storage/receipts.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore, cached + +from twisted.internet import defer + +from synapse.util import unwrapFirstError + +from blist import sorteddict +import logging + + +logger = logging.getLogger(__name__) + + +class ReceiptsStore(SQLBaseStore): + def __init__(self, hs): + super(ReceiptsStore, self).__init__(hs) + + self._receipts_stream_cache = _RoomStreamChangeCache() + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, from_key, to_key): + room_ids = set(room_ids) + + if from_key: + room_ids = yield self._receipts_stream_cache.get_rooms_changed( + self, room_ids, from_key + ) + + results = yield defer.gatherResults( + [ + self.get_linearized_receipts_for_room(room_id, from_key, to_key) + for room_id in room_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue([ev for res in results for ev in res]) + + @defer.inlineCallbacks + def get_linearized_receipts_for_room(self, room_id, from_key, to_key): + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id > ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, from_key, to_key) + ) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, to_key) + ) + + rows = self.cursor_to_dict(txn) + + return rows + + rows = yield self.runInteraction( + "get_linearized_receipts_for_room", f + ) + + if not rows: + defer.returnValue([]) + + content = {} + for row in rows: + content.setdefault( + row["event_id"], {} + ).setdefault( + row["receipt_type"], [] + ).append(row["user_id"]) + + defer.returnValue([{ + "type": "m.receipt", + "room_id": room_id, + "content": content, + }]) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_max_token(self) + + @cached + @defer.inlineCallbacks + def get_graph_receipts_for_room(self, room_id): + rows = yield self._simple_select_list( + table="receipts_graph", + keyvalues={"room_id": room_id}, + retcols=["receipt_type", "user_id", "event_id"], + desc="get_linearized_receipts_for_room", + ) + + result = {} + for row in rows: + result.setdefault( + row["user_id"], {} + ).setdefault( + row["receipt_type"], [] + ).append(row["event_id"]) + + defer.returnValue(result) + + def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_id, stream_id): + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + sql = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + + txn.execute(sql, (room_id, receipt_type, user_id)) + results = txn.fetchall() + + if results: + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + ) + topological_ordering = int(res["topological_ordering"]) + stream_ordering = int(res["stream_ordering"]) + + for to, so, _ in results: + if int(to) > topological_ordering: + return False + elif int(to) == topological_ordering and int(so) >= stream_ordering: + return False + + self._simple_delete_txn( + txn, + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + + self._simple_insert_txn( + txn, + table="receipts_linearized", + values={ + "stream_id": stream_id, + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_id": event_id, + } + ) + + return True + + @defer.inlineCallbacks + def insert_receipt(self, room_id, receipt_type, user_id, event_ids): + if not event_ids: + return + + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + def graph_to_linear(txn): + query = ( + "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" + " SELECT max(stream_ordering) WHERE event_id IN (%s)" + ")" + ) % (",".join(["?"] * len(event_ids))) + + txn.execute(query, [room_id] + event_ids) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + # TODO: ARGH?! + return None + + linearized_event_id = yield self.runInteraction( + "insert_receipt_conv", graph_to_linear + ) + + stream_id_manager = yield self._receipts_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self._receipts_stream_cache.room_has_changed( + self, room_id, stream_id + ) + have_persisted = yield self.runInteraction( + "insert_linearized_receipt", + self.insert_linearized_receipt_txn, + room_id, receipt_type, user_id, linearized_event_id, + stream_id=stream_id, + ) + + if not have_persisted: + defer.returnValue(None) + + yield self.insert_graph_receipt( + room_id, receipt_type, user_id, event_ids + ) + + max_persisted_id = yield self._stream_id_gen.get_max_token(self) + defer.returnValue((stream_id, max_persisted_id)) + + def insert_graph_receipt(self, room_id, receipt_type, + user_id, event_ids): + return self.runInteraction( + "insert_graph_receipt", + self.insert_graph_receipt_txn, + room_id, receipt_type, user_id, event_ids, + ) + + def insert_graph_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_ids): + self._simple_delete_txn( + txn, + table="receipts_graph", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + self._simple_insert_many_txn( + txn, + table="receipts_graph", + values=[ + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_id": event_id, + } + for event_id in event_ids + ], + ) + + +class _RoomStreamChangeCache(object): + """Keeps track of the stream_id of the latest change in rooms. + + Given a list of rooms and stream key, it will give a subset of rooms that + may have changed since that key. If the key is too old then the cache + will simply return all rooms. + """ + def __init__(self, size_of_cache=1000): + self._size_of_cache = size_of_cache + self._room_to_key = {} + self._cache = sorteddict() + self._earliest_key = None + + @defer.inlineCallbacks + def get_rooms_changed(self, store, room_ids, key): + if key > (yield self._get_earliest_key(store)): + keys = self._cache.keys() + i = keys.bisect_right(key) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(room_ids) + else: + result = room_ids + + defer.returnValue(result) + + @defer.inlineCallbacks + def room_has_changed(self, store, room_id, key): + if key > (yield self._get_earliest_key(store)): + old_key = self._room_to_key.get(room_id, None) + if old_key: + key = max(key, old_key) + self._cache.pop(old_key, None) + self._cache[key] = room_id + + while len(self._cache) > self._size_of_cache: + k, r = self._cache.popitem() + self._earliest_key = max(k, self._earliest_key) + self._room_to_key.pop(r, None) + + @defer.inlineCallbacks + def _get_earliest_key(self, store): + if self._earliest_key is None: + self._earliest_key = yield store.get_max_receipt_stream_id() + self._earliest_key = int(self._earliest_key) + + defer.returnValue(self._earliest_key) diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql new file mode 100644 index 0000000000..ac7738e371 --- /dev/null +++ b/synapse/storage/schema/delta/21/receipts.sql @@ -0,0 +1,42 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS receipts_graph( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE INDEX receipts_graph_room_tuple ON receipts_graph( + room_id, receipt_type, user_id +); + +CREATE TABLE IF NOT EXISTS receipts_linearized ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE INDEX receipts_linearized_room_tuple ON receipts_linearized( + room_id, receipt_type, user_id +); + +CREATE INDEX receipts_linearized_id ON receipts_linearized( + stream_id +); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 89d1643f10..b39006315d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -72,7 +72,10 @@ class StreamIdGenerator(object): with stream_id_gen.get_next_txn(txn) as stream_id: # ... persist event ... """ - def __init__(self): + def __init__(self, table, column): + self.table = table + self.column = column + self._lock = threading.Lock() self._current_max = None @@ -126,7 +129,7 @@ class StreamIdGenerator(object): def _get_or_compute_current_max(self, txn): with self._lock: - txn.execute("SELECT MAX(stream_ordering) FROM events") + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) rows = txn.fetchall() val, = rows[0] |