summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-11-12 15:44:22 +0000
committerErik Johnston <erik@matrix.org>2018-11-12 15:44:22 +0000
commitdfa830e61aab21aee0edc7b9ffa0c94becf9cdf1 (patch)
tree4f1f5edb28e5b4c36bd430ca9bd1235812dacb54
parentRemove hack to support rejoining rooms (diff)
downloadsynapse-dfa830e61aab21aee0edc7b9ffa0c94becf9cdf1.tar.xz
Store and fetch thread IDs
-rw-r--r--synapse/events/snapshot.py9
-rw-r--r--synapse/handlers/federation.py32
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/storage/events.py3
-rw-r--r--synapse/storage/events_worker.py11
5 files changed, 46 insertions, 14 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 368b5f6ae4..0c77c3c44b 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -74,6 +74,7 @@ class EventContext(object):
         "delta_ids",
         "prev_state_events",
         "app_service",
+        "thread_id",
         "_current_state_ids",
         "_prev_state_ids",
         "_prev_state_id",
@@ -89,8 +90,9 @@ class EventContext(object):
 
     @staticmethod
     def with_state(state_group, current_state_ids, prev_state_ids,
-                   prev_group=None, delta_ids=None):
+                   thread_id, prev_group=None, delta_ids=None):
         context = EventContext()
+        context.thread_id = thread_id
 
         # The current state including the current event
         context._current_state_ids = current_state_ids
@@ -141,7 +143,8 @@ class EventContext(object):
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
             "prev_state_events": self.prev_state_events,
-            "app_service_id": self.app_service.id if self.app_service else None
+            "app_service_id": self.app_service.id if self.app_service else None,
+            "thread_id": self.thread_id,
         })
 
     @staticmethod
@@ -158,6 +161,8 @@ class EventContext(object):
         """
         context = EventContext()
 
+        context.thread_id = input["thread_input"]
+
         # We use the state_group and prev_state_id stuff to pull the
         # current_state_ids out of the DB and construct prev_state_ids.
         context._prev_state_id = input["prev_state_id"]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a3bb864bb2..f6fec0afdd 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,6 +18,7 @@
 
 import itertools
 import logging
+import random
 
 import six
 from six import iteritems, itervalues
@@ -135,7 +136,7 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def on_receive_pdu(
-            self, origin, pdu, sent_to_us_directly=False,
+            self, origin, pdu, sent_to_us_directly=False, thread_id=None,
     ):
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
@@ -222,6 +223,10 @@ class FederationHandler(BaseHandler):
         state = None
         auth_chain = []
 
+        if thread_id is None:
+            # FIXME: Pick something better?
+            thread_id = random.randint(0, 999999999)
+
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
@@ -259,7 +264,8 @@ class FederationHandler(BaseHandler):
                         )
 
                         yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
+                            origin, pdu, prevs, min_depth,
+                            thread_id=thread_id,
                         )
 
                         # Update the set of things we've seen after trying to
@@ -414,15 +420,24 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
+        now = self.clock.time_msec()
+        if now - pdu.origin_server_ts > 2 * 60 * 1000:
+            pass
+        else:
+            thread_id = 0
+
+        logger.info("Thread ID %r", thread_id)
+
         yield self._process_received_pdu(
             origin,
             pdu,
             state=state,
             auth_chain=auth_chain,
+            thread_id=thread_id,
         )
 
     @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth, thread_id):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -529,6 +544,7 @@ class FederationHandler(BaseHandler):
                         origin,
                         ev,
                         sent_to_us_directly=False,
+                        thread_id=thread_id,
                     )
                 except FederationError as e:
                     if e.code == 403:
@@ -540,7 +556,7 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state, auth_chain):
+    def _process_received_pdu(self, origin, event, state, auth_chain, thread_id):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
         """
@@ -592,6 +608,7 @@ class FederationHandler(BaseHandler):
                 origin,
                 event,
                 state=state,
+                thread_id=thread_id,
             )
         except AuthError as e:
             raise FederationError(
@@ -1557,11 +1574,12 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _handle_new_event(self, origin, event, state=None, auth_events=None,
-                          backfilled=False):
+                          backfilled=False, thread_id=0):
         context = yield self._prep_event(
             origin, event,
             state=state,
             auth_events=auth_events,
+            thread_id=thread_id,
         )
 
         # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
@@ -1720,7 +1738,7 @@ class FederationHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def _prep_event(self, origin, event, state=None, auth_events=None):
+    def _prep_event(self, origin, event, state=None, auth_events=None, thread_id=0):
         """
 
         Args:
@@ -1733,7 +1751,7 @@ class FederationHandler(BaseHandler):
             Deferred, which resolves to synapse.events.snapshot.EventContext
         """
         context = yield self.state_handler.compute_event_context(
-            event, old_state=state,
+            event, old_state=state, thread_id=thread_id,
         )
 
         if not auth_events:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 70048b0c09..35041028fe 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -178,7 +178,7 @@ class StateHandler(object):
         defer.returnValue(joined_hosts)
 
     @defer.inlineCallbacks
-    def compute_event_context(self, event, old_state=None):
+    def compute_event_context(self, event, old_state=None, thread_id=0):
         """Build an EventContext structure for the event.
 
         This works out what the current state should be for the event, and
@@ -215,6 +215,7 @@ class StateHandler(object):
             # We don't store state for outliers, so we don't generate a state
             # group for it.
             context = EventContext.with_state(
+                thread_id=0,  # outlier, don't care
                 state_group=None,
                 current_state_ids=current_state_ids,
                 prev_state_ids=prev_state_ids,
@@ -251,6 +252,7 @@ class StateHandler(object):
             )
 
             context = EventContext.with_state(
+                thread_id=thread_id,
                 state_group=state_group,
                 current_state_ids=current_state_ids,
                 prev_state_ids=prev_state_ids,
@@ -319,6 +321,7 @@ class StateHandler(object):
             state_group = entry.state_group
 
         context = EventContext.with_state(
+            thread_id=thread_id,
             state_group=state_group,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2047110b1d..855f859115 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1282,8 +1282,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                         "url" in event.content
                         and isinstance(event.content["url"], text_type)
                     ),
+                    "thread_id": ctx.thread_id,
                 }
-                for event, _ in events_and_contexts
+                for event, ctx in events_and_contexts
             ],
         )
 
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index a8326f5296..5b9cad5522 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -352,7 +352,7 @@ class EventsWorkerStore(SQLBaseStore):
                 run_in_background(
                     self._get_event_from_row,
                     row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
+                    rejected_reason=row["rejects"], thread_id=row["thread_id"],
                 )
                 for row in rows
             ],
@@ -378,8 +378,10 @@ class EventsWorkerStore(SQLBaseStore):
                 " e.internal_metadata,"
                 " e.json,"
                 " r.redacts as redacts,"
-                " rej.event_id as rejects "
+                " rej.event_id as rejects, "
+                " ev.thread_id as thread_id"
                 " FROM event_json as e"
+                " INNER JOIN events as ev USING (event_id)"
                 " LEFT JOIN rejections as rej USING (event_id)"
                 " LEFT JOIN redactions as r ON e.event_id = r.redacts"
                 " WHERE e.event_id IN (%s)"
@@ -392,10 +394,11 @@ class EventsWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
+                            thread_id, rejected_reason=None):
         with Measure(self._clock, "_get_event_from_row"):
             d = json.loads(js)
             internal_metadata = json.loads(internal_metadata)
+            internal_metadata["thread_id"] = thread_id
 
             if rejected_reason:
                 rejected_reason = yield self._simple_select_one_onecol(
@@ -411,6 +414,8 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected_reason=rejected_reason,
             )
 
+            original_ev.unsigned["thread_id"] = thread_id
+
             redacted_event = None
             if redacted:
                 redacted_event = prune_event(original_ev)