summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-15 13:36:51 +0100
committerErik Johnston <erik@matrix.org>2020-09-29 14:43:28 +0100
commit4499d81adf97c3260872f25745655d2e404241eb (patch)
tree80c4addd06f4850210704258bc424bd8994edffb
parentReduce usages of RoomStreamToken constructor (diff)
downloadsynapse-4499d81adf97c3260872f25745655d2e404241eb.tar.xz
Wire up token
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/storage/databases/main/stream.py227
-rw-r--r--synapse/types.py65
4 files changed, 252 insertions, 45 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 441b3d15e2..59415f6f88 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -163,7 +163,7 @@ class _NotifierUserStream:
         """
         # Immediately wake up stream if something has already since happened
         # since their last token.
-        if self.last_notified_token.is_after(token):
+        if self.last_notified_token != token:
             return _NotificationListener(defer.succeed(self.current_token))
         else:
             return _NotificationListener(self.notify_deferred.observe())
@@ -470,7 +470,7 @@ class Notifier:
         async def check_for_updates(
             before_token: StreamToken, after_token: StreamToken
         ) -> EventStreamResult:
-            if not after_token.is_after(before_token):
+            if after_token == before_token:
                 return EventStreamResult([], (from_token, from_token))
 
             events = []  # type: List[EventBase]
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 288631477e..1761fca0ed 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -77,6 +77,7 @@ REQUIREMENTS = [
     "Jinja2>=2.9",
     "bleach>=1.4.3",
     "typing-extensions>=3.7.4",
+    "cbor2",
 ]
 
 CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 255d363b33..72b6420532 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,11 +35,10 @@ what sort order was used:
     - topological tokems: "t%d-%d", where the integers map to the topological
       and stream ordering columns respectively.
 """
-
 import abc
 import logging
 from collections import namedtuple
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
 
 from twisted.internet import defer
 
@@ -54,6 +53,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import Collection, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -76,6 +76,18 @@ _EventDictReturn = namedtuple(
 )
 
 
+def _filter_result(
+    instance_name: str,
+    stream_id: int,
+    from_token: RoomStreamToken,
+    to_token: RoomStreamToken,
+) -> bool:
+    from_id = from_token.instance_map.get(instance_name, from_token.stream)
+    to_id = to_token.instance_map.get(instance_name, to_token.stream)
+
+    return from_id < stream_id <= to_id
+
+
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
@@ -209,6 +221,71 @@ def _make_generic_sql_bound(
     )
 
 
+def _make_instance_filter_clause(
+    direction: str,
+    from_token: Optional[RoomStreamToken],
+    to_token: Optional[RoomStreamToken],
+) -> Tuple[str, List[Any]]:
+    if from_token and from_token.topological:
+        from_token = None
+    if to_token and to_token.topological:
+        to_token = None
+
+    if not from_token and not to_token:
+        return "", []
+
+    from_bound = ">=" if direction == "b" else "<"
+    to_bound = "<" if direction == "b" else ">="
+
+    filter_clauses = []
+    filter_args = []  # type: List[Any]
+
+    from_map = from_token.instance_map if from_token else {}
+    to_map = to_token.instance_map if to_token else {}
+
+    default_from = from_token.stream if from_token else None
+    default_to = to_token.stream if to_token else None
+
+    if default_from and default_to:
+        filter_clauses.append(
+            "(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound)
+        )
+        filter_args.extend((default_from, default_to,))
+    elif default_from:
+        filter_clauses.append("(? %s stream_ordering)" % (from_bound,))
+        filter_args.extend((default_from,))
+    elif default_to:
+        filter_clauses.append("(? %s stream_ordering)" % (to_bound,))
+        filter_args.extend((default_to,))
+
+    for instance in set(from_map).union(to_map):
+        from_id = from_map.get(instance, default_from)
+        to_id = to_map.get(instance, default_to)
+
+        if from_id and to_id:
+            filter_clauses.append(
+                "(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)"
+                % (from_bound, to_bound)
+            )
+            filter_args.extend((instance, from_id, to_id,))
+        elif from_id:
+            filter_clauses.append(
+                "(instance_name = ? AND ? %s stream_ordering)" % (from_bound,)
+            )
+            filter_args.extend((instance, from_id,))
+        elif to_id:
+            filter_clauses.append(
+                "(instance_name = ? AND ? %s stream_ordering)" % (to_bound,)
+            )
+            filter_args.extend((instance, to_id,))
+
+    filter_clause = ""
+    if filter_clauses:
+        filter_clause = "(%s)" % (" OR ".join(filter_clauses),)
+
+    return filter_clause, filter_args
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -306,7 +383,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     def get_room_max_token(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p >= min_pos
+            }
+
+            if set(positions.values()) == {min_pos}:
+                positions = {}
+
+        return RoomStreamToken(None, min_pos, positions)
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -405,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            filter_clause, filter_args = _make_instance_filter_clause(
+                "f", from_key, to_key
+            )
+            if filter_clause:
+                filter_clause = " AND " + filter_clause
+
+            min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
+            max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
+
+            sql = """
+                SELECT event_id, instance_name, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                    %s
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                filter_clause,
+                order,
+            )
+            args = [room_id, min_from_id, max_to_id]
+            args.extend(filter_args)
+            args.append(limit)
+            txn.execute(sql, args)
+
+            # rows = [
+            #     _EventDictReturn(event_id, None, stream_ordering)
+            #     for event_id, instance_name, stream_ordering in txn
+            #     if _filter_result(instance_name, stream_ordering, from_key, to_key)
+            # ]
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, stream_ordering in txn
+            ]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -432,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -449,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
+            filter_clause, filter_args = _make_instance_filter_clause(
+                "f", from_key, to_key
+            )
+            if filter_clause:
+                filter_clause = " AND " + filter_clause
+
+            min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
+            max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
+
+            sql = """
+                SELECT m.event_id, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                    %s
+                ORDER BY e.stream_ordering ASC
+            """ % (
+                filter_clause,
             )
-            txn.execute(sql, (user_id, from_id, to_id))
+            args = [user_id, min_from_id, max_to_id]
+            args.extend(filter_args)
+            txn.execute(sql, args)
 
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
 
@@ -978,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         else:
             order = "ASC"
 
+        if from_token.topological is not None:
+            from_bound = from_token.as_tuple()
+        elif direction == "b":
+            from_bound = (
+                None,
+                max(from_token.instance_map.values(), default=from_token.stream),
+            )
+        else:
+            from_bound = (
+                None,
+                min(from_token.instance_map.values(), default=from_token.stream),
+            )
+
+        to_bound = None
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    min(to_token.instance_map.values(), default=to_token.stream),
+                )
+            else:
+                to_bound = (
+                    None,
+                    max(to_token.instance_map.values(), default=to_token.stream),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -992,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
+        stream_filter_clause, stream_filter_args = _make_instance_filter_clause(
+            direction, from_token, to_token
+        )
+        if stream_filter_clause:
+            bounds += " AND " + stream_filter_clause
+            args.extend(stream_filter_args)
+
         args.append(int(limit))
 
         select_keywords = "SELECT"
diff --git a/synapse/types.py b/synapse/types.py
index ec39f9e1e8..5224b473a9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -21,8 +21,9 @@ from collections import namedtuple
 from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
 
 import attr
+import cbor2
 from signedjson.key import decode_verify_key_bytes
-from unpaddedbase64 import decode_base64
+from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.errors import Codes, SynapseError
 
@@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
     return username.decode("ascii")
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
 class RoomStreamToken:
     """Tokens are positions between events. The token "s1" comes after event 1.
 
@@ -392,6 +393,8 @@ class RoomStreamToken:
     )
     stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
 
+    instance_map = attr.ib(type=Dict[str, int], factory=dict)
+
     @classmethod
     def parse(cls, string: str) -> "RoomStreamToken":
         try:
@@ -400,6 +403,11 @@ class RoomStreamToken:
             if string[0] == "t":
                 parts = string[1:].split("-", 1)
                 return cls(topological=int(parts[0]), stream=int(parts[1]))
+            if string[0] == "m":
+                payload = cbor2.loads(decode_base64(string[1:]))
+                return cls(
+                    topological=None, stream=payload["s"], instance_map=payload["p"],
+                )
         except Exception:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
@@ -413,15 +421,49 @@ class RoomStreamToken:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
 
+    def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+        if self.topological or other.topological:
+            raise Exception("Can't advance topological tokens")
+
+        max_stream = max(self.stream, other.stream)
+
+        instance_map = {
+            instance: max(
+                self.instance_map.get(instance, self.stream),
+                other.instance_map.get(instance, other.stream),
+            )
+            for instance in set(self.instance_map).union(other.instance_map)
+        }
+
+        return RoomStreamToken(None, max_stream, instance_map)
+
     def as_tuple(self) -> Tuple[Optional[int], int]:
         return (self.topological, self.stream)
 
     def __str__(self) -> str:
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
+        elif self.instance_map:
+            return "m" + encode_base64(
+                cbor2.dumps({"s": self.stream, "p": self.instance_map}),
+            )
         else:
             return "s%d" % (self.stream,)
 
+    def __lt__(self, other: "RoomStreamToken"):
+        if self.stream != other.stream:
+            return self.stream < other.stream
+
+        for instance in set(self.instance_map).union(other.instance_map):
+            if self.instance_map.get(instance, self.stream) != other.instance_map.get(
+                instance, other.stream
+            ):
+                return self.instance_map.get(
+                    instance, self.stream
+                ) < other.instance_map.get(instance, other.stream)
+
+        return False
+
 
 @attr.s(slots=True, frozen=True)
 class StreamToken:
@@ -461,7 +503,7 @@ class StreamToken:
     def is_after(self, other):
         """Does this token contain events that the other doesn't?"""
         return (
-            (other.room_stream_id < self.room_stream_id)
+            (other.room_key < self.room_key)
             or (int(other.presence_key) < int(self.presence_key))
             or (int(other.typing_key) < int(self.typing_key))
             or (int(other.receipt_key) < int(self.receipt_key))
@@ -476,13 +518,16 @@ class StreamToken:
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
         """
-        new_token = self.copy_and_replace(key, new_value)
         if key == "room_key":
-            new_id = new_token.room_stream_id
-            old_id = self.room_stream_id
-        else:
-            new_id = int(getattr(new_token, key))
-            old_id = int(getattr(self, key))
+            new_token = self.copy_and_replace(
+                "room_key", self.room_key.copy_and_advance(new_value)
+            )
+            return new_token
+
+        new_token = self.copy_and_replace(key, new_value)
+        new_id = int(getattr(new_token, key))
+        old_id = int(getattr(self, key))
+
         if old_id < new_id:
             return new_token
         else:
@@ -507,7 +552,7 @@ class PersistedEventPosition:
     stream = attr.ib(type=int)
 
     def persisted_after(self, token: RoomStreamToken) -> bool:
-        return token.stream < self.stream
+        return token.instance_map.get(self.instance_name, token.stream) < self.stream
 
 
 class ThirdPartyInstanceID(