summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py116
-rw-r--r--synapse/logging/opentracing.py158
-rw-r--r--synapse/storage/databases/main/events_worker.py60
-rw-r--r--synapse/storage/databases/main/state.py5
-rw-r--r--synapse/storage/state.py9
-rw-r--r--synapse/visibility.py4
6 files changed, 247 insertions, 105 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d827c03ad1..3ca01391c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,17 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    FrozenSet,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -89,7 +99,7 @@ class SyncConfig:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
     prev_batch: StreamToken
-    events: List[EventBase]
+    events: Sequence[EventBase]
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
@@ -852,16 +862,26 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """Works out the difference in state between the start of the timeline
-        and the previous sync.
+        """Works out the difference in state between the end of the previous sync and
+        the start of the timeline.
 
         Args:
             room_id:
             batch: The timeline batch for the room that will be sent to the user.
             sync_config:
-            since_token: Token of the end of the previous batch. May be None.
+            since_token: Token of the end of the previous batch. May be `None`.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+                `lazy_load_members` still applies when `full_state` is `True`.
+
+        Returns:
+            The state to return in the sync response for the room.
+
+            Clients will overlay this onto the state at the end of the previous sync to
+            arrive at the state at the start of the timeline.
+
+            Clients will then overlay state events in the timeline to arrive at the
+            state at the end of the timeline, in preparation for the next sync.
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -869,7 +889,8 @@ class SyncHandler:
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
-
+            # The memberships needed for events in the timeline.
+            # Only calculated when `lazy_load_members` is on.
             members_to_fetch = None
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
@@ -897,38 +918,46 @@ class SyncHandler:
             else:
                 state_filter = StateFilter.all()
 
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
             timeline_state = {
                 (event.type, event.state_key): event.event_id
                 for event in batch.events
                 if event.is_state()
             }
 
+            # Now calculate the state to return in the sync response for the room.
+            # This is more or less the change in state between the end of the previous
+            # sync's timeline and the start of the current sync's timeline.
+            # See the docstring above for details.
+            state_ids: StateMap[str]
+
             if full_state:
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
 
-                    state_ids = (
+                    state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[0].event_id, state_filter=state_filter
                         )
                     )
 
                 else:
-                    current_state_ids = await self.get_state_at(
+                    state_at_timeline_end = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
-                    state_ids = current_state_ids
+                    state_at_timeline_start = state_at_timeline_end
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state_ids,
-                    previous={},
-                    current=current_state_ids,
+                    timeline_start=state_at_timeline_start,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end={},
                     lazy_load_members=lazy_load_members,
                 )
             elif batch.limited:
@@ -968,24 +997,23 @@ class SyncHandler:
                 )
 
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
                 else:
-                    # Its not clear how we get here, but empirically we do
-                    # (#5407). Logging has been added elsewhere to try and
-                    # figure out where this state comes from.
-                    current_state_ids = await self.get_state_at(
+                    # We can get here if the user has ignored the senders of all
+                    # the recent events.
+                    state_at_timeline_end = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
-                    previous=state_at_previous_sync,
-                    current=current_state_ids,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end=state_at_previous_sync,
                     # we have to include LL members in case LL initial sync missed them
                     lazy_load_members=lazy_load_members,
                 )
@@ -1010,6 +1038,13 @@ class SyncHandler:
                             ),
                         )
 
+            # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+            # the memberships of all event senders in the timeline. This is because we
+            # may not have sent the memberships in a previous sync.
+
+            # When `include_redundant_members` is on, we send all the lazy-loaded
+            # memberships of event senders. Otherwise we make an effort to limit the set
+            # of memberships we send to those that we have not already sent to this client.
             if lazy_load_members and not include_redundant_members:
                 cache_key = (sync_config.user.to_string(), sync_config.device_id)
                 cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -2216,8 +2251,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
 def _calculate_state(
     timeline_contains: StateMap[str],
     timeline_start: StateMap[str],
-    previous: StateMap[str],
-    current: StateMap[str],
+    timeline_end: StateMap[str],
+    previous_timeline_end: StateMap[str],
     lazy_load_members: bool,
 ) -> StateMap[str]:
     """Works out what state to include in a sync response.
@@ -2225,45 +2260,50 @@ def _calculate_state(
     Args:
         timeline_contains: state in the timeline
         timeline_start: state at the start of the timeline
-        previous: state at the end of the previous sync (or empty dict
+        timeline_end: state at the end of the timeline
+        previous_timeline_end: state at the end of the previous sync (or empty dict
             if this is an initial sync)
-        current: state at the end of the timeline
         lazy_load_members: whether to return members from timeline_start
             or not.  assumes that timeline_start has already been filtered to
             include only the members the client needs to know about.
     """
-    event_id_to_key = {
-        e: key
-        for key, e in itertools.chain(
+    event_id_to_state_key = {
+        event_id: state_key
+        for state_key, event_id in itertools.chain(
             timeline_contains.items(),
-            previous.items(),
             timeline_start.items(),
-            current.items(),
+            timeline_end.items(),
+            previous_timeline_end.items(),
         )
     }
 
-    c_ids = set(current.values())
-    ts_ids = set(timeline_start.values())
-    p_ids = set(previous.values())
-    tc_ids = set(timeline_contains.values())
+    timeline_end_ids = set(timeline_end.values())
+    timeline_start_ids = set(timeline_start.values())
+    previous_timeline_end_ids = set(previous_timeline_end.values())
+    timeline_contains_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
     # as we may not have sent them to the client before.  We find these membership
     # events by filtering them out of timeline_start, which has already been filtered
     # to only include membership events for the senders in the timeline.
-    # In practice, we can do this by removing them from the p_ids list,
-    # which is the list of relevant state we know we have already sent to the client.
+    # In practice, we can do this by removing them from the previous_timeline_end_ids
+    # list, which is the list of relevant state we know we have already sent to the
+    # client.
     # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
-        p_ids.difference_update(
+        previous_timeline_end_ids.difference_update(
             e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (
+        (timeline_end_ids | timeline_start_ids)
+        - previous_timeline_end_ids
+        - timeline_contains_ids
+    )
 
-    return {event_id_to_key[e]: e for e in state_ids}
+    return {event_id_to_state_key[e]: e for e in state_ids}
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fa3f76c27f..d1fa2cf8ae 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import (
     Any,
     Callable,
     Collection,
+    ContextManager,
     Dict,
     Generator,
     Iterable,
@@ -823,75 +824,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def _custom_sync_async_decorator(
+    func: Callable[P, R],
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function with a custom opname.
-
-    See the module's doc string for usage examples.
+    Decorates a function that is sync or async (coroutines), or that returns a Twisted
+    `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+    Example usage:
+    ```py
+    # Decorator to time the function and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+            start_ts = time.time()
+            try:
+                yield
+            finally:
+                end_ts = time.time()
+                duration = end_ts - start_ts
+                logger.info("%s took %s seconds", func.__name__, duration)
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+    ```
 
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func: Callable[P, R]) -> Callable[P, R]:
-        if opentracing is None:
-            return func  # type: ignore[unreachable]
+    if inspect.iscoroutinefunction(func):
 
-        if inspect.iscoroutinefunction(func):
+        @wraps(func)
+        async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)  # type: ignore[misc]
 
-            @wraps(func)
-            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                with start_active_span(opname):
-                    return await func(*args, **kwargs)  # type: ignore[misc]
+    else:
+        # The other case here handles both sync functions and those
+        # decorated with inlineDeferred.
+        @wraps(func)
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
 
-        else:
-            # The other case here handles both sync functions and those
-            # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                scope = start_active_span(opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
+
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
+
+                    result.addCallbacks(call_back, err_back)
+
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
+
+                    scope.__exit__(None, None, None)
 
-                    return result
+                return result
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
 
-        return _trace_inner  # type: ignore[return-value]
+    return _wrapper  # type: ignore[return-value]
 
-    return decorator
+
+def trace_with_opname(
+    opname: str,
+    *,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+    See the module's doc string for usage examples.
+    """
+
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
+        with start_active_span(opname, tracer=tracer):
+            yield
+
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+        if not opentracing:
+            return func
+
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+
+    return _decorator
 
 
 def trace(func: Callable[P, R]) -> Callable[P, R]:
     """
     Decorator to trace a function.
-
     Sets the operation name to that of the function's name.
-
     See the module's doc string for usage examples.
     """
 
@@ -900,7 +943,7 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
 
     Args:
         func: `func` is assumed to be a method taking a `self` parameter, or a
@@ -911,22 +954,25 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     if not opentracing:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
         # We use `[1:]` to skip the `self` object reference and `start=1` to
         # make the index line up with `argspec.args`.
         #
-        # FIXME: We could update this handle any type of function by ignoring the
+        # FIXME: We could update this to handle any type of function by ignoring the
         #   first argument only if it's named `self` or `cls`. This isn't fool-proof
         #   but handles the idiomatic cases.
         for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
             set_tag("ARG_" + argspec.args[i], str(arg))
         set_tag("args", str(args[len(argspec.args) :]))  # type: ignore[index]
         set_tag("kwargs", str(kwargs))
-        return func(*args, **kwargs)
+        yield
 
-    return _tag_args_inner
+    return _custom_sync_async_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e9ff6cfb34..b07d812ae2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2200,3 +2200,63 @@ class EventsWorkerStore(SQLBaseStore):
             (room_id,),
         )
         return [row[0] for row in txn]
+
+    def mark_event_rejected_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        rejection_reason: Optional[str],
+    ) -> None:
+        """Mark an event that was previously accepted as rejected, or vice versa
+
+        This can happen, for example, when resyncing state during a faster join.
+
+        Args:
+            txn:
+            event_id: ID of event to update
+            rejection_reason: reason it has been rejected, or None if it is now accepted
+        """
+        if rejection_reason is None:
+            logger.info(
+                "Marking previously-processed event %s as accepted",
+                event_id,
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                "rejections",
+                keyvalues={"event_id": event_id},
+            )
+        else:
+            logger.info(
+                "Marking previously-processed event %s as rejected(%s)",
+                event_id,
+                rejection_reason,
+            )
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": event_id},
+                values={
+                    "reason": rejection_reason,
+                    "last_check": self._clock.time_msec(),
+                },
+            )
+        self.db_pool.simple_update_txn(
+            txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            updatevalues={"rejection_reason": rejection_reason},
+        )
+
+        self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+        # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+        #   call '_send_invalidation_to_replication', but we actually need the other
+        #   end to call _invalidate_local_get_event_cache() rather than (just)
+        #   _get_event_cache.invalidate().
+        #
+        #   One solution might be to (somehow) get the workers to call
+        #   _invalidate_caches_for_event() (though that will invalidate more than
+        #   strictly necessary).
+        #
+        #   https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index f70705a0af..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -430,6 +430,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             updatevalues={"state_group": state_group},
         )
 
+        # the event may now be rejected where it was not before, or vice versa,
+        # in which case we need to update the rejected flags.
+        if bool(context.rejected) != (event.rejected_reason is not None):
+            self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
         self.db_pool.simple_delete_one_txn(
             txn,
             table="partial_state_events",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index af3bab2c15..0004d955b4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,15 +539,6 @@ class StateFilter:
             is_mine_id: a callable which confirms if a given state_key matches a mxid
                of a local user
         """
-
-        # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
-        #  there may be circumstances in which we return a piece of state that, once we
-        #  resync the state, we discover is invalid. For example: if it turns out that
-        #  the sender of a piece of state wasn't actually in the room, then clearly that
-        #  state shouldn't have been returned.
-        #  We should at least add some tests around this to see what happens.
-        #  https://github.com/matrix-org/synapse/issues/13006
-
         # if we haven't requested membership events, then it depends on the value of
         # 'include_others'
         if EventTypes.Member not in self.types:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index d947edde66..c810a05907 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -73,8 +73,8 @@ async def filter_events_for_client(
           * the user is not currently a member of the room, and:
           * the user has not been a member of the room since the given
             events
-        always_include_ids: set of event ids to specifically
-            include (unless sender is ignored)
+        always_include_ids: set of event ids to specifically include, if present
+            in events (unless sender is ignored)
         filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.