diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a94bec1ac5..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/types.py b/synapse/types.py
index bd271f9f16..5bde67cc07 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -22,6 +22,7 @@ from typing import (
TYPE_CHECKING,
Any,
Dict,
+ Iterable,
Mapping,
MutableMapping,
Optional,
@@ -43,7 +44,7 @@ if TYPE_CHECKING:
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Container, Iterable, Sized
+ from typing import Container, Sized
T_co = TypeVar("T_co", covariant=True)
@@ -375,7 +376,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
@@ -397,6 +398,31 @@ class RoomStreamToken:
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
+
+ There is also a third mode for live tokens where the token starts with "m",
+ which is sometimes used when using sharded event persisters. In this case
+ the events stream is considered to be a set of streams (one for each writer)
+ and the token encodes the vector clock of positions of each writer in their
+ respective streams.
+
+ The format of the token in such case is an initial integer min position,
+ followed by the mapping of instance ID to position separated by '.' and '~':
+
+ m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
+
+ The `min_pos` corresponds to the minimum position all writers have persisted
+ up to, and then only writers that are ahead of that position need to be
+ encoded. An example token is:
+
+ m56~2.58~3.59
+
+ Which corresponds to a set of three (or more writers) where instances 2 and
+ 3 (these are instance IDs that can be looked up in the DB to fetch the more
+ commonly used instance names) are at positions 58 and 59 respectively, and
+ all other instances are at position 56.
+
+ Note: The `RoomStreamToken` cannot have both a topological part and an
+ instance map.
"""
topological = attr.ib(
@@ -405,6 +431,25 @@ class RoomStreamToken:
)
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ instance_map = attr.ib(
+ type=Dict[str, int],
+ factory=dict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(dict),
+ ),
+ )
+
+ def __attrs_post_init__(self):
+ """Validates that both `topological` and `instance_map` aren't set.
+ """
+
+ if self.instance_map and self.topological:
+ raise ValueError(
+ "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
+ )
+
@classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try:
@@ -413,6 +458,20 @@ class RoomStreamToken:
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_map[instance_name] = pos
+
+ return cls(topological=None, stream=stream, instance_map=instance_map,)
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -436,14 +495,61 @@ class RoomStreamToken:
max_stream = max(self.stream, other.stream)
- return RoomStreamToken(None, max_stream)
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return RoomStreamToken(None, max_stream, instance_map)
+
+ def as_historical_tuple(self) -> Tuple[int, int]:
+ """Returns a tuple of `(topological, stream)` for historical tokens.
+
+ Raises if not an historical token (i.e. doesn't have a topological part).
+ """
+ if self.topological is None:
+ raise Exception(
+ "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
+ )
- def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token.
+
+ This only makes sense for "live" tokens that may have a vector clock
+ component, and so asserts that this is a "live" token.
+ """
+ assert self.topological is None
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
+ elif self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ instance_id = await store.get_id_for_instance(name)
+ entries.append("{}.{}".format(instance_id, pos))
+
+ encoded_map = "~".join(entries)
+ return "m{}~{}".format(self.stream, encoded_map)
else:
return "s%d" % (self.stream,)
@@ -535,7 +641,7 @@ class PersistedEventPosition:
stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool:
- return token.stream < self.stream
+ return token.get_stream_pos_for_instance(self.instance_name) < self.stream
def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
|