diff --git a/synapse/notifier.py b/synapse/notifier.py
index 441b3d15e2..59415f6f88 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -163,7 +163,7 @@ class _NotifierUserStream:
"""
# Immediately wake up stream if something has already since happened
# since their last token.
- if self.last_notified_token.is_after(token):
+ if self.last_notified_token != token:
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
@@ -470,7 +470,7 @@ class Notifier:
async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
- if not after_token.is_after(before_token):
+ if after_token == before_token:
return EventStreamResult([], (from_token, from_token))
events = [] # type: List[EventBase]
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 288631477e..1761fca0ed 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -77,6 +77,7 @@ REQUIREMENTS = [
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
+ "cbor2",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 255d363b33..72b6420532 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,11 +35,10 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
-
import abc
import logging
from collections import namedtuple
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@@ -54,6 +53,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -76,6 +76,18 @@ _EventDictReturn = namedtuple(
)
+def _filter_result(
+ instance_name: str,
+ stream_id: int,
+ from_token: RoomStreamToken,
+ to_token: RoomStreamToken,
+) -> bool:
+ from_id = from_token.instance_map.get(instance_name, from_token.stream)
+ to_id = to_token.instance_map.get(instance_name, to_token.stream)
+
+ return from_id < stream_id <= to_id
+
+
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -209,6 +221,71 @@ def _make_generic_sql_bound(
)
+def _make_instance_filter_clause(
+ direction: str,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
+) -> Tuple[str, List[Any]]:
+ if from_token and from_token.topological:
+ from_token = None
+ if to_token and to_token.topological:
+ to_token = None
+
+ if not from_token and not to_token:
+ return "", []
+
+ from_bound = ">=" if direction == "b" else "<"
+ to_bound = "<" if direction == "b" else ">="
+
+ filter_clauses = []
+ filter_args = [] # type: List[Any]
+
+ from_map = from_token.instance_map if from_token else {}
+ to_map = to_token.instance_map if to_token else {}
+
+ default_from = from_token.stream if from_token else None
+ default_to = to_token.stream if to_token else None
+
+ if default_from and default_to:
+ filter_clauses.append(
+ "(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound)
+ )
+ filter_args.extend((default_from, default_to,))
+ elif default_from:
+ filter_clauses.append("(? %s stream_ordering)" % (from_bound,))
+ filter_args.extend((default_from,))
+ elif default_to:
+ filter_clauses.append("(? %s stream_ordering)" % (to_bound,))
+ filter_args.extend((default_to,))
+
+ for instance in set(from_map).union(to_map):
+ from_id = from_map.get(instance, default_from)
+ to_id = to_map.get(instance, default_to)
+
+ if from_id and to_id:
+ filter_clauses.append(
+ "(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)"
+ % (from_bound, to_bound)
+ )
+ filter_args.extend((instance, from_id, to_id,))
+ elif from_id:
+ filter_clauses.append(
+ "(instance_name = ? AND ? %s stream_ordering)" % (from_bound,)
+ )
+ filter_args.extend((instance, from_id,))
+ elif to_id:
+ filter_clauses.append(
+ "(instance_name = ? AND ? %s stream_ordering)" % (to_bound,)
+ )
+ filter_args.extend((instance, to_id,))
+
+ filter_clause = ""
+ if filter_clauses:
+ filter_clause = "(%s)" % (" OR ".join(filter_clauses),)
+
+ return filter_clause, filter_args
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -306,7 +383,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.get_room_max_stream_ordering())
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p >= min_pos
+ }
+
+ if set(positions.values()) == {min_pos}:
+ positions = {}
+
+ return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
@@ -405,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ filter_clause, filter_args = _make_instance_filter_clause(
+ "f", from_key, to_key
+ )
+ if filter_clause:
+ filter_clause = " AND " + filter_clause
+
+ min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
+ max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
+
+ sql = """
+ SELECT event_id, instance_name, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ %s
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ filter_clause,
+ order,
+ )
+ args = [room_id, min_from_id, max_to_id]
+ args.extend(filter_args)
+ args.append(limit)
+ txn.execute(sql, args)
+
+ # rows = [
+ # _EventDictReturn(event_id, None, stream_ordering)
+ # for event_id, instance_name, stream_ordering in txn
+ # if _filter_result(instance_name, stream_ordering, from_key, to_key)
+ # ]
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, stream_ordering in txn
+ ]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -432,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
if order.lower() == "desc":
ret.reverse()
@@ -449,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
+ filter_clause, filter_args = _make_instance_filter_clause(
+ "f", from_key, to_key
+ )
+ if filter_clause:
+ filter_clause = " AND " + filter_clause
+
+ min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
+ max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
+
+ sql = """
+ SELECT m.event_id, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ %s
+ ORDER BY e.stream_ordering ASC
+ """ % (
+ filter_clause,
)
- txn.execute(sql, (user_id, from_id, to_id))
+ args = [user_id, min_from_id, max_to_id]
+ args.extend(filter_args)
+ txn.execute(sql, args)
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@@ -978,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
+ if from_token.topological is not None:
+ from_bound = from_token.as_tuple()
+ elif direction == "b":
+ from_bound = (
+ None,
+ max(from_token.instance_map.values(), default=from_token.stream),
+ )
+ else:
+ from_bound = (
+ None,
+ min(from_token.instance_map.values(), default=from_token.stream),
+ )
+
+ to_bound = None
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ min(to_token.instance_map.values(), default=to_token.stream),
+ )
+ else:
+ to_bound = (
+ None,
+ max(to_token.instance_map.values(), default=to_token.stream),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -992,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
+ stream_filter_clause, stream_filter_args = _make_instance_filter_clause(
+ direction, from_token, to_token
+ )
+ if stream_filter_clause:
+ bounds += " AND " + stream_filter_clause
+ args.extend(stream_filter_args)
+
args.append(int(limit))
select_keywords = "SELECT"
diff --git a/synapse/types.py b/synapse/types.py
index ec39f9e1e8..5224b473a9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -21,8 +21,9 @@ from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr
+import cbor2
from signedjson.key import decode_verify_key_bytes
-from unpaddedbase64 import decode_base64
+from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
@@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
@@ -392,6 +393,8 @@ class RoomStreamToken:
)
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ instance_map = attr.ib(type=Dict[str, int], factory=dict)
+
@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
try:
@@ -400,6 +403,11 @@ class RoomStreamToken:
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
+ if string[0] == "m":
+ payload = cbor2.loads(decode_base64(string[1:]))
+ return cls(
+ topological=None, stream=payload["s"], instance_map=payload["p"],
+ )
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -413,15 +421,49 @@ class RoomStreamToken:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
+ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+ if self.topological or other.topological:
+ raise Exception("Can't advance topological tokens")
+
+ max_stream = max(self.stream, other.stream)
+
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return RoomStreamToken(None, max_stream, instance_map)
+
def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
def __str__(self) -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
+ elif self.instance_map:
+ return "m" + encode_base64(
+ cbor2.dumps({"s": self.stream, "p": self.instance_map}),
+ )
else:
return "s%d" % (self.stream,)
+ def __lt__(self, other: "RoomStreamToken"):
+ if self.stream != other.stream:
+ return self.stream < other.stream
+
+ for instance in set(self.instance_map).union(other.instance_map):
+ if self.instance_map.get(instance, self.stream) != other.instance_map.get(
+ instance, other.stream
+ ):
+ return self.instance_map.get(
+ instance, self.stream
+ ) < other.instance_map.get(instance, other.stream)
+
+ return False
+
@attr.s(slots=True, frozen=True)
class StreamToken:
@@ -461,7 +503,7 @@ class StreamToken:
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
return (
- (other.room_stream_id < self.room_stream_id)
+ (other.room_key < self.room_key)
or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
@@ -476,13 +518,16 @@ class StreamToken:
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
- new_token = self.copy_and_replace(key, new_value)
if key == "room_key":
- new_id = new_token.room_stream_id
- old_id = self.room_stream_id
- else:
- new_id = int(getattr(new_token, key))
- old_id = int(getattr(self, key))
+ new_token = self.copy_and_replace(
+ "room_key", self.room_key.copy_and_advance(new_value)
+ )
+ return new_token
+
+ new_token = self.copy_and_replace(key, new_value)
+ new_id = int(getattr(new_token, key))
+ old_id = int(getattr(self, key))
+
if old_id < new_id:
return new_token
else:
@@ -507,7 +552,7 @@ class PersistedEventPosition:
stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool:
- return token.stream < self.stream
+ return token.instance_map.get(self.instance_name, token.stream) < self.stream
class ThirdPartyInstanceID(
|