diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
""" This package includes all the federation specific logic.
"""
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
- transport = hs.get_federation_transport_client()
-
- return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..c11798093d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import six
from twisted.internet import defer
-from synapse.events.utils import prune_event
-
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
-
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_dict
+from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.hs = hs
+
+ self.server_name = hs.hostname
+ self.keyring = hs.get_keyring()
+ self.spam_checker = hs.get_spam_checker()
+ self.store = hs.get_datastore()
+ self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +61,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
-
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
+
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -114,15 +114,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
- """Throws a SynapseError if a PDU does not have the correct
- signatures.
+ """Checks that each of the received events is correctly signed by the
+ sending server.
+
+ Args:
+ pdus (list[FrozenEvent]): the events to be checked
Returns:
- FrozenEvent: Either the given event or it redacted if it failed the
- content hash check.
+ list[Deferred]: for each input event, a deferred which:
+ * returns the original event if the checks pass
+ * returns a redacted version of the event (if the signature
+ matched but the hash did not)
+ * throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,26 +139,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
+ ctx = logcontext.LoggingContext.current_context()
+
def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
- return pdu
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if self.spam_checker.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
@@ -160,3 +181,40 @@ class FederationBase(object):
)
return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+ """Construct a FrozenEvent from an event json received over federation
+
+ Args:
+ pdu_json (object): pdu as received over federation
+ outlier (bool): True to mark this event as an outlier
+
+ Returns:
+ FrozenEvent
+
+ Raises:
+ SynapseError: if the pdu is missing required fields or is otherwise
+ not a valid matrix event
+ """
+ # we could probably enforce a bunch of other fields here (room_id, sender,
+ # origin, etc etc)
+ assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
+
+ depth = pdu_json['depth']
+ if not isinstance(depth, six.integer_types):
+ raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+ Codes.BAD_JSON)
+
+ if depth < 0:
+ raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+ elif depth > MAX_DEPTH:
+ raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 861441708b..62d7ed13cf 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,36 +14,35 @@
# limitations under the License.
+import copy
+import itertools
+import logging
+import random
+
+from six.moves import range
+
+from prometheus_client import Counter
+
from twisted.internet import defer
-from .federation_base import FederationBase
from synapse.api.constants import Membership
-
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError,
+ CodeMessageException,
+ FederationDeniedError,
+ HttpResponseException,
+ SynapseError,
)
-from synapse.util import unwrapFirstError
+from synapse.events import builder
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
from synapse.util.retryutils import NotRetryingDestination
-import copy
-import itertools
-import logging
-import random
-
-
logger = logging.getLogger(__name__)
-
-# synapse.federation.federation_client is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-
-sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
@@ -58,6 +57,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
+ self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
@@ -105,7 +105,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc(query_type)
+ sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
@@ -124,7 +124,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_device_keys")
+ sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(
destination, content, timeout
)
@@ -134,7 +134,7 @@ class FederationClient(FederationBase):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
- sent_queries_counter.inc("user_devices")
+ sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@@ -151,7 +151,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_one_time_keys")
+ sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@@ -184,15 +184,15 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
- self.event_from_pdu_json(p, outlier=False)
+ event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+ pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(pdus)
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
+ event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
@@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ signed_pdu = yield self._check_sigs_and_hash(pdu)
break
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e.message)
+ continue
except Exception as e:
pdu_attempts[destination] = now
@@ -336,11 +339,11 @@ class FederationClient(FederationBase):
)
pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -388,9 +391,9 @@ class FederationClient(FederationBase):
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = seen_events.values()
+ signed_events = list(seen_events.values())
else:
- seen_events = yield self.store.have_events(event_ids)
+ seen_events = yield self.store.have_seen_events(event_ids)
signed_events = []
failed_to_fetch = set()
@@ -409,18 +412,19 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
- for i in xrange(0, len(missing_events), batch_size):
+ for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- preserve_fn(self.get_pdu)(
+ run_in_background(
+ self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
- res = yield preserve_context_over_deferred(
+ res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
@@ -441,7 +445,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
@@ -570,12 +574,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -585,7 +589,7 @@ class FederationClient(FederationBase):
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus.values(),
+ destination, list(pdus.values()),
outlier=True,
)
@@ -650,7 +654,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict)
- pdu = self.event_from_pdu_json(pdu_dict)
+ pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +744,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -788,7 +792,7 @@ class FederationClient(FederationBase):
)
events = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content.get("events", [])
]
@@ -805,15 +809,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 51e3fdea06..e501251b6e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,92 +13,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import re
+import six
+from six import iteritems
-from twisted.internet import defer
-
-from .federation_base import FederationBase
-from .units import Transaction, Edu
-
-from synapse.util.async import Linearizer
-from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
-import synapse.metrics
+from canonicaljson import json
+from prometheus_client import Counter
-from synapse.api.errors import AuthError, FederationError, SynapseError
+from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
+from twisted.python import failure
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
+from synapse.types import get_domain_from_id
+from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logutils import log_function
-import simplejson as json
-import logging
-
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
-# synapse.federation.federation_server is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
-
-received_pdus_counter = metrics.register_counter("received_pdus")
+received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
-received_edus_counter = metrics.register_counter("received_edus")
+received_edus_counter = Counter("synapse_federation_server_received_edus", "")
-received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+received_queries_counter = Counter(
+ "synapse_federation_server_received_queries", "", ["type"]
+)
class FederationServer(FederationBase):
+
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
+ self.handler = hs.get_handlers().federation_handler
- self._server_linearizer = Linearizer("fed_server")
+ self._server_linearizer = async.Linearizer("fed_server")
+ self._transaction_linearizer = async.Linearizer("fed_txn_handler")
- # We cache responses to state queries, as they take a while and often
- # come in waves.
- self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+ self.transaction_actions = TransactionActions(self.store)
- def set_handler(self, handler):
- """Sets the handler that the replication layer will use to communicate
- receipt of new PDUs from other home servers. The required methods are
- documented on :py:class:`.ReplicationHandler`.
- """
- self.handler = handler
+ self.registry = hs.get_federation_registry()
- def register_edu_handler(self, edu_type, handler):
- if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
- self.edu_handlers[edu_type] = handler
-
- def register_query_handler(self, query_type, handler):
- """Sets the handler callable that will be used to handle an incoming
- federation Query of the given type.
-
- Args:
- query_type (str): Category name of the query, which should match
- the string used by make_query.
- handler (callable): Invoked to handle incoming queries of this type
-
- handler is invoked as:
- result = handler(args)
-
- where 'args' is a dict mapping strings to strings of the query
- arguments. It should return a Deferred that will eventually yield an
- object to encode as JSON.
- """
- if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
-
- self.query_handlers[query_type] = handler
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -109,25 +90,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
+ # keep this as early as possible to make the calculated origin ts as
+ # accurate as possible.
+ request_time = self._clock.time_msec()
+
transaction = Transaction(**transaction_data)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ if not transaction.transaction_id:
+ raise Exception("Transaction missing transaction_id")
+ if not transaction.origin:
+ raise Exception("Transaction missing origin")
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
+ # use a linearizer to ensure that we don't process the same transaction
+ # multiple times in parallel.
+ with (yield self._transaction_linearizer.queue(
+ (transaction.origin, transaction.transaction_id),
+ )):
+ result = yield self._handle_incoming_transaction(
+ transaction, request_time,
+ )
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _handle_incoming_transaction(self, transaction, request_time):
+ """ Process an incoming transaction and return the HTTP response
+
+ Args:
+ transaction (Transaction): incoming transaction
+ request_time (int): timestamp that the HTTP request arrived at
+ Returns:
+ Deferred[(int, object)]: http response code and body
+ """
response = yield self.transaction_actions.have_responded(transaction)
if response:
@@ -140,42 +137,67 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- results = []
-
- for pdu in pdu_list:
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if transaction.origin != get_domain_from_id(pdu.event_id):
- # We continue to accept join events from any server; this is
- # necessary for the federation join dance to work correctly.
- # (When we join over federation, the "helper" server is
- # responsible for sending out the join event, rather than the
- # origin. See bug #1893).
- if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) == 'join'
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, transaction.origin
+ received_pdus_counter.inc(len(transaction.pdus))
+
+ origin_host, _ = parse_server_name(transaction.origin)
+
+ pdus_by_room = {}
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = request_time - int(p["age"])
+ del p["age"]
+
+ event = event_from_pdu_json(p)
+ room_id = event.room_id
+ pdus_by_room.setdefault(room_id, []).append(event)
+
+ pdu_results = {}
+
+ # we can process different rooms in parallel (which is useful if they
+ # require callouts to other servers to fetch missing events), but
+ # impose a limit to avoid going too crazy with ram/cpu.
+
+ @defer.inlineCallbacks
+ def process_pdus_for_room(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
+ try:
+ yield self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warn(
+ "Ignoring PDUs for room %s from banned server", room_id,
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ try:
+ yield self._handle_received_pdu(
+ transaction.origin, pdu
)
- continue
- else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, transaction.origin
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warn("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ pdu_results[event_id] = {"error": str(e)}
+ logger.error(
+ "Failed to handle PDU %s: %s",
+ event_id, f.getTraceback().rstrip(),
)
- try:
- yield self._handle_received_pdu(transaction.origin, pdu)
- results.append({})
- except FederationError as e:
- self.send_failure(e, transaction.origin)
- results.append({"error": str(e)})
- except Exception as e:
- results.append({"error": str(e)})
- logger.exception("Failed to handle PDU")
+ yield async.concurrently_execute(
+ process_pdus_for_room, pdus_by_room.keys(),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
@@ -185,17 +207,16 @@ class FederationServer(FederationBase):
edu.content
)
- for failure in getattr(transaction, "pdu_failures", []):
- logger.info("Got failure %r", failure)
-
- logger.debug("Returning: %s", str(results))
+ pdu_failures = getattr(transaction, "pdu_failures", [])
+ for fail in pdu_failures:
+ logger.info("Got failure %r", fail)
response = {
- "pdus": dict(zip(
- (p.event_id for p in pdu_list), results
- )),
+ "pdus": pdu_results,
}
+ logger.debug("Returning: %s", str(response))
+
yield self.transaction_actions.set_response(
transaction,
200, response
@@ -205,16 +226,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
-
- if edu_type in self.edu_handlers:
- try:
- yield self.edu_handlers[edu_type](origin, content)
- except SynapseError as e:
- logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception as e:
- logger.exception("Failed to handle edu %r", edu_type)
- else:
- logger.warn("Received EDU of type %s with no handler", edu_type)
+ yield self.registry.on_edu(edu_type, origin, content)
@defer.inlineCallbacks
@log_function
@@ -222,19 +234,24 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- result = self._state_resp_cache.get((room_id, event_id))
- if not result:
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.set(
- (room_id, event_id),
- self._on_context_state_request_compute(room_id, event_id)
- )
- else:
- resp = yield result
+ # we grab the linearizer to protect ourselves from servers which hammer
+ # us. In theory we might already have the response to this query
+ # in the cache so we could return it without waiting for the linearizer
+ # - but that's non-trivial to get right, and anyway somewhat defeats
+ # the point of the linearizer.
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id, event_id,
+ )
defer.returnValue((200, resp))
@@ -243,6 +260,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -286,7 +306,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
+ pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -302,25 +322,23 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
- received_queries_counter.inc(query_type)
-
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue(
- (404, "No handler for Query type '%s'" % (query_type,))
- )
+ received_queries_counter.labels(query_type).inc()
+ resp = yield self.registry.on_query(query_type, args)
+ defer.returnValue((200, resp))
@defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -328,7 +346,11 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -340,7 +362,9 @@ class FederationServer(FederationBase):
}))
@defer.inlineCallbacks
- def on_make_leave_request(self, room_id, user_id):
+ def on_make_leave_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -348,7 +372,11 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -356,6 +384,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {
@@ -384,8 +415,11 @@ class FederationServer(FederationBase):
Deferred: Results in `dict` with the same format as `content`
"""
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -444,9 +478,9 @@ class FederationServer(FederationBase):
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -457,6 +491,9 @@ class FederationServer(FederationBase):
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d",
@@ -485,17 +522,6 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
@@ -513,13 +539,57 @@ class FederationServer(FederationBase):
def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
+ If the event is invalid, then this method throws a FederationError.
+ (The error will then be logged and sent back to the sender (which
+ probably won't do anything with it), and other events in the
+ transaction will be processed as normal).
+
+ It is likely that we'll then receive other events which refer to
+ this rejected_event in their prev_events, etc. When that happens,
+ we'll attempt to fetch the rejected event again, which will presumably
+ fail, so those second-generation events will also get rejected.
+
+ Eventually, we get to the point where there are more than 10 events
+ between any new events and the original rejected event. Since we
+ only try to backfill 10 events deep on received pdu, we then accept the
+ new event, possibly introducing a discontinuity in the DAG, with new
+ forward extremities, so normal service is approximately returned,
+ until we try to backfill across the discontinuity.
+
Args:
origin (str): server which sent the pdu
pdu (FrozenEvent): received pdu
Returns (Deferred): completes with None
- Raises: FederationError if the signatures / hash do not match
- """
+
+ Raises: FederationError if the signatures / hash do not match, or
+ if the event was unacceptable for any other reason (eg, too large,
+ too many prev_events, couldn't find the prev_events)
+ """
+ # check that it's actually being sent from a valid destination to
+ # workaround bug #1753 in 0.18.5 and 0.18.6
+ if origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
+ if not (
+ pdu.type == 'm.room.member' and
+ pdu.content and
+ pdu.content.get("membership", None) == 'join'
+ ):
+ logger.info(
+ "Discarding PDU %s from invalid origin %s",
+ pdu.event_id, origin
+ )
+ return
+ else:
+ logger.info(
+ "Accepting join PDU %s from %s",
+ pdu.event_id, origin
+ )
+
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
@@ -531,20 +601,13 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+ yield self.handler.on_receive_pdu(
+ origin, pdu, get_missing=True, sent_to_us_directly=True,
+ )
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def exchange_third_party_invite(
self,
@@ -567,3 +630,161 @@ class FederationServer(FederationBase):
origin, room_id, event_dict
)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def check_server_matches_acl(self, server_name, room_id):
+ """Check if the given server is allowed by the server ACLs in the room
+
+ Args:
+ server_name (str): name of server, *without any port part*
+ room_id (str): ID of the room to check
+
+ Raises:
+ AuthError if the server does not match the ACL
+ """
+ state_ids = yield self.store.get_current_state_ids(room_id)
+ acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+ if not acl_event_id:
+ return
+
+ acl_event = yield self.store.get_event(acl_event_id)
+ if server_matches_acl_event(server_name, acl_event):
+ return
+
+ raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+ """Check if the given server is allowed by the ACL event
+
+ Args:
+ server_name (str): name of server, without any port part
+ acl_event (EventBase): m.room.server_acl event
+
+ Returns:
+ bool: True if this server is allowed by the ACLs
+ """
+ logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+ # first of all, check if literal IPs are blocked, and if so, whether the
+ # server name is a literal IP
+ allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+ if not isinstance(allow_ip_literals, bool):
+ logger.warn("Ignorning non-bool allow_ip_literals flag")
+ allow_ip_literals = True
+ if not allow_ip_literals:
+ # check for ipv6 literals. These start with '['.
+ if server_name[0] == '[':
+ return False
+
+ # check for ipv4 literals. We can just lift the routine from twisted.
+ if isIPAddress(server_name):
+ return False
+
+ # next, check the deny list
+ deny = acl_event.content.get("deny", [])
+ if not isinstance(deny, (list, tuple)):
+ logger.warn("Ignorning non-list deny ACL %s", deny)
+ deny = []
+ for e in deny:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched deny rule %s", server_name, e)
+ return False
+
+ # then the allow list.
+ allow = acl_event.content.get("allow", [])
+ if not isinstance(allow, (list, tuple)):
+ logger.warn("Ignorning non-list allow ACL %s", allow)
+ allow = []
+ for e in allow:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched allow rule %s", server_name, e)
+ return True
+
+ # everything else should be rejected.
+ # logger.info("%s fell through", server_name)
+ return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+ if not isinstance(acl_entry, six.string_types):
+ logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ return False
+ regex = _glob_to_regex(acl_entry)
+ return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+ return re.compile(res + "\\Z", re.IGNORECASE)
+
+
+class FederationHandlerRegistry(object):
+ """Allows classes to register themselves as handlers for a given EDU or
+ query type for incoming federation traffic.
+ """
+ def __init__(self):
+ self.edu_handlers = {}
+ self.query_handlers = {}
+
+ def register_edu_handler(self, edu_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation EDU of the given type.
+
+ Args:
+ edu_type (str): The type of the incoming EDU to register handler for
+ handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+ of the given type. The arguments are the origin server name and
+ the EDU contents.
+ """
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+ self.edu_handlers[edu_type] = handler
+
+ def register_query_handler(self, query_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation query of the given type.
+
+ Args:
+ query_type (str): Category name of the query, which should match
+ the string used by make_query.
+ handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+ incoming queries of this type. The return will be yielded
+ on and the result used as the response to the query request.
+ """
+ if query_type in self.query_handlers:
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
+
+ self.query_handlers[query_type] = handler
+
+ @defer.inlineCallbacks
+ def on_edu(self, edu_type, origin, content):
+ handler = self.edu_handlers.get(edu_type)
+ if not handler:
+ logger.warn("No handler registered for EDU type %s", edu_type)
+
+ try:
+ yield handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception as e:
+ logger.exception("Failed to handle edu %r", edu_type)
+
+ def on_query(self, query_type, args):
+ handler = self.query_handlers.get(query_type)
+ if not handler:
+ logger.warn("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+ return handler(args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 84dc606673..9146215c21 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -19,13 +19,12 @@ package.
These actions are mostly only used by the :py:mod:`.replication` module.
"""
+import logging
+
from twisted.internet import defer
from synapse.util.logutils import log_function
-import logging
-
-
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
- """This layer is responsible for replicating with remote home servers over
- the given transport. I.e., does the sending and receiving of PDUs to
- remote home servers.
-
- The layer communicates with the rest of the server via a registered
- ReplicationHandler.
-
- In more detail, the layer:
- * Receives incoming data and processes it into transactions and pdus.
- * Fetches any PDUs it thinks it might have missed.
- * Keeps the current state for contexts up to date by applying the
- suitable conflict resolution.
- * Sends outgoing pdus wrapped in transactions.
- * Fills out the references to previous pdus/transactions appropriately
- for outgoing data.
- """
-
- def __init__(self, hs, transport_layer):
- self.server_name = hs.hostname
-
- self.keyring = hs.get_keyring()
-
- self.transport_layer = transport_layer
-
- self.federation_client = self
-
- self.store = hs.get_datastore()
-
- self.handler = None
- self.edu_handlers = {}
- self.query_handlers = {}
-
- self._clock = hs.get_clock()
-
- self.transaction_actions = TransactionActions(self.store)
-
- self.hs = hs
-
- super(ReplicationLayer, self).__init__(hs)
-
- def __str__(self):
- return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 93e5acebc1..5157c3860d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -29,23 +29,22 @@ dead worker doesn't cause the queues to grow limitlessly.
Events are replicated via a separate events stream.
"""
-from .units import Edu
+import logging
+from collections import namedtuple
+from six import iteritems, itervalues
+
+from sortedcontainers import SortedDict
+
+from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
-import synapse.metrics
-
-from blist import sorteddict
-from collections import namedtuple
-import logging
+from .units import Edu
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-
class FederationRemoteSendQueue(object):
"""A drop in replacement for TransactionQueue"""
@@ -56,29 +55,27 @@ class FederationRemoteSendQueue(object):
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = sorteddict() # Stream position -> user_id
+ self.presence_changed = SortedDict() # Stream position -> user_id
self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = sorteddict() # stream position -> (destination, key)
+ self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
- self.edus = sorteddict() # stream position -> Edu
+ self.edus = SortedDict() # stream position -> Edu
- self.failures = sorteddict() # stream position -> (destination, Failure)
+ self.failures = SortedDict() # stream position -> (destination, Failure)
- self.device_messages = sorteddict() # stream position -> destination
+ self.device_messages = SortedDict() # stream position -> destination
self.pos = 1
- self.pos_time = sorteddict()
+ self.pos_time = SortedDict()
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- metrics.register_callback(
- queue_name + "_size",
- lambda: len(queue),
- )
+ LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
+ "", [], lambda: len(queue))
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
@@ -101,7 +98,7 @@ class FederationRemoteSendQueue(object):
now = self.clock.time_msec()
keys = self.pos_time.keys()
- time = keys.bisect_left(now - FIVE_MINUTES_AGO)
+ time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
if not keys[:time]:
return
@@ -116,13 +113,13 @@ class FederationRemoteSendQueue(object):
with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps
keys = self.presence_changed.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.presence_changed.bisect_left(position_to_delete)
for key in keys[:i]:
del self.presence_changed[key]
user_ids = set(
user_id
- for uids in self.presence_changed.itervalues()
+ for uids in itervalues(self.presence_changed)
for user_id in uids
)
@@ -134,7 +131,7 @@ class FederationRemoteSendQueue(object):
# Delete things out of keyed edus
keys = self.keyed_edu_changed.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.keyed_edu_changed.bisect_left(position_to_delete)
for key in keys[:i]:
del self.keyed_edu_changed[key]
@@ -148,19 +145,19 @@ class FederationRemoteSendQueue(object):
# Delete things out of edu map
keys = self.edus.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.edus.bisect_left(position_to_delete)
for key in keys[:i]:
del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.failures.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map
keys = self.device_messages.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.device_messages.bisect_left(position_to_delete)
for key in keys[:i]:
del self.device_messages[key]
@@ -200,7 +197,7 @@ class FederationRemoteSendQueue(object):
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+ local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states]
@@ -253,13 +250,12 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(federation_ack)
# Fetch changed presence
- keys = self.presence_changed.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
+ i = self.presence_changed.bisect_right(from_token)
+ j = self.presence_changed.bisect_right(to_token) + 1
dest_user_ids = [
(pos, user_id)
- for pos in keys[i:j]
- for user_id in self.presence_changed[pos]
+ for pos, user_id_list in self.presence_changed.items()[i:j]
+ for user_id in user_id_list
]
for (key, user_id) in dest_user_ids:
@@ -268,34 +264,31 @@ class FederationRemoteSendQueue(object):
)))
# Fetch changes keyed edus
- keys = self.keyed_edu_changed.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
+ i = self.keyed_edu_changed.bisect_right(from_token)
+ j = self.keyed_edu_changed.bisect_right(to_token) + 1
# We purposefully clobber based on the key here, python dict comprehensions
# always use the last value, so this will correctly point to the last
# stream position.
- keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+ keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in keyed_edus.iteritems():
+ for ((destination, edu_key), pos) in iteritems(keyed_edus):
rows.append((pos, KeyedEduRow(
key=edu_key,
edu=self.keyed_edu[(destination, edu_key)],
)))
# Fetch changed edus
- keys = self.edus.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- edus = ((k, self.edus[k]) for k in keys[i:j])
+ i = self.edus.bisect_right(from_token)
+ j = self.edus.bisect_right(to_token) + 1
+ edus = self.edus.items()[i:j]
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
# Fetch changed failures
- keys = self.failures.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- failures = ((k, self.failures[k]) for k in keys[i:j])
+ i = self.failures.bisect_right(from_token)
+ j = self.failures.bisect_right(to_token) + 1
+ failures = self.failures.items()[i:j]
for (pos, (destination, failure)) in failures:
rows.append((pos, FailureRow(
@@ -304,12 +297,11 @@ class FederationRemoteSendQueue(object):
)))
# Fetch changed device messages
- keys = self.device_messages.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- device_messages = {self.device_messages[k]: k for k in keys[i:j]}
+ i = self.device_messages.bisect_right(from_token)
+ j = self.device_messages.bisect_right(to_token) + 1
+ device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
- for (destination, pos) in device_messages.iteritems():
+ for (destination, pos) in iteritems(device_messages):
rows.append((pos, DeviceRow(
destination=destination,
)))
@@ -528,19 +520,19 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
- for destination, edu_map in buff.keyed_edus.iteritems():
+ for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=key,
)
- for destination, edu_list in buff.edus.iteritems():
+ for destination, edu_list in iteritems(buff.edus):
for edu in edu_list:
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in buff.failures.iteritems():
+ for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 003eaba893..6996d6b695 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -13,34 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
+import logging
-from twisted.internet import defer
+from six import itervalues
-from .persistence import TransactionActions
-from .units import Transaction, Edu
+from prometheus_client import Counter
-from synapse.api.errors import HttpResponseException
-from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
-from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
-from synapse.util.metrics import measure_func
-from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
-import synapse.metrics
+from twisted.internet import defer
-import logging
+import synapse.metrics
+from synapse.api.errors import FederationDeniedError, HttpResponseException
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
+from synapse.metrics import (
+ LaterGauge,
+ events_processed_counter,
+ sent_edus_counter,
+ sent_transactions_counter,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import logcontext
+from synapse.util.metrics import measure_func
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+from .persistence import TransactionActions
+from .units import Edu, Transaction
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = client_metrics.register_distribution(
- "sent_pdu_destinations"
+sent_pdus_destination_dist_count = Counter(
+ "synapse_federation_client_sent_pdu_destinations:count", ""
+)
+sent_pdus_destination_dist_total = Counter(
+ "synapse_federation_client_sent_pdu_destinations:total", ""
)
-sent_edus_counter = client_metrics.register_counter("sent_edus")
-
-sent_transactions_counter = client_metrics.register_counter("sent_transactions")
class TransactionQueue(object):
@@ -67,8 +72,10 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
- metrics.register_callback(
- "pending_destinations",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_destinations",
+ "",
+ [],
lambda: len(self.pending_transactions),
)
@@ -92,12 +99,16 @@ class TransactionQueue(object):
# Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
- metrics.register_callback(
- "pending_pdus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_pdus",
+ "",
+ [],
lambda: sum(map(len, pdus.values())),
)
- metrics.register_callback(
- "pending_edus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_edus",
+ "",
+ [],
lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
@@ -146,7 +157,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- @defer.inlineCallbacks
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
@@ -156,12 +166,20 @@ class TransactionQueue(object):
if self._is_processing:
return
+ # fire off a processing loop in the background
+ run_as_background_process(
+ "process_event_queue_for_federation",
+ self._process_event_queue_loop,
+ )
+
+ @defer.inlineCallbacks
+ def _process_event_queue_loop(self):
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=20,
+ last_token, self._last_poked_id, limit=100,
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -169,24 +187,33 @@ class TransactionQueue(object):
if not events and next_token >= self._last_poked_id:
break
- for event in events:
+ @defer.inlineCallbacks
+ def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
if not is_mine and send_on_behalf_of is None:
- continue
-
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=[
- prev_id for prev_id, _ in event.prev_events
- ],
- )
+ return
+
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = yield self.state.get_current_hosts_in_room(
+ event.room_id, latest_event_ids=[
+ prev_id for prev_id, _ in event.prev_events
+ ],
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
+
destinations = set(destinations)
if send_on_behalf_of is not None:
@@ -199,10 +226,41 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ events_by_room = {}
+ for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True
+ ))
+
yield self.store.update_federation_out_pos(
"events", next_token
)
+ if events:
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_lag.labels(
+ "federation_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "federation_sender").set(ts)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_positions.labels(
+ "federation_sender").set(next_token)
+
finally:
self._is_processing = False
@@ -224,18 +282,17 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc_by(len(destinations))
+ sent_pdus_destination_dist_total.inc(len(destinations))
+ sent_pdus_destination_dist_count.inc()
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
- @preserve_fn # the caller should not yield on this
+ @logcontext.preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
"""Send the new presence states to the appropriate destinations.
@@ -273,7 +330,9 @@ class TransactionQueue(object):
if not states_map:
break
- yield self._process_presence_inner(states_map.values())
+ yield self._process_presence_inner(list(states_map.values()))
+ except Exception:
+ logger.exception("Error sending presence states to servers")
finally:
self._processing_pending_presence = False
@@ -299,7 +358,7 @@ class TransactionQueue(object):
state.user_id: state for state in states
})
- preserve_fn(self._attempt_new_transaction)(destination)
+ self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
@@ -321,9 +380,7 @@ class TransactionQueue(object):
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
@@ -336,9 +393,7 @@ class TransactionQueue(object):
destination, []
).append(failure)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
@@ -347,15 +402,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination):
return
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def get_current_token(self):
return 0
- @defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
+ """Try to start a new transaction to this destination
+
+ If there is already a transaction in progress to this destination,
+ returns immediately. Otherwise kicks off the process of sending a
+ transaction in the background.
+
+ Args:
+ destination (str):
+
+ Returns:
+ None
+ """
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
@@ -368,6 +432,16 @@ class TransactionQueue(object):
)
return
+ logger.debug("TX [%s] Starting transaction loop", destination)
+
+ run_as_background_process(
+ "federation_transaction_transmission_loop",
+ self._transaction_transmission_loop,
+ destination,
+ )
+
+ @defer.inlineCallbacks
+ def _transaction_transmission_loop(self, destination):
pending_pdus = []
try:
self.pending_transactions[destination] = 1
@@ -377,9 +451,6 @@ class TransactionQueue(object):
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
- # XXX: what's this for?
- yield run_on_reactor()
-
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
@@ -464,6 +535,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+ except FederationDeniedError as e:
+ logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 52b2a717d2..4529d454af 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -49,7 +50,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state/%s/" % room_id
+ path = _create_path(PREFIX, "/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -71,7 +72,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state_ids/%s/" % room_id
+ path = _create_path(PREFIX, "/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -93,7 +94,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = PREFIX + "/event/%s/" % (event_id, )
+ path = _create_path(PREFIX, "/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@@ -119,7 +120,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = PREFIX + "/backfill/%s/" % (room_id,)
+ path = _create_path(PREFIX, "/backfill/%s/", room_id)
args = {
"v": event_tuples,
@@ -157,9 +158,11 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
+ path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
response = yield self.client.put_json(
transaction.destination,
- path=PREFIX + "/send/%s/" % transaction.transaction_id,
+ path=path,
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
@@ -177,7 +180,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
- path = PREFIX + "/query/%s" % query_type
+ path = _create_path(PREFIX, "/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@@ -212,6 +215,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if the remote destination
+ is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
@@ -219,7 +225,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+ path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -245,7 +251,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -258,7 +264,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -277,7 +283,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
- path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -319,7 +325,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+ path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@@ -332,7 +338,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@@ -344,7 +350,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
- path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@@ -406,7 +412,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = PREFIX + "/user/devices/" + user_id
+ path = _create_path(PREFIX, "/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@@ -456,7 +462,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
- path = PREFIX + "/get_missing_events/%s" % (room_id,)
+ path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@@ -471,3 +477,475 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ @log_function
+ def get_group_profile(self, destination, group_id, requester_user_id):
+ """Get a group profile
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_profile(self, destination, group_id, requester_user_id, content):
+ """Update a remote group profile
+
+ Args:
+ destination (str)
+ group_id (str)
+ requester_user_id (str)
+ content (dict): The new profile of the group
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_summary(self, destination, group_id, requester_user_id):
+ """Get a group summary
+ """
+ path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ """Get all rooms in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+ content):
+ """Add a room to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
+ config_key, content):
+ """Update room in group
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/room/%s/config/%s",
+ group_id, room_id, config_key,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ """Remove a room from a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users that have been invited to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def accept_group_invite(self, destination, group_id, user_id, content):
+ """Accept a group invite
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/users/%s/accept_invite",
+ group_id, user_id,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def join_group(self, destination, group_id, user_id, content):
+ """Attempts to join a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ """Invite a user to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group_notification(self, destination, group_id, user_id, content):
+ """Sent by group server to inform a user's server that they have been
+ invited.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group(self, destination, group_id, requester_user_id,
+ user_id, content):
+ """Remove a user fron a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group_notification(self, destination, group_id, user_id,
+ content):
+ """Sent by group server to inform a user's server that they have been
+ kicked from the group.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def renew_group_attestation(self, destination, group_id, user_id, content):
+ """Sent by either a group server or a user's server to periodically update
+ the attestations
+ """
+
+ path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id, content):
+ """Update a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id):
+ """Delete a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_categories(self, destination, group_id, requester_user_id):
+ """Get all categories in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ """Get category info in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_category(self, destination, group_id, requester_user_id, category_id,
+ content):
+ """Update a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_category(self, destination, group_id, requester_user_id,
+ category_id):
+ """Delete a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_roles(self, destination, group_id, requester_user_id):
+ """Get all roles in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Get a roles info
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_role(self, destination, group_id, requester_user_id, role_id,
+ content):
+ """Update a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Delete a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id, content):
+ """Update a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def set_group_join_policy(self, destination, group_id, requester_user_id,
+ content):
+ """Sets the join policy for a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+ return self.client.put_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id):
+ """Delete a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def bulk_get_publicised_groups(self, destination, user_ids):
+ """Get the groups a list of users are publicising
+ """
+
+ path = PREFIX + "/get_groups_publicised"
+
+ content = {"user_ids": user_ids}
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+
+def _create_path(prefix, path, *args):
+ """Creates a path from the prefix, path template and args. Ensures that
+ all args are url encoded.
+
+ Example:
+
+ _create_path(PREFIX, "/event/%s/", event_id)
+
+ Args:
+ prefix (str)
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a78f01e442..c9beca27c2 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,25 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import logging
+import re
+
from twisted.internet import defer
+import synapse
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
- parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
+ parse_integer_from_args,
+ parse_json_object_from_request,
+ parse_string_from_args,
)
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import preserve_fn
-from synapse.types import ThirdPartyInstanceID
-
-import functools
-import logging
-import re
-import synapse
-
logger = logging.getLogger(__name__)
@@ -81,6 +84,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -97,26 +101,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -125,11 +109,17 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
+ if (
+ self.federation_domain_whitelist is not None and
+ origin not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(origin)
+
if not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -144,11 +134,60 @@ class Authenticator(object):
# alive
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
- logger.info("Marking origin %r as up", origin)
- preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
+ run_in_background(self._reset_retry_timings, origin)
defer.returnValue(origin)
+ @defer.inlineCallbacks
+ def _reset_retry_timings(self, origin):
+ try:
+ logger.info("Marking origin %r as up", origin)
+ yield self.store.set_destination_retry_timings(origin, 0, 0)
+ except Exception:
+ logger.exception("Error resetting retry timings on %s", origin)
+
+
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+
+ # ensure that the origin is a valid server name
+ parse_and_validate_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
@@ -177,7 +216,7 @@ class BaseFederationServlet(object):
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
- except:
+ except Exception:
logger.exception("authenticate_request failed")
raise
@@ -270,7 +309,7 @@ class FederationSendServlet(BaseFederationServlet):
code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
- except:
+ except Exception:
logger.exception("on_incoming_transaction failed")
raise
@@ -347,7 +386,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_join_request(context, user_id)
+ content = yield self.handler.on_make_join_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
@@ -356,7 +397,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(context, user_id)
+ content = yield self.handler.on_make_leave_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
@@ -609,6 +652,549 @@ class FederationVersionServlet(BaseFederationServlet):
}))
+class FederationGroupsProfileServlet(BaseFederationServlet):
+ """Get/set the basic profile of a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_profile(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.update_group_profile(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+ PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_summary(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+ """Get the rooms in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+ """Add/remove room from group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+ """Update room config in group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+ """Get the users in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+ """Get the users that have been invited to a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_invited_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+ """Ask a group server to invite someone to the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+ """Accept an invitation from the group server
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.accept_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsJoinServlet(BaseFederationServlet):
+ """Attempt to join a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.join_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+ """Leave or kick a user from the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+ """A group server has invited a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "group_id doesn't match origin")
+
+ new_content = yield self.handler.on_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+ """A group server has removed a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.user_removed_from_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+ """A group or user's server renews their attestation
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ # We don't need to check auth here as we check the attestation signatures
+
+ new_content = yield self.handler.on_renew_attestation(
+ group_id, user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+ """Add/remove a room from the group summary, with optional category.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+ """Get all categories for a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+ """Add/remove/get a category in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_category(
+ group_id, requester_user_id, category_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.upsert_group_category(
+ group_id, requester_user_id, category_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_category(
+ group_id, requester_user_id, category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+ """Add/remove/get a role in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_role(
+ group_id, requester_user_id, role_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_role(
+ group_id, requester_user_id, role_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_role(
+ group_id, requester_user_id, role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+ """Add/remove a user from the group summary, with optional role.
+
+ Matches both:
+ - /groups/:group/summary/users/:user_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/get_groups_publicised$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ resp = yield self.handler.bulk_get_publicised_groups(
+ content["user_ids"], proxy=False,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+ """Sets whether a group is joinable without an invite or knock
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.set_group_join_policy(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -635,15 +1221,49 @@ FEDERATION_SERVLET_CLASSES = (
FederationVersionServlet,
)
+
ROOM_LIST_CLASSES = (
PublicRoomList,
)
+GROUP_SERVER_SERVLET_CLASSES = (
+ FederationGroupsProfileServlet,
+ FederationGroupsSummaryServlet,
+ FederationGroupsRoomsServlet,
+ FederationGroupsUsersServlet,
+ FederationGroupsInvitedUsersServlet,
+ FederationGroupsInviteServlet,
+ FederationGroupsAcceptInviteServlet,
+ FederationGroupsJoinServlet,
+ FederationGroupsRemoveUserServlet,
+ FederationGroupsSummaryRoomsServlet,
+ FederationGroupsCategoriesServlet,
+ FederationGroupsCategoryServlet,
+ FederationGroupsRolesServlet,
+ FederationGroupsRoleServlet,
+ FederationGroupsSummaryUsersServlet,
+ FederationGroupsAddRoomsServlet,
+ FederationGroupsAddRoomsConfigServlet,
+ FederationGroupsSettingJoinPolicyServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+ FederationGroupsLocalInviteServlet,
+ FederationGroupsRemoveLocalUserServlet,
+ FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+ FederationGroupsRenewAttestaionServlet,
+)
+
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_replication_layer(),
+ handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -656,3 +1276,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
+
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_attestation_renewer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 3f645acc43..bb1b3b13f7 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,10 +17,9 @@
server protocol.
"""
-from synapse.util.jsonobject import JsonEncodedObject
-
import logging
+from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
@@ -74,8 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "transaction_id",
- "destination",
"pdu_failures",
]
|