diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b313720a4b..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
# limitations under the License.
import logging
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
- EventsStreamCurrentStateRow,
- EventsStreamEventRow,
-)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: Database, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
-
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -92,81 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
- self._stream_id_gen.advance(token)
- for row in rows:
- self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
- self._backfill_id_gen.advance(-token)
- for row in rows:
- self.invalidate_caches_for_event(
- -token,
- row.event_id,
- row.room_id,
- row.type,
- row.state_key,
- row.redacts,
- row.relates_to,
- backfilled=True,
- )
- return super().process_replication_rows(stream_name, instance_name, token, rows)
-
- def _process_event_stream_row(self, token, row):
- data = row.data
-
- if row.type == EventsStreamEventRow.TypeId:
- self.invalidate_caches_for_event(
- token,
- data.event_id,
- data.room_id,
- data.type,
- data.state_key,
- data.redacts,
- data.relates_to,
- backfilled=False,
- )
- elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
-
- if data.type == EventTypes.Member:
- self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key,)
- )
- else:
- raise Exception("Unknown events stream row type %s" % (row.type,))
-
- def invalidate_caches_for_event(
- self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
- self._invalidate_get_event_cache(event_id)
-
- self.get_latest_event_ids_in_room.invalidate((room_id,))
-
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
- if not backfilled:
- self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
- if redacts:
- self._invalidate_get_event_cache(redacts)
-
- if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
- if relates_to:
- self.get_relations_for_event.invalidate_many((relates_to,))
- self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
- self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 5d5816d7eb..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
-from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
- super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
- db_query_to_update_function(self._update_function),
+ self._update_function,
+ )
+
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(room_rows, global_rows))
+ return updates, to_token, limited
class GroupServerStream(Stream):
|