summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py141
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/storage/event_federation.py66
-rw-r--r--synapse/storage/events.py15
-rw-r--r--synapse/storage/stream.py138
-rw-r--r--synapse/streams/events.py6
-rw-r--r--synapse/types.py52
8 files changed, 310 insertions, 126 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 85e2757227..1093112587 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,7 +18,7 @@
 from ._base import BaseHandler
 
 from synapse.api.errors import (
-    AuthError, FederationError, StoreError,
+    AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.util.logutils import log_function
@@ -29,6 +29,8 @@ from synapse.crypto.event_signing import (
 )
 from synapse.types import UserID
 
+from synapse.util.retryutils import NotRetryingDestination
+
 from twisted.internet import defer
 
 import itertools
@@ -218,10 +220,11 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit):
+    def backfill(self, dest, room_id, limit, extremities=[]):
         """ Trigger a backfill request to `dest` for the given `room_id`
         """
-        extremities = yield self.store.get_oldest_events_in_room(room_id)
+        if not extremities:
+            extremities = yield self.store.get_oldest_events_in_room(room_id)
 
         pdus = yield self.replication_layer.backfill(
             dest,
@@ -249,6 +252,138 @@ class FederationHandler(BaseHandler):
         defer.returnValue(events)
 
     @defer.inlineCallbacks
+    def maybe_backfill(self, room_id, current_depth):
+        """Checks the database to see if we should backfill before paginating,
+        and if so do.
+        """
+        extremities = yield self.store.get_oldest_events_with_depth_in_room(
+            room_id
+        )
+
+        if not extremities:
+            logger.debug("Not backfilling as no extremeties found.")
+            return
+
+        # Check if we reached a point where we should start backfilling.
+        sorted_extremeties_tuple = sorted(
+            extremities.items(),
+            key=lambda e: -int(e[1])
+        )
+        max_depth = sorted_extremeties_tuple[0][1]
+
+        if current_depth > max_depth:
+            logger.debug(
+                "Not backfilling as we don't need to. %d < %d",
+                max_depth, current_depth,
+            )
+            return
+
+        # Now we need to decide which hosts to hit first.
+
+        # First we try hosts that are already in the room
+        # TODO: HEURISTIC ALERT.
+
+        curr_state = yield self.state_handler.get_current_state(room_id)
+
+        def get_domains_from_state(state):
+            joined_users = [
+                (state_key, int(event.depth))
+                for (e_type, state_key), event in state.items()
+                if e_type == EventTypes.Member
+                and event.membership == Membership.JOIN
+            ]
+
+            joined_domains = {}
+            for u, d in joined_users:
+                try:
+                    dom = UserID.from_string(u).domain
+                    old_d = joined_domains.get(dom)
+                    if old_d:
+                        joined_domains[dom] = min(d, old_d)
+                    else:
+                        joined_domains[dom] = d
+                except:
+                    pass
+
+            return sorted(joined_domains.items(), key=lambda d: d[1])
+
+        curr_domains = get_domains_from_state(curr_state)
+
+        likely_domains = [
+            domain for domain, depth in curr_domains
+        ]
+
+        @defer.inlineCallbacks
+        def try_backfill(domains):
+            # TODO: Should we try multiple of these at a time?
+            for dom in domains:
+                try:
+                    events = yield self.backfill(
+                        dom, room_id,
+                        limit=100,
+                        extremities=[e for e in extremities.keys()]
+                    )
+                except SynapseError:
+                    logger.info(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+                except CodeMessageException as e:
+                    if 400 <= e.code < 500:
+                        raise
+
+                    logger.info(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+                except NotRetryingDestination as e:
+                    logger.info(e.message)
+                    continue
+                except Exception as e:
+                    logger.warn(
+                        "Failed to backfill from %s because %s",
+                        dom, e,
+                    )
+                    continue
+
+                if events:
+                    defer.returnValue(True)
+            defer.returnValue(False)
+
+        success = yield try_backfill(likely_domains)
+        if success:
+            defer.returnValue(True)
+
+        # Huh, well *those* domains didn't work out. Lets try some domains
+        # from the time.
+
+        tried_domains = set(likely_domains)
+
+        event_ids = list(extremities.keys())
+
+        states = yield defer.gatherResults([
+            self.state_handler.resolve_state_groups([e])
+            for e in event_ids
+        ])
+        states = dict(zip(event_ids, [s[1] for s in states]))
+
+        for e_id, _ in sorted_extremeties_tuple:
+            likely_domains = get_domains_from_state(states[e_id])
+
+            success = yield try_backfill([
+                dom for dom in likely_domains
+                if dom not in tried_domains
+            ])
+            if success:
+                defer.returnValue(True)
+
+            tried_domains.update(likely_domains)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
     def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 22e19af17f..1809a44a99 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,7 +21,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.util.logcontext import PreserveLoggingContext
-from synapse.types import UserID
+from synapse.types import UserID, RoomStreamToken
 
 from ._base import BaseHandler
 
@@ -89,9 +89,19 @@ class MessageHandler(BaseHandler):
 
         if not pagin_config.from_token:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token()
+                yield self.hs.get_event_sources().get_current_token(
+                    direction='b'
+                )
             )
 
+        room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
+        if room_token.topological is None:
+            raise SynapseError(400, "Invalid token")
+
+        yield self.hs.get_handlers().federation_handler.maybe_backfill(
+            room_id, room_token.topological
+        )
+
         user = UserID.from_string(user_id)
 
         events, next_key = yield data_source.get_pagination_rows(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cfa2e38ed2..29b6d52757 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -577,8 +577,8 @@ class RoomEventSource(object):
 
         defer.returnValue((events, end_key))
 
-    def get_current_key(self):
-        return self.store.get_room_events_max_id()
+    def get_current_key(self, direction='f'):
+        return self.store.get_room_events_max_id(direction)
 
     @defer.inlineCallbacks
     def get_pagination_rows(self, user, config, key):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 74b4e23590..a1982dfbb5 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
+    def get_oldest_events_with_depth_in_room(self, room_id):
+        return self.runInteraction(
+            "get_oldest_events_with_depth_in_room",
+            self.get_oldest_events_with_depth_in_room_txn,
+            room_id,
+        )
+
+    def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
+        sql = (
+            "SELECT b.event_id, MAX(e.depth) FROM events as e"
+            " INNER JOIN event_edges as g"
+            " ON g.event_id = e.event_id AND g.room_id = e.room_id"
+            " INNER JOIN event_backward_extremities as b"
+            " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
+            " WHERE b.room_id = ? AND g.is_state is ?"
+            " GROUP BY b.event_id"
+        )
+
+        txn.execute(sql, (room_id, False,))
+
+        return dict(txn.fetchall())
+
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
             txn,
@@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore):
         do_insert = depth < min_depth if min_depth else True
 
         if do_insert:
-            self._simple_insert_txn(
+            self._simple_upsert_txn(
                 txn,
                 table="room_depth",
-                values={
+                keyvalues={
                     "room_id": room_id,
+                },
+                values={
                     "min_depth": depth,
                 },
             )
@@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore):
 
                 txn.execute(query, (event_id, room_id))
 
-            # Insert all the prev_events as a backwards thing, they'll get
-            # deleted in a second if they're incorrect anyway.
-            self._simple_insert_many_txn(
-                txn,
-                table="event_backward_extremities",
-                values=[
-                    {
-                        "event_id": e_id,
-                        "room_id": room_id,
-                    }
-                    for e_id, _ in prev_events
-                ],
+            query = (
+                "INSERT INTO event_backward_extremities (event_id, room_id)"
+                " SELECT ?, ? WHERE NOT EXISTS ("
+                " SELECT 1 FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
+                " )"
+                " AND NOT EXISTS ("
+                " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
+                " )"
             )
 
-            # Also delete from the backwards extremities table all ones that
-            # reference events that we have already seen
+            txn.executemany(query, [
+                (e_id, room_id, e_id, room_id, e_id, room_id, )
+                for e_id, _ in prev_events
+            ])
+
             query = (
-                "DELETE FROM event_backward_extremities WHERE EXISTS ("
-                "SELECT 1 FROM events "
-                "WHERE "
-                "event_backward_extremities.event_id = events.event_id "
-                "AND not events.outlier "
-                ")"
+                "DELETE FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
             )
-            txn.execute(query)
+            txn.execute(query, (event_id, room_id))
 
             txn.call_after(
                 self.get_latest_event_ids_in_room.invalidate, room_id
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 626a5eaf6e..a5a6869079 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore):
         outlier = event.internal_metadata.is_outlier()
 
         if not outlier:
-            self._store_state_groups_txn(txn, event, context)
-
             self._update_min_depth_for_room_txn(
                 txn,
                 event.room_id,
                 event.depth
             )
 
-        have_persisted = self._simple_select_one_onecol_txn(
+        have_persisted = self._simple_select_one_txn(
             txn,
-            table="event_json",
+            table="events",
             keyvalues={"event_id": event.event_id},
-            retcol="event_id",
+            retcols=["event_id", "outlier"],
             allow_none=True,
         )
 
@@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore):
         # if we are persisting an event that we had persisted as an outlier,
         # but is no longer one.
         if have_persisted:
-            if not outlier:
+            if not outlier and have_persisted["outlier"]:
+                self._store_state_groups_txn(txn, event, context)
+
                 sql = (
                     "UPDATE event_json SET internal_metadata = ?"
                     " WHERE event_id = ?"
@@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore):
                 )
             return
 
+        if not outlier:
+            self._store_state_groups_txn(txn, event, context)
+
         self._handle_prev_events(
             txn,
             outlier=outlier,
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 280d4ad605..8045e17fd7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,11 +37,9 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
 
-from collections import namedtuple
-
 import logging
 
 
@@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
-class _StreamToken(namedtuple("_StreamToken", "topological stream")):
-    """Tokens are positions between events. The token "s1" comes after event 1.
-
-            s0    s1
-            |     |
-        [0] V [1] V [2]
-
-    Tokens can either be a point in the live event stream or a cursor going
-    through historic events.
-
-    When traversing the live event stream events are ordered by when they
-    arrived at the homeserver.
-
-    When traversing historic events the events are ordered by their depth in
-    the event graph "topological_ordering" and then by when they arrived at the
-    homeserver "stream_ordering".
-
-    Live tokens start with an "s" followed by the "stream_ordering" id of the
-    event it comes after. Historic tokens start with a "t" followed by the
-    "topological_ordering" id of the event it comes after, follewed by "-",
-    followed by the "stream_ordering" id of the event it comes after.
-    """
-    __slots__ = []
-
-    @classmethod
-    def parse(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-            if string[0] == 't':
-                parts = string[1:].split('-', 1)
-                return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    @classmethod
-    def parse_stream_token(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    def __str__(self):
-        if self.topological is not None:
-            return "t%d-%d" % (self.topological, self.stream)
-        else:
-            return "s%d" % (self.stream,)
+def lower_bound(token):
+    if token.topological is None:
+        return "(%d < %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d < %s OR (%d = %s AND %d < %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
-    def lower_bound(self):
-        if self.topological is None:
-            return "(%d < %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d < %s OR (%d = %s AND %d < %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
 
-    def upper_bound(self):
-        if self.topological is None:
-            return "(%d >= %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d > %s OR (%d = %s AND %d >= %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
+def upper_bound(token):
+    if token.topological is None:
+        return "(%d >= %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
 
 class StreamStore(SQLBaseStore):
@@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             defer.returnValue(([], to_key))
@@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             return defer.succeed(([], to_key))
@@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
         args = [False, room_id]
         if direction == 'b':
             order = "DESC"
-            bounds = _StreamToken.parse(from_key).upper_bound()
+            bounds = upper_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).lower_bound()
+                    bounds, lower_bound(RoomStreamToken.parse(to_key))
                 )
         else:
             order = "ASC"
-            bounds = _StreamToken.parse(from_key).lower_bound()
+            bounds = lower_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).upper_bound()
+                    bounds, upper_bound(RoomStreamToken.parse(to_key))
                 )
 
         if int(limit) > 0:
@@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore):
                     # when we are going backwards so we subtract one from the
                     # stream part.
                     toke -= 1
-                next_token = str(_StreamToken(topo, toke))
+                next_token = str(RoomStreamToken(topo, toke))
             else:
                 # TODO (erikj): We should work out what to do here instead.
                 next_token = to_key if to_key else from_key
@@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore):
                                    with_feedback=False, from_token=None):
         # TODO (erikj): Handle compressed feedback
 
-        end_token = _StreamToken.parse_stream_token(end_token)
+        end_token = RoomStreamToken.parse_stream_token(end_token)
 
         if from_token is None:
             sql = (
@@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore):
                 " LIMIT ?"
             )
         else:
-            from_token = _StreamToken.parse_stream_token(from_token)
+            from_token = RoomStreamToken.parse_stream_token(from_token)
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
@@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore):
                 # stream part.
                 topo = rows[0]["topological_ordering"]
                 toke = rows[0]["stream_ordering"] - 1
-                start_token = str(_StreamToken(topo, toke))
+                start_token = str(RoomStreamToken(topo, toke))
 
                 token = (start_token, str(end_token))
             else:
@@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_room_events_max_id(self):
+    def get_room_events_max_id(self, direction='f'):
         token = yield self._stream_id_gen.get_max_token(self)
-        defer.returnValue("s%d" % (token,))
+        if direction != 'b':
+            defer.returnValue("s%d" % (token,))
+        else:
+            topo = yield self.runInteraction(
+                "_get_max_topological_txn", self._get_max_topological_txn
+            )
+            defer.returnValue("t%d-%d" % (topo, token))
+
+    def _get_max_topological_txn(self, txn):
+        txn.execute(
+            "SELECT MAX(topological_ordering) FROM events"
+            " WHERE outlier = ?",
+            (False,)
+        )
+
+        rows = txn.fetchall()
+        return rows[0][0] if rows else 0
 
     @defer.inlineCallbacks
     def _get_min_token(self):
@@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore):
             stream = row["stream_ordering"]
             topo = event.depth
             internal = event.internal_metadata
-            internal.before = str(_StreamToken(topo, stream - 1))
-            internal.after = str(_StreamToken(topo, stream))
+            internal.before = str(RoomStreamToken(topo, stream - 1))
+            internal.after = str(RoomStreamToken(topo, stream))
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 5c8e54b78b..dff7970bea 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -31,7 +31,7 @@ class NullSource(object):
     def get_new_events_for_user(self, user, from_key, limit):
         return defer.succeed(([], from_key))
 
-    def get_current_key(self):
+    def get_current_key(self, direction='f'):
         return defer.succeed(0)
 
     def get_pagination_rows(self, user, pagination_config, key):
@@ -52,10 +52,10 @@ class EventSources(object):
         }
 
     @defer.inlineCallbacks
-    def get_current_token(self):
+    def get_current_token(self, direction='f'):
         token = StreamToken(
             room_key=(
-                yield self.sources["room"].get_current_key()
+                yield self.sources["room"].get_current_key(direction)
             ),
             presence_key=(
                 yield self.sources["presence"].get_current_key()
diff --git a/synapse/types.py b/synapse/types.py
index f6a1b0bbcf..0f16867d75 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -121,4 +121,56 @@ class StreamToken(
         return StreamToken(**d)
 
 
+class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
+    """Tokens are positions between events. The token "s1" comes after event 1.
+
+            s0    s1
+            |     |
+        [0] V [1] V [2]
+
+    Tokens can either be a point in the live event stream or a cursor going
+    through historic events.
+
+    When traversing the live event stream events are ordered by when they
+    arrived at the homeserver.
+
+    When traversing historic events the events are ordered by their depth in
+    the event graph "topological_ordering" and then by when they arrived at the
+    homeserver "stream_ordering".
+
+    Live tokens start with an "s" followed by the "stream_ordering" id of the
+    event it comes after. Historic tokens start with a "t" followed by the
+    "topological_ordering" id of the event it comes after, follewed by "-",
+    followed by the "stream_ordering" id of the event it comes after.
+    """
+    __slots__ = []
+
+    @classmethod
+    def parse(cls, string):
+        try:
+            if string[0] == 's':
+                return cls(topological=None, stream=int(string[1:]))
+            if string[0] == 't':
+                parts = string[1:].split('-', 1)
+                return cls(topological=int(parts[0]), stream=int(parts[1]))
+        except:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    @classmethod
+    def parse_stream_token(cls, string):
+        try:
+            if string[0] == 's':
+                return cls(topological=None, stream=int(string[1:]))
+        except:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    def __str__(self):
+        if self.topological is not None:
+            return "t%d-%d" % (self.topological, self.stream)
+        else:
+            return "s%d" % (self.stream,)
+
+
 ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))