summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py102
-rw-r--r--synapse/storage/monthly_active_users.py5
-rw-r--r--synapse/storage/state.py30
-rw-r--r--synapse/storage/stream.py16
-rw-r--r--synapse/storage/transactions.py23
5 files changed, 132 insertions, 44 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e7487311ce..03cedf3a75 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util import batch_iter
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
@@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                             )
 
                         for room_id, ev_ctx_rm in iteritems(events_by_room):
-                            # Work out new extremities by recursively adding and removing
-                            # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
                                 room_id
                             )
-                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                            new_latest_event_ids = yield self._calculate_new_extremities(
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
@@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                                 # No change in extremities, so no change in state
                                 continue
 
+                            # there should always be at least one forward extremity.
+                            # (except during the initial persistence of the send_join
+                            # results, in which case there will be no existing
+                            # extremities, so we'll `continue` above and skip this bit.)
+                            assert new_latest_event_ids, "No forward extremities left!"
+
                             new_forward_extremeties[room_id] = new_latest_event_ids
 
                             len_1 = (
@@ -517,44 +522,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                     )
 
     @defer.inlineCallbacks
-    def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
-        """Calculates the new forward extremeties for a room given events to
+    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+        """Calculates the new forward extremities for a room given events to
         persist.
 
         Assumes that we are only persisting events for one room at a time.
         """
-        new_latest_event_ids = set(latest_event_ids)
-        # First, add all the new events to the list
-        new_latest_event_ids.update(
-            event.event_id for event, ctx in event_contexts
+
+        # we're only interested in new events which aren't outliers and which aren't
+        # being rejected.
+        new_events = [
+            event for event, ctx in event_contexts
             if not event.internal_metadata.is_outlier() and not ctx.rejected
+        ]
+
+        # start with the existing forward extremities
+        result = set(latest_event_ids)
+
+        # add all the new events to the list
+        result.update(
+            event.event_id for event in new_events
         )
-        # Now remove all events that are referenced by the to-be-added events
-        new_latest_event_ids.difference_update(
+
+        # Now remove all events which are prev_events of any of the new events
+        result.difference_update(
             e_id
-            for event, ctx in event_contexts
+            for event in new_events
             for e_id, _ in event.prev_events
-            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
 
-        # And finally remove any events that are referenced by previously added
-        # events.
-        rows = yield self._simple_select_many_batch(
-            table="event_edges",
-            column="prev_event_id",
-            iterable=list(new_latest_event_ids),
-            retcols=["prev_event_id"],
-            keyvalues={
-                "is_state": False,
-            },
-            desc="_calculate_new_extremeties",
-        )
+        # Finally, remove any events which are prev_events of any existing events.
+        existing_prevs = yield self._get_events_which_are_prevs(result)
+        result.difference_update(existing_prevs)
 
-        new_latest_event_ids.difference_update(
-            row["prev_event_id"] for row in rows
-        )
+        defer.returnValue(result)
 
-        defer.returnValue(new_latest_event_ids)
+    @defer.inlineCallbacks
+    def _get_events_which_are_prevs(self, event_ids):
+        """Filter the supplied list of event_ids to get those which are prev_events of
+        existing (non-outlier/rejected) events.
+
+        Args:
+            event_ids (Iterable[str]): event ids to filter
+
+        Returns:
+            Deferred[List[str]]: filtered event ids
+        """
+        results = []
+
+        def _get_events(txn, batch):
+            sql = """
+            SELECT prev_event_id
+            FROM event_edges
+                INNER JOIN events USING (event_id)
+                LEFT JOIN rejections USING (event_id)
+            WHERE
+                prev_event_id IN (%s)
+                AND NOT events.outlier
+                AND rejections.event_id IS NULL
+            """ % (
+                ",".join("?" for _ in batch),
+            )
+
+            txn.execute(sql, batch)
+            results.extend(r[0] for r in txn)
+
+        for chunk in batch_iter(event_ids, 100):
+            yield self.runInteraction(
+                "_get_events_which_are_prevs",
+                _get_events,
+                chunk,
+            )
+
+        defer.returnValue(results)
 
     @defer.inlineCallbacks
     def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
@@ -586,10 +626,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
             the new current state is only returned if we've already calculated
             it.
         """
-
-        if not new_latest_event_ids:
-            return
-
         # map from state_group to ((type, key) -> event_id) state map
         state_groups_map = {}
 
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 59580949f1..0fe8c8e24c 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -172,6 +172,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             Deferred[bool]: True if a new entry was created, False if an
                 existing one was updated.
         """
+        # Am consciously deciding to lock the table on the basis that is ought
+        # never be a big table and alternative approaches (batching multiple
+        # upserts into a single txn) introduced a lot of extra complexity.
+        # See https://github.com/matrix-org/synapse/issues/3854 for more
         is_insert = yield self._simple_upsert(
             desc="upsert_monthly_active_user",
             table="monthly_active_users",
@@ -181,7 +185,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             values={
                 "timestamp": int(self._clock.time_msec()),
             },
-            lock=False,
         )
         if is_insert:
             self.user_last_seen_monthly_active.invalidate((user_id,))
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 4b971efdba..3f4cbd61c4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_state_groups_ids(self, room_id, event_ids):
+    def get_state_groups_ids(self, _room_id, event_ids):
+        """Get the event IDs of all the state for the state groups for the given events
+
+        Args:
+            _room_id (str): id of the room for these events
+            event_ids (iterable[str]): ids of the events
+
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
         if not event_ids:
             defer.returnValue({})
 
@@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_state_ids_for_group(self, state_group):
-        """Get the state IDs for the given state group
+        """Get the event IDs of all the state in the given state group
 
         Args:
             state_group (int)
@@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
 
-        The return value is a dict mapping group names to lists of events.
+        Returns:
+            Deferred[dict[int, list[EventBase]]]:
+                dict of state_group_id -> list of state events.
         """
         if not event_ids:
             defer.returnValue({})
@@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 member events (if True), or to exclude member events (if False)
 
         Returns:
-            dictionary state_group -> (dict of (type, state_key) -> event id)
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         results = {}
 
@@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 If None, `types` filtering is applied to all events.
 
         Returns:
-            Deferred[dict[int, dict[(type, state_key), EventBase]]]
-                a dictionary mapping from state group to state dictionary.
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if types is not None:
             non_member_types = [t for t in types if t[0] != EventTypes.Member]
@@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 If None, `types` filtering is applied to all events.
 
         Returns:
-            Deferred[dict[int, dict[(type, state_key), EventBase]]]
-                a dictionary mapping from state group to state dictionary.
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if types:
             types = frozenset(types)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 4c296d72c0..d6cfdba519 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -630,7 +630,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_all_new_events_stream(self, from_id, current_id, limit):
-        """Get all new events"""
+        """Get all new events
+
+         Returns all events with from_id < stream_ordering <= current_id.
+
+         Args:
+             from_id (int):  the stream_ordering of the last event we processed
+             current_id (int):  the stream_ordering of the most recently processed event
+             limit (int): the maximum number of events to return
+
+         Returns:
+             Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
+             `next_id` is the next value to pass as `from_id` (it will either be the
+             stream_ordering of the last returned event, or, if fewer than `limit` events
+             were found, `current_id`.
+         """
 
         def get_all_new_events_stream_txn(txn):
             sql = (
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index baf0379a68..a3032cdce9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches.expiringcache import ExpiringCache
 
 from ._base import SQLBaseStore, db_to_json
 
@@ -49,6 +50,8 @@ _UpdateTransactionRow = namedtuple(
     )
 )
 
+SENTINEL = object()
+
 
 class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
@@ -59,6 +62,12 @@ class TransactionStore(SQLBaseStore):
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
+        self._destination_retry_cache = ExpiringCache(
+            cache_name="get_destination_retry_timings",
+            clock=self._clock,
+            expiry_ms=5 * 60 * 1000,
+        )
+
     def get_received_txn_response(self, transaction_id, origin):
         """For an incoming transaction from a given origin, check if we have
         already responded to it. If so, return the response code and response
@@ -155,6 +164,7 @@ class TransactionStore(SQLBaseStore):
         """
         pass
 
+    @defer.inlineCallbacks
     def get_destination_retry_timings(self, destination):
         """Gets the current retry timings (if any) for a given destination.
 
@@ -165,10 +175,20 @@ class TransactionStore(SQLBaseStore):
             None if not retrying
             Otherwise a dict for the retry scheme
         """
-        return self.runInteraction(
+
+        result = self._destination_retry_cache.get(destination, SENTINEL)
+        if result is not SENTINEL:
+            defer.returnValue(result)
+
+        result = yield self.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings, destination)
 
+        # We don't hugely care about race conditions between getting and
+        # invalidating the cache, since we time out fairly quickly anyway.
+        self._destination_retry_cache[destination] = result
+        defer.returnValue(result)
+
     def _get_destination_retry_timings(self, txn, destination):
         result = self._simple_select_one_txn(
             txn,
@@ -196,6 +216,7 @@ class TransactionStore(SQLBaseStore):
             retry_interval (int) - how long until next retry in ms
         """
 
+        self._destination_retry_cache.pop(destination, None)
         return self.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,