summary refs log tree commit diff
path: root/synapse/replication/tcp/streams
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams')
-rw-r--r--synapse/replication/tcp/streams/__init__.py1
-rw-r--r--synapse/replication/tcp/streams/_base.py21
-rw-r--r--synapse/replication/tcp/streams/events.py130
3 files changed, 118 insertions, 34 deletions
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5c715e3bfa..634f636dc9 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -44,7 +44,6 @@ STREAMS_MAP = {
         federation.FederationStream,
         _base.TagAccountDataStream,
         _base.AccountDataStream,
-        _base.CurrentStateDeltaStream,
         _base.GroupServerStream,
     )
 }
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 13ab1bee05..8971a6a22e 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -91,12 +91,6 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
     "data_type",  # str
     "data",  # dict
 ))
-CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
-    "room_id",  # str
-    "type",  # str
-    "state_key",  # str
-    "event_id",  # str, optional
-))
 GroupsStreamRow = namedtuple("GroupsStreamRow", (
     "group_id",  # str
     "user_id",  # str
@@ -428,21 +422,6 @@ class AccountDataStream(Stream):
         defer.returnValue(results)
 
 
-class CurrentStateDeltaStream(Stream):
-    """Current state for a room was changed
-    """
-    NAME = "current_state_deltas"
-    ROW_TYPE = CurrentStateDeltaStreamRow
-
-    def __init__(self, hs):
-        store = hs.get_datastore()
-
-        self.current_token = store.get_max_current_state_delta_stream_id
-        self.update_function = store.get_all_updated_current_state_deltas
-
-        super(CurrentStateDeltaStream, self).__init__(hs)
-
-
 class GroupServerStream(Stream):
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 511dd6bcc7..e0f6e29248 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,28 +13,134 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from collections import namedtuple
+import heapq
+
+import attr
+
+from twisted.internet import defer
 
 from ._base import Stream
 
-EventStreamRow = namedtuple("EventStreamRow", (
-    "event_id",  # str
-    "room_id",  # str
-    "type",  # str
-    "state_key",  # str, optional
-    "redacts",  # str, optional
-))
+
+"""Handling of the 'events' replication stream
+
+This stream contains rows of various types. Each row therefore contains a 'type'
+identifier before the real data. For example::
+
+    RDATA events batch ["state", ["!room:id", "m.type", "", "$event:id"]]
+    RDATA events 12345 ["ev", ["$event:id", "!room:id", "m.type", null, null]]
+
+An "ev" row is sent for each new event. The fields in the data part are:
+
+ * The new event id
+ * The room id for the event
+ * The type of the new event
+ * The state key of the event, for state events
+ * The event id of an event which is redacted by this event.
+
+A "state" row is sent whenever the "current state" in a room changes. The fields in the
+data part are:
+
+ * The room id for the state change
+ * The event type of the state which has changed
+ * The state_key of the state which has changed
+ * The event id of the new state
+
+"""
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamRow(object):
+    """A parsed row from the events replication stream"""
+    type = attr.ib()  # str: the TypeId of one of the *EventsStreamRows
+    data = attr.ib()  # BaseEventsStreamRow
+
+
+class BaseEventsStreamRow(object):
+    """Base class for rows to be sent in the events stream.
+
+    Specifies how to identify, serialize and deserialize the different types.
+    """
+
+    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+
+    @classmethod
+    def from_data(cls, data):
+        """Parse the data from the replication stream into a row.
+
+        By default we just call the constructor with the data list as arguments
+
+        Args:
+            data: The value of the data object from the replication stream
+        """
+        return cls(*data)
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamEventRow(BaseEventsStreamRow):
+    TypeId = "ev"
+
+    event_id = attr.ib()   # str
+    room_id = attr.ib()    # str
+    type = attr.ib()       # str
+    state_key = attr.ib()  # str, optional
+    redacts = attr.ib()    # str, optional
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamCurrentStateRow(BaseEventsStreamRow):
+    TypeId = "state"
+
+    room_id = attr.ib()    # str
+    type = attr.ib()       # str
+    state_key = attr.ib()  # str
+    event_id = attr.ib()   # str, optional
+
+
+TypeToRow = {
+    Row.TypeId: Row
+    for Row in (
+        EventsStreamEventRow,
+        EventsStreamCurrentStateRow,
+    )
+}
 
 
 class EventsStream(Stream):
     """We received a new event, or an event went from being an outlier to not
     """
     NAME = "events"
-    ROW_TYPE = EventStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
-        self.current_token = store.get_current_events_token
-        self.update_function = store.get_all_new_forward_event_rows
+        self._store = hs.get_datastore()
+        self.current_token = self._store.get_current_events_token
 
         super(EventsStream, self).__init__(hs)
+
+    @defer.inlineCallbacks
+    def update_function(self, from_token, current_token, limit=None):
+        event_rows = yield self._store.get_all_new_forward_event_rows(
+            from_token, current_token, limit,
+        )
+        event_updates = (
+            (row[0], EventsStreamEventRow.TypeId, row[1:])
+            for row in event_rows
+        )
+
+        state_rows = yield self._store.get_all_updated_current_state_deltas(
+            from_token, current_token, limit
+        )
+        state_updates = (
+            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
+            for row in state_rows
+        )
+
+        all_updates = heapq.merge(event_updates, state_updates)
+
+        defer.returnValue(all_updates)
+
+    @classmethod
+    def parse_row(cls, row):
+        (typ, data) = row
+        data = TypeToRow[typ].from_data(data)
+        return EventsStreamRow(typ, data)