# # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2018 Vector Creations Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # See the GNU Affero General Public License for more details: # . # # Originally licensed under the Apache License, Version 2.0: # . # # [This file includes modifications made by New Vector Limited] # # import logging from typing import List, Optional, Tuple import attr from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.stream import _filter_results_by_stream from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class StateDelta: stream_id: int room_id: str event_type: str state_key: str event_id: Optional[str] """new event_id for this state key. None if the state has been deleted.""" prev_event_id: Optional[str] """previous event_id for this state key. None if it's new state.""" class StateDeltasStore(SQLBaseStore): # This class must be mixed in with a child class which provides the following # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[StateDelta]]: """Fetch a list of room state changes since the given stream id This may be the partial state if we're lazy joining the room. Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. Returns: A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. """ prev_stream_id = int(prev_stream_id) # check we're not going backwards assert ( prev_stream_id <= max_stream_id ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id ): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. return max_stream_id, [] def get_current_state_deltas_txn( txn: LoggingTransaction, ) -> Tuple[int, List[StateDelta]]: # First we calculate the max stream id that will give us less than # N results. # We arbitrarily limit to 100 stream_id entries to ensure we don't # select toooo many. sql = """ SELECT stream_id, count(*) FROM current_state_delta_stream WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id ORDER BY stream_id ASC LIMIT 100 """ txn.execute(sql, (prev_stream_id, max_stream_id)) total = 0 for stream_id, count in txn: total += count if total > 100: # We arbitrarily limit to 100 entries to ensure we don't # select toooo many. logger.debug( "Clipping current_state_delta_stream rows to stream_id %i", stream_id, ) clipped_stream_id = stream_id break else: # if there's no problem, we may as well go right up to the max_stream_id clipped_stream_id = max_stream_id # Now actually get the deltas sql = """ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) return clipped_stream_id, [ StateDelta( stream_id=row[0], room_id=row[1], event_type=row[2], state_key=row[3], event_id=row[4], prev_event_id=row[5], ) for row in txn.fetchall() ] return await self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn( self, txn: LoggingTransaction ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", ) async def get_max_stream_id_in_current_state_deltas(self) -> int: return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) @trace async def get_current_state_deltas_for_room( self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken ) -> List[StateDelta]: """Get the state deltas between two tokens.""" if not self._curr_state_delta_stream_cache.has_entity_changed( room_id, from_token.stream ): return [] def get_current_state_deltas_for_room_txn( txn: LoggingTransaction, ) -> List[StateDelta]: sql = """ SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream WHERE room_id = ? AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ txn.execute( sql, (room_id, from_token.stream, to_token.get_max_stream_pos()) ) return [ StateDelta( stream_id=row[1], room_id=room_id, event_type=row[2], state_key=row[3], event_id=row[4], prev_event_id=row[5], ) for row in txn if _filter_results_by_stream(from_token, to_token, row[0], row[1]) ] return await self.db_pool.runInteraction( "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn )