diff options
author | Erik Johnston <erik@matrix.org> | 2020-09-15 13:36:51 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-09-29 14:43:28 +0100 |
commit | 4499d81adf97c3260872f25745655d2e404241eb (patch) | |
tree | 80c4addd06f4850210704258bc424bd8994edffb | |
parent | Reduce usages of RoomStreamToken constructor (diff) | |
download | synapse-4499d81adf97c3260872f25745655d2e404241eb.tar.xz |
Wire up token
-rw-r--r-- | synapse/notifier.py | 4 | ||||
-rw-r--r-- | synapse/python_dependencies.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 227 | ||||
-rw-r--r-- | synapse/types.py | 65 |
4 files changed, 252 insertions, 45 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py index 441b3d15e2..59415f6f88 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -163,7 +163,7 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token.is_after(token): + if self.last_notified_token != token: return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) @@ -470,7 +470,7 @@ class Notifier: async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: - if not after_token.is_after(before_token): + if after_token == before_token: return EventStreamResult([], (from_token, from_token)) events = [] # type: List[EventBase] diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 288631477e..1761fca0ed 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -77,6 +77,7 @@ REQUIREMENTS = [ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", + "cbor2", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 255d363b33..72b6420532 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,11 +35,10 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from twisted.internet import defer @@ -54,6 +53,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -76,6 +76,18 @@ _EventDictReturn = namedtuple( ) +def _filter_result( + instance_name: str, + stream_id: int, + from_token: RoomStreamToken, + to_token: RoomStreamToken, +) -> bool: + from_id = from_token.instance_map.get(instance_name, from_token.stream) + to_id = to_token.instance_map.get(instance_name, to_token.stream) + + return from_id < stream_id <= to_id + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -209,6 +221,71 @@ def _make_generic_sql_bound( ) +def _make_instance_filter_clause( + direction: str, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[str, List[Any]]: + if from_token and from_token.topological: + from_token = None + if to_token and to_token.topological: + to_token = None + + if not from_token and not to_token: + return "", [] + + from_bound = ">=" if direction == "b" else "<" + to_bound = "<" if direction == "b" else ">=" + + filter_clauses = [] + filter_args = [] # type: List[Any] + + from_map = from_token.instance_map if from_token else {} + to_map = to_token.instance_map if to_token else {} + + default_from = from_token.stream if from_token else None + default_to = to_token.stream if to_token else None + + if default_from and default_to: + filter_clauses.append( + "(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound) + ) + filter_args.extend((default_from, default_to,)) + elif default_from: + filter_clauses.append("(? %s stream_ordering)" % (from_bound,)) + filter_args.extend((default_from,)) + elif default_to: + filter_clauses.append("(? %s stream_ordering)" % (to_bound,)) + filter_args.extend((default_to,)) + + for instance in set(from_map).union(to_map): + from_id = from_map.get(instance, default_from) + to_id = to_map.get(instance, default_to) + + if from_id and to_id: + filter_clauses.append( + "(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)" + % (from_bound, to_bound) + ) + filter_args.extend((instance, from_id, to_id,)) + elif from_id: + filter_clauses.append( + "(instance_name = ? AND ? %s stream_ordering)" % (from_bound,) + ) + filter_args.extend((instance, from_id,)) + elif to_id: + filter_clauses.append( + "(instance_name = ? AND ? %s stream_ordering)" % (to_bound,) + ) + filter_args.extend((instance, to_id,)) + + filter_clause = "" + if filter_clauses: + filter_clause = "(%s)" % (" OR ".join(filter_clauses),) + + return filter_clause, filter_args + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -306,7 +383,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p >= min_pos + } + + if set(positions.values()) == {min_pos}: + positions = {} + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -405,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + filter_clause, filter_args = _make_instance_filter_clause( + "f", from_key, to_key + ) + if filter_clause: + filter_clause = " AND " + filter_clause + + min_from_id = min(from_key.instance_map.values(), default=from_key.stream) + max_to_id = max(to_key.instance_map.values(), default=to_key.stream) + + sql = """ + SELECT event_id, instance_name, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + %s + ORDER BY stream_ordering %s LIMIT ? + """ % ( + filter_clause, + order, + ) + args = [room_id, min_from_id, max_to_id] + args.extend(filter_args) + args.append(limit) + txn.execute(sql, args) + + # rows = [ + # _EventDictReturn(event_id, None, stream_ordering) + # for event_id, instance_name, stream_ordering in txn + # if _filter_result(instance_name, stream_ordering, from_key, to_key) + # ] + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, stream_ordering in txn + ] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -432,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=from_key.stream is None) if order.lower() == "desc": ret.reverse() @@ -449,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" + filter_clause, filter_args = _make_instance_filter_clause( + "f", from_key, to_key + ) + if filter_clause: + filter_clause = " AND " + filter_clause + + min_from_id = min(from_key.instance_map.values(), default=from_key.stream) + max_to_id = max(to_key.instance_map.values(), default=to_key.stream) + + sql = """ + SELECT m.event_id, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + %s + ORDER BY e.stream_ordering ASC + """ % ( + filter_clause, ) - txn.execute(sql, (user_id, from_id, to_id)) + args = [user_id, min_from_id, max_to_id] + args.extend(filter_args) + txn.execute(sql, args) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] @@ -978,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + if from_token.topological is not None: + from_bound = from_token.as_tuple() + elif direction == "b": + from_bound = ( + None, + max(from_token.instance_map.values(), default=from_token.stream), + ) + else: + from_bound = ( + None, + min(from_token.instance_map.values(), default=from_token.stream), + ) + + to_bound = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_tuple() + elif direction == "b": + to_bound = ( + None, + min(to_token.instance_map.values(), default=to_token.stream), + ) + else: + to_bound = ( + None, + max(to_token.instance_map.values(), default=to_token.stream), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -992,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) + stream_filter_clause, stream_filter_args = _make_instance_filter_clause( + direction, from_token, to_token + ) + if stream_filter_clause: + bounds += " AND " + stream_filter_clause + args.extend(stream_filter_args) + args.append(int(limit)) select_keywords = "SELECT" diff --git a/synapse/types.py b/synapse/types.py index ec39f9e1e8..5224b473a9 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -21,8 +21,9 @@ from collections import namedtuple from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar import attr +import cbor2 from signedjson.key import decode_verify_key_bytes -from unpaddedbase64 import decode_base64 +from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError @@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, cmp=False) class RoomStreamToken: """Tokens are positions between events. The token "s1" comes after event 1. @@ -392,6 +393,8 @@ class RoomStreamToken: ) stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + instance_map = attr.ib(type=Dict[str, int], factory=dict) + @classmethod def parse(cls, string: str) -> "RoomStreamToken": try: @@ -400,6 +403,11 @@ class RoomStreamToken: if string[0] == "t": parts = string[1:].split("-", 1) return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + payload = cbor2.loads(decode_base64(string[1:])) + return cls( + topological=None, stream=payload["s"], instance_map=payload["p"], + ) except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -413,15 +421,49 @@ class RoomStreamToken: pass raise SynapseError(400, "Invalid token %r" % (string,)) + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, instance_map) + def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) def __str__(self) -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + return "m" + encode_base64( + cbor2.dumps({"s": self.stream, "p": self.instance_map}), + ) else: return "s%d" % (self.stream,) + def __lt__(self, other: "RoomStreamToken"): + if self.stream != other.stream: + return self.stream < other.stream + + for instance in set(self.instance_map).union(other.instance_map): + if self.instance_map.get(instance, self.stream) != other.instance_map.get( + instance, other.stream + ): + return self.instance_map.get( + instance, self.stream + ) < other.instance_map.get(instance, other.stream) + + return False + @attr.s(slots=True, frozen=True) class StreamToken: @@ -461,7 +503,7 @@ class StreamToken: def is_after(self, other): """Does this token contain events that the other doesn't?""" return ( - (other.room_stream_id < self.room_stream_id) + (other.room_key < self.room_key) or (int(other.presence_key) < int(self.presence_key)) or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) @@ -476,13 +518,16 @@ class StreamToken: """Advance the given key in the token to a new value if and only if the new value is after the old value. """ - new_token = self.copy_and_replace(key, new_value) if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + if old_id < new_id: return new_token else: @@ -507,7 +552,7 @@ class PersistedEventPosition: stream = attr.ib(type=int) def persisted_after(self, token: RoomStreamToken) -> bool: - return token.stream < self.stream + return token.instance_map.get(self.instance_name, token.stream) < self.stream class ThirdPartyInstanceID( |