summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/app/federation_reader.py8
-rw-r--r--synapse/handlers/federation.py44
-rw-r--r--synapse/replication/http/__init__.py3
-rw-r--r--synapse/replication/http/federation.py245
4 files changed, 290 insertions, 10 deletions
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 7af00b8bcf..c512b4be87 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -32,9 +32,13 @@ from synapse.http.site import SynapseSite
 from synapse.metrics import RegistryProxy
 from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
@@ -49,6 +53,10 @@ logger = logging.getLogger("synapse.app.federation_reader")
 
 
 class FederationReaderSlavedStore(
+    SlavedApplicationServiceStore,
+    SlavedPusherStore,
+    SlavedPushRuleStore,
+    SlavedReceiptsStore,
     SlavedEventStore,
     SlavedKeyStore,
     RoomStore,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 533b82c783..0524dec942 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -44,6 +44,8 @@ from synapse.crypto.event_signing import (
     compute_event_signature,
 )
 from synapse.events.validator import EventValidator
+from synapse.replication.http.federation import send_federation_events_to_master
+from synapse.replication.http.membership import notify_user_membership_change
 from synapse.state import resolve_events_with_factory
 from synapse.types import UserID, get_domain_from_id
 from synapse.util import logcontext, unwrapFirstError
@@ -86,6 +88,8 @@ class FederationHandler(BaseHandler):
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
+        self.config = hs.config
+        self.http_client = hs.get_simple_http_client()
 
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
@@ -2288,7 +2292,7 @@ class FederationHandler(BaseHandler):
                 for revocation.
         """
         try:
-            response = yield self.hs.get_simple_http_client().get_json(
+            response = yield self.http_client.get_json(
                 url,
                 {"public_key": public_key}
             )
@@ -2313,14 +2317,25 @@ class FederationHandler(BaseHandler):
         Returns:
             Deferred
         """
-        max_stream_id = yield self.store.persist_events(
-            event_and_contexts,
-            backfilled=backfilled,
-        )
+        if self.config.worker_app:
+            yield send_federation_events_to_master(
+                clock=self.hs.get_clock(),
+                store=self.store,
+                client=self.http_client,
+                host=self.config.worker_replication_host,
+                port=self.config.worker_replication_http_port,
+                event_and_contexts=event_and_contexts,
+                backfilled=backfilled
+            )
+        else:
+            max_stream_id = yield self.store.persist_events(
+                event_and_contexts,
+                backfilled=backfilled,
+            )
 
-        if not backfilled:  # Never notify for backfilled events
-            for event, _ in event_and_contexts:
-                self._notify_persisted_event(event, max_stream_id)
+            if not backfilled:  # Never notify for backfilled events
+                for event, _ in event_and_contexts:
+                    self._notify_persisted_event(event, max_stream_id)
 
     def _notify_persisted_event(self, event, max_stream_id):
         """Checks to see if notifier/pushers should be notified about the
@@ -2359,9 +2374,20 @@ class FederationHandler(BaseHandler):
         )
 
     def _clean_room_for_join(self, room_id):
+        # TODO move this out to master
         return self.store.clean_room_for_join(room_id)
 
     def user_joined_room(self, user, room_id):
         """Called when a new user has joined the room
         """
-        return user_joined_room(self.distributor, user, room_id)
+        if self.config.worker_app:
+            return notify_user_membership_change(
+                client=self.http_client,
+                host=self.config.worker_replication_host,
+                port=self.config.worker_replication_http_port,
+                room_id=room_id,
+                user_id=user.to_string(),
+                change="joined",
+            )
+        else:
+            return user_joined_room(self.distributor, user, room_id)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 589ee94c66..19f214281e 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.http.server import JsonResource
-from synapse.replication.http import membership, send_event
+from synapse.replication.http import federation, membership, send_event
 
 REPLICATION_PREFIX = "/_synapse/replication"
 
@@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
     def register_servlets(self, hs):
         send_event.register_servlets(hs, self)
         membership.register_servlets(hs, self)
+        federation.register_servlets(hs, self)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
new file mode 100644
index 0000000000..f39aaa89be
--- /dev/null
+++ b/synapse/replication/http/federation.py
@@ -0,0 +1,245 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import UserID
+from synapse.util.logcontext import run_in_background
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
+    """Handles events newly received from federation, including persisting and
+    notifying.
+
+    The API looks like:
+
+        POST /_synapse/replication/fed_send_events/:txn_id
+
+        {
+            "events": [{
+                "event": { .. serialized event .. },
+                "internal_metadata": { .. serialized internal_metadata .. },
+                "rejected_reason": ..,   // The event.rejected_reason field
+                "context": { .. serialized event context .. },
+            }],
+            "backfilled": false
+    """
+
+    NAME = "fed_send_events"
+    PATH_ARGS = ()
+
+    def __init__(self, hs):
+        super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.is_mine_id = hs.is_mine_id
+        self.notifier = hs.get_notifier()
+        self.pusher_pool = hs.get_pusherpool()
+
+    @defer.inlineCallbacks
+    @staticmethod
+    def _serialize_payload(store, event_and_contexts, backfilled):
+        """
+        Args:
+            store
+            event_and_contexts (list[tuple[FrozenEvent, EventContext]])
+            backfilled (bool): Whether or not the events are the result of
+                backfilling
+        """
+        event_payloads = []
+        for event, context in event_and_contexts:
+            serialized_context = yield context.serialize(event, store)
+
+            event_payloads.append({
+                "event": event.get_pdu_json(),
+                "internal_metadata": event.internal_metadata.get_dict(),
+                "rejected_reason": event.rejected_reason,
+                "context": serialized_context,
+            })
+
+        payload = {
+            "events": event_payloads,
+            "backfilled": backfilled,
+        }
+
+        defer.returnValue(payload)
+
+    @defer.inlineCallbacks
+    def _handle_request(self, request):
+        with Measure(self.clock, "repl_fed_send_events_parse"):
+            content = parse_json_object_from_request(request)
+
+            backfilled = content["backfilled"]
+
+            event_payloads = content["events"]
+
+            event_and_contexts = []
+            for event_payload in event_payloads:
+                event_dict = event_payload["event"]
+                internal_metadata = event_payload["internal_metadata"]
+                rejected_reason = event_payload["rejected_reason"]
+                event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+                context = yield EventContext.deserialize(
+                    self.store, event_payload["context"],
+                )
+
+                event_and_contexts.append((event, context))
+
+        logger.info(
+            "Got %d events from federation",
+            len(event_and_contexts),
+        )
+
+        max_stream_id = yield self.store.persist_events(
+            event_and_contexts,
+            backfilled=backfilled
+        )
+
+        if not backfilled:
+            for event, _ in event_and_contexts:
+                self._notify_persisted_event(event, max_stream_id)
+
+        defer.returnValue((200, {}))
+
+    def _notify_persisted_event(self, event, max_stream_id):
+        extra_users = []
+        if event.type == EventTypes.Member:
+            target_user_id = event.state_key
+
+            # We notify for memberships if its an invite for one of our
+            # users
+            if event.internal_metadata.is_outlier():
+                if event.membership != Membership.INVITE:
+                    if not self.is_mine_id(target_user_id):
+                        return
+
+            target_user = UserID.from_string(target_user_id)
+            extra_users.append(target_user)
+        elif event.internal_metadata.is_outlier():
+            return
+
+        event_stream_id = event.internal_metadata.stream_ordering
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=extra_users
+        )
+
+        run_in_background(
+            self.pusher_pool.on_new_notifications,
+            event_stream_id, max_stream_id,
+        )
+
+
+class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
+    """Handles EDUs newly received from federation, including persisting and
+    notifying.
+    """
+
+    NAME = "fed_send_edu"
+    PATH_ARGS = ("edu_type",)
+
+    def __init__(self, hs):
+        super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.registry = hs.get_federation_registry()
+
+    @staticmethod
+    def _serialize_payload(edu_type, origin, content):
+        return {
+            "origin": origin,
+            "content": content,
+        }
+
+    @defer.inlineCallbacks
+    def _handle_request(self, request, edu_type):
+        with Measure(self.clock, "repl_fed_send_edu_parse"):
+            content = parse_json_object_from_request(request)
+
+            origin = content["origin"]
+            edu_content = content["content"]
+
+        logger.info(
+            "Got %r edu from $s",
+            edu_type, origin,
+        )
+
+        result = yield self.registry.on_edu(edu_type, origin, edu_content)
+
+        defer.returnValue((200, result))
+
+
+class ReplicationGetQueryRestServlet(ReplicationEndpoint):
+    """Handle responding to queries from federation.
+    """
+
+    NAME = "fed_query"
+    PATH_ARGS = ("query_type",)
+
+    # This is a query, so let's not bother caching
+    CACHE = False
+
+    def __init__(self, hs):
+        super(ReplicationGetQueryRestServlet, self).__init__(hs)
+
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.registry = hs.get_federation_registry()
+
+    @staticmethod
+    def _serialize_payload(query_type, args):
+        """
+        Args:
+            query_type (str)
+            args (dict): The arguments received for the given query type
+        """
+        return {
+            "args": args,
+        }
+
+    @defer.inlineCallbacks
+    def _handle_request(self, request, query_type):
+        with Measure(self.clock, "repl_fed_query_parse"):
+            content = parse_json_object_from_request(request)
+
+            args = content["args"]
+
+        logger.info(
+            "Got %r query",
+            query_type,
+        )
+
+        result = yield self.registry.on_query(query_type, args)
+
+        defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+    ReplicationFederationSendEventsRestServlet(hs).register(http_server)
+    ReplicationFederationSendEduRestServlet(hs).register(http_server)
+    ReplicationGetQueryRestServlet(hs).register(http_server)