diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 4 | ||||
-rw-r--r-- | synapse/storage/events.py | 26 | ||||
-rw-r--r-- | synapse/storage/events_worker.py | 26 | ||||
-rw-r--r-- | synapse/storage/keys.py | 90 | ||||
-rw-r--r-- | synapse/storage/relations.py | 476 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 32 | ||||
-rw-r--r-- | synapse/storage/schema/delta/54/add_validity_to_server_keys.sql | 23 | ||||
-rw-r--r-- | synapse/storage/schema/delta/54/relations.sql | 27 | ||||
-rw-r--r-- | synapse/storage/schema/delta/54/stats.sql | 80 | ||||
-rw-r--r-- | synapse/storage/state_deltas.py | 12 | ||||
-rw-r--r-- | synapse/storage/stats.py | 450 | ||||
-rw-r--r-- | synapse/storage/stream.py | 194 |
12 files changed, 1336 insertions, 104 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c432041b4e..66675d08ae 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,11 +49,13 @@ from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore from .rejections import RejectionsStore +from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore from .signatures import SignatureStore from .state import StateStore +from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore @@ -99,6 +101,8 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + StatsStore, + RelationsStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7a7f841c6c..2ffc27ff41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -575,10 +575,11 @@ class EventsStore( def _get_events(txn, batch): sql = """ - SELECT prev_event_id + SELECT prev_event_id, internal_metadata FROM event_edges INNER JOIN events USING (event_id) LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) WHERE prev_event_id IN (%s) AND NOT events.outlier @@ -588,7 +589,11 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend(r[0] for r in txn) + results.extend( + r[0] + for r in txn + if not json.loads(r[1]).get("soft_failed") + ) for chunk in batch_iter(event_ids, 100): yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) @@ -1325,6 +1330,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1351,6 +1359,8 @@ class EventsStore( # Insert into the event_search table. self._store_guest_access_txn(txn, event) + self._handle_event_relations(txn, event) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1655,10 +1665,11 @@ class EventsStore( def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1673,11 +1684,12 @@ class EventsStore( sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering DESC" @@ -1698,10 +1710,11 @@ class EventsStore( def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1716,11 +1729,12 @@ class EventsStore( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" " ORDER BY event_stream_ordering DESC" diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index adc6cf26b5..21b353cad3 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -610,4 +610,28 @@ class EventsWorkerStore(SQLBaseStore): return res - return self.runInteraction("get_rejection_reasons", f) + return self.runInteraction("get_seen_events_with_rejections", f) + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, room_id + ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 7036541792..5300720dbb 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -19,6 +19,7 @@ import logging import six +import attr from signedjson.key import decode_verify_key_bytes from synapse.util import batch_iter @@ -36,6 +37,12 @@ else: db_binary_type = memoryview +@attr.s(slots=True, frozen=True) +class FetchKeyResult(object): + verify_key = attr.ib() # VerifyKey: the key itself + valid_until_ts = attr.ib() # int: how long we can use this key for + + class KeyStore(SQLBaseStore): """Persistence for signature verification keys """ @@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore): iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: - map from (server_name, key_id) -> VerifyKey, or None if the key is + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is unknown """ keys = {} @@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore): # batch_iter always returns tuples so it's safe to do len(batch) sql = ( - "SELECT server_name, key_id, verify_key FROM server_signature_keys " - "WHERE 1=0" + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" ) + " OR (server_name=? AND key_id=?)" * len(batch) txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) for row in txn: - server_name, key_id, key_bytes = row - keys[(server_name, key_id)] = decode_verify_key_bytes( - key_id, bytes(key_bytes) + server_name, key_id, key_bytes, ts_valid_until_ms = row + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, ) + keys[(server_name, key_id)] = res def _txn(txn): for batch in batch_iter(server_name_and_key_ids, 50): @@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore): return self.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_key( - self, server_name, from_server, time_now_ms, verify_key - ): - """Stores a NACL verification key for the given server. + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. Args: - server_name (str): The name of the server. - from_server (str): Where the verification key was looked up - time_now_ms (int): The time now in milliseconds - verify_key (nacl.signing.VerifyKey): The NACL verify key. + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) - - # XXX fix this to not need a lock (#3819) - def _txn(txn): - self._simple_upsert_txn( - txn, - table="server_signature_keys", - keyvalues={"server_name": server_name, "key_id": key_id}, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": db_binary_type(verify_key.encode()), - }, + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) ) # invalidate takes a tuple corresponding to the params of # _get_server_verify_key. _get_server_verify_key only takes one # param, which is itself the 2-tuple (server_name, key_id). - txn.call_after( - self._get_server_verify_key.invalidate, ((server_name, key_id),) - ) - - return self.runInteraction("store_server_verify_key", _txn) + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i, )) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py new file mode 100644 index 0000000000..4c83800cca --- /dev/null +++ b/synapse/storage/relations.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from twisted.internet import defer + +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.stream import generate_pagination_where_clause +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +@attr.s +class PaginationChunk(object): + """Returned by relation pagination APIs. + + Attributes: + chunk (list): The rows returned by pagination + next_batch (Any|None): Token to fetch next set of results with, if + None then there are no more results. + prev_batch (Any|None): Token to fetch previous set of results with, if + None then there are no previous results. + """ + + chunk = attr.ib() + next_batch = attr.ib(default=None) + prev_batch = attr.ib(default=None) + + def to_dict(self): + d = {"chunk": self.chunk} + + if self.next_batch: + d["next_batch"] = self.next_batch.to_string() + + if self.prev_batch: + d["prev_batch"] = self.prev_batch.to_string() + + return d + + +@attr.s(frozen=True, slots=True) +class RelationPaginationToken(object): + """Pagination token for relation pagination API. + + As the results are order by topological ordering, we can use the + `topological_ordering` and `stream_ordering` fields of the events at the + boundaries of the chunk as pagination tokens. + + Attributes: + topological (int): The topological ordering of the boundary event + stream (int): The stream ordering of the boundary event. + """ + + topological = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + t, s = string.split("-") + return RelationPaginationToken(int(t), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.topological, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +@attr.s(frozen=True, slots=True) +class AggregationPaginationToken(object): + """Pagination token for relation aggregation pagination API. + + As the results are order by count and then MAX(stream_ordering) of the + aggregation groups, we can just use them as our pagination token. + + Attributes: + count (int): The count of relations in the boundar group. + stream (int): The MAX stream ordering in the boundary group. + """ + + count = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + c, s = string.split("-") + return AggregationPaginationToken(int(c), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.count, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + defer.returnValue(edit_event) + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 57df17bcc2..4bd1669458 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,6 +142,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction("get_room_summary", _get_room_summary_txn) + def _get_user_count_in_room_txn(self, txn, room_id, membership): + """ + See get_user_count_in_room. + """ + sql = ( + "SELECT count(*) FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" + ) + + txn.execute(sql, (room_id, membership)) + row = txn.fetchone() + return row[0] + + def get_user_count_in_room(self, room_id, membership): + """ + Get the user count in a room with a particular membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_users_in_room", self._get_user_count_in_room_txn, room_id, membership + ) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 0000000000..c01aa9d2d9 --- /dev/null +++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql new file mode 100644 index 0000000000..652e58308e --- /dev/null +++ b/synapse/storage/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 31a0279b18..5fdb442104 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self._simple_select_one_onecol_txn( + txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py new file mode 100644 index 0000000000..eb0ced5b5e --- /dev/null +++ b/synapse/storage/stats.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.state_deltas import StateDeltasStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "state_events", + ), + "user": ("public_rooms", "private_rooms"), +} + +TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +TEMP_TABLE = "_temp_populate_stats" + + +class StatsStore(StateDeltasStore): + def __init__(self, db_conn, hs): + super(StatsStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.register_background_update_handler( + "populate_stats_createtables", self._populate_stats_createtables + ) + self.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.register_background_update_handler( + "populate_stats_cleanup", self._populate_stats_cleanup + ) + + @defer.inlineCallbacks + def _populate_stats_createtables(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" + ) + txn.execute(sql) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database + sql = """ + SELECT room_id, count(*) FROM current_state_events + GROUP BY room_id + """ + txn.execute(sql) + rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] + self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + del rooms + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.runInteraction("populate_stats_temp_build", _make_staging_area) + yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + self.get_earliest_token_for_room_stats.invalidate_all() + + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + position = yield self._simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_stats_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_process_rooms(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_stats() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s_rooms + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE, + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.runInteraction( + "populate_stats_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + logger.info( + "Processing the next %d rooms of %d remaining", + len(rooms_to_work_on), progress["remaining"], + ) + + # Number of state events we've processed by going through each room + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + + current_state_ids = yield self.get_current_state_ids(room_id) + + join_rules = yield self.get_event( + current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True + ) + history_visibility = yield self.get_event( + current_state_ids.get((EventTypes.RoomHistoryVisibility, "")), + allow_none=True, + ) + encryption = yield self.get_event( + current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True + ) + name = yield self.get_event( + current_state_ids.get((EventTypes.Name, "")), allow_none=True + ) + topic = yield self.get_event( + current_state_ids.get((EventTypes.Topic, "")), allow_none=True + ) + avatar = yield self.get_event( + current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True + ) + canonical_alias = yield self.get_event( + current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True + ) + + def _or_none(x, arg): + if x: + return x.content.get(arg) + return None + + yield self.update_room_state( + room_id, + { + "join_rules": _or_none(join_rules, "join_rule"), + "history_visibility": _or_none( + history_visibility, "history_visibility" + ), + "encryption": _or_none(encryption, "algorithm"), + "name": _or_none(name, "name"), + "topic": _or_none(topic, "topic"), + "avatar": _or_none(avatar, "url"), + "canonical_alias": _or_none(canonical_alias, "alias"), + }, + ) + + now = self.hs.get_reactor().seconds() + + # quantise time to the nearest bucket + now = (now // self.stats_bucket_size) * self.stats_bucket_size + + def _fetch_data(txn): + + # Get the current token of the room + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + current_state_events = len(current_state_ids) + joined_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.JOIN + ) + invited_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.INVITE + ) + left_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.LEAVE + ) + banned_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.BAN + ) + total_state_events = self._get_total_state_event_counts_txn( + txn, room_id + ) + + self._update_stats_txn( + txn, + "room", + room_id, + now, + { + "bucket_size": self.stats_bucket_size, + "current_state_events": current_state_events, + "joined_members": joined_members, + "invited_members": invited_members, + "left_members": left_members, + "banned_members": banned_members, + "state_events": total_state_events, + }, + ) + self._simple_insert_txn( + txn, + "room_stats_earliest_token", + {"room_id": room_id, "token": current_token}, + ) + + yield self.runInteraction("update_room_stats", _fetch_data) + + # We've finished a room. Delete it from the table. + yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_stats", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + defer.returnValue(processed_event_count) + + defer.returnValue(processed_event_count) + + def delete_all_stats(self): + """ + Delete all statistics records. + """ + + def _delete_all_stats_txn(txn): + txn.execute("DELETE FROM room_state") + txn.execute("DELETE FROM room_stats") + txn.execute("DELETE FROM room_stats_earliest_token") + txn.execute("DELETE FROM user_stats") + + return self.runInteraction("delete_all_stats", _delete_all_stats_txn) + + def get_stats_stream_pos(self): + return self._simple_select_one_onecol( + table="stats_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="stats_stream_pos", + ) + + def update_stats_stream_pos(self, stream_id): + return self._simple_update_one( + table="stats_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_stats_stream_pos", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + return self._simple_upsert( + table="room_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_deltas_for_room(self, room_id, start, size=100): + """ + Get statistics deltas for a given room. + + Args: + room_id (str) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS["room"] and "ts". + """ + return self._simple_select_list_paginate( + "room_stats", + {"room_id": room_id}, + "ts", + start, + size, + retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + order_direction="DESC", + ) + + def get_all_room_state(self): + return self._simple_select_list( + "room_state", None, retcols=("name", "topic", "canonical_alias") + ) + + @cached() + def get_earliest_token_for_room_stats(self, room_id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + return self._simple_select_one_onecol( + "room_stats_earliest_token", + {"room_id": room_id}, + retcol="token", + allow_none=True, + ) + + def update_stats(self, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert( + table=table, + keyvalues={id_col: stats_id, "ts": ts}, + values=fields, + desc="update_stats", + ) + + def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert_txn( + txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + ) + + def update_stats_delta(self, ts, stats_type, stats_id, field, value): + def _update_stats_delta(txn): + table, id_col = TYPE_TO_ROOM[stats_type] + + sql = ( + "SELECT * FROM %s" + " WHERE %s=? and ts=(" + " SELECT MAX(ts) FROM %s" + " WHERE %s=?" + ")" + ) % (table, id_col, table, id_col) + txn.execute(sql, (stats_id, stats_id)) + rows = self.cursor_to_dict(txn) + if len(rows) == 0: + # silently skip as we don't have anything to apply a delta to yet. + # this tries to minimise any race between the initial sync and + # subsequent deltas arriving. + return + + current_ts = ts + latest_ts = rows[0]["ts"] + if current_ts < latest_ts: + # This one is in the past, but we're just encountering it now. + # Mark it as part of the current bucket. + current_ts = latest_ts + elif ts != latest_ts: + # we have to copy our absolute counters over to the new entry. + values = { + key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] + } + values[id_col] = stats_id + values["ts"] = ts + values["bucket_size"] = self.stats_bucket_size + + self._simple_insert_txn(txn, table=table, values=values) + + # actually update the new value + if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: + self._simple_update_txn( + txn, + table=table, + keyvalues={id_col: stats_id, "ts": current_ts}, + updatevalues={field: value}, + ) + else: + sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( + table, + field, + field, + id_col, + ) + txn.execute(sql, (value, stats_id, current_ts)) + + return self.runInteraction("update_stats_delta", _update_stats_delta) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d105b6b17d..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,135 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well - # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) <%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", +def generate_pagination_where_clause( + direction, column_names, from_token, to_token, engine, +): + """Creates an SQL expression to bound the columns by the pagination + tokens. + + For example creates an SQL expression like: + + (6, 7) >= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", - ) - - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well + # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we + # use the later form when running against postgres. + return "((%d,%d) %s (%s,%s))" % ( + val1, val2, + bound, + name1, name2, + ) + + # We want to generate queries of e.g. the form: + # + # (val1 < name1 OR (val1 = name1 AND val2 <= name2)) + # + # which is equivalent to (val1, val2) < (name1, name2) + + return """( + {val1:d} {strict_bound} {name1} + OR ({val1:d} = {name1} AND {val2:d} {bound} {name2}) + )""".format( + name1=name1, + val1=val1, + name2=name2, + val2=val2, + strict_bound=bound[0], # The first bound must always be strict equality here + bound=bound, + ) + def filter_to_clause(event_filter): # NB: This may create SQL clauses that don't optimise well (and we don't @@ -762,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = upper_bound(from_token, self.database_engine) - if to_token: - bounds = "%s AND %s" % ( - bounds, - lower_bound(to_token, self.database_engine), - ) else: order = "ASC" - bounds = lower_bound(from_token, self.database_engine) - if to_token: - bounds = "%s AND %s" % ( - bounds, - upper_bound(to_token, self.database_engine), - ) + + bounds = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token, + to_token=to_token, + engine=self.database_engine, + ) filter_clause, filter_args = filter_to_clause(event_filter) |