diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a94bec1ac5..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
|