diff options
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r-- | synapse/storage/databases/main/stream.py | 280 |
1 files changed, 243 insertions, 37 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a94bec1ac5..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -53,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -208,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: |