diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 4cc98a3fe8..c11798093d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,14 +16,15 @@ import logging
import six
+from twisted.internet import defer
+
from synapse.api.constants import MAX_DEPTH
-from synapse.api.errors import SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
-from synapse.http.servlet import assert_params_in_request
-from synapse.util import unwrapFirstError, logcontext
-from twisted.internet import defer
+from synapse.http.servlet import assert_params_in_dict
+from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -198,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+ assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 87a92f6ea9..c9f3c2d352 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,25 +21,25 @@ import random
from six.moves import range
+from prometheus_client import Counter
+
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
+ CodeMessageException,
+ FederationDeniedError,
+ HttpResponseException,
+ SynapseError,
)
from synapse.events import builder
-from synapse.federation.federation_base import (
- FederationBase,
- event_from_pdu_json,
-)
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
-from prometheus_client import Counter
-
logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000
+class InvalidResponseError(RuntimeError):
+ """Helper for _try_destination_list: indicates that the server returned a response
+ we couldn't parse
+ """
+ pass
+
+
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
@@ -458,8 +465,63 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth)
@defer.inlineCallbacks
+ def _try_destination_list(self, description, destinations, callback):
+ """Try an operation on a series of servers, until it succeeds
+
+ Args:
+ description (unicode): description of the operation we're doing, for logging
+
+ destinations (Iterable[unicode]): list of server_names to try
+
+ callback (callable): Function to run for each server. Passed a single
+ argument: the server_name to try. May return a deferred.
+
+ If the callback raises a CodeMessageException with a 300/400 code,
+ attempts to perform the operation stop immediately and the exception is
+ reraised.
+
+ Otherwise, if the callback raises an Exception the error is logged and the
+ next server tried. Normally the stacktrace is logged but this is
+ suppressed if the exception is an InvalidResponseError.
+
+ Returns:
+ The [Deferred] result of callback, if it succeeds
+
+ Raises:
+ SynapseError if the chosen remote server returns a 300/400 code.
+
+ RuntimeError if no servers were reachable.
+ """
+ for destination in destinations:
+ if destination == self.server_name:
+ continue
+
+ try:
+ res = yield callback(destination)
+ defer.returnValue(res)
+ except InvalidResponseError as e:
+ logger.warn(
+ "Failed to %s via %s: %s",
+ description, destination, e,
+ )
+ except HttpResponseException as e:
+ if not 500 <= e.code < 600:
+ raise e.to_synapse_error()
+ else:
+ logger.warn(
+ "Failed to %s via %s: %i %s",
+ description, destination, e.code, e.message,
+ )
+ except Exception:
+ logger.warn(
+ "Failed to %s via %s",
+ description, destination, exc_info=1,
+ )
+
+ raise RuntimeError("Failed to %s via any server" % (description, ))
+
def make_membership_event(self, destinations, room_id, user_id, membership,
- content={},):
+ content, params):
"""
Creates an m.room.member event, with context, without participating in the room.
@@ -475,13 +537,15 @@ class FederationClient(FederationBase):
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
- content (object): Any additional data to put into the content field
+ content (dict): Any additional data to put into the content field
of the event.
+ params (dict[str, str|Iterable[str]]): Query parameters to include in the
+ request.
Return:
Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event.
- Fails with a ``CodeMessageException`` if the chosen remote server
+ Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
@@ -492,50 +556,37 @@ class FederationClient(FederationBase):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- for destination in destinations:
- if destination == self.server_name:
- continue
- try:
- ret = yield self.transport_layer.make_membership_event(
- destination, room_id, user_id, membership
- )
+ @defer.inlineCallbacks
+ def send_request(destination):
+ ret = yield self.transport_layer.make_membership_event(
+ destination, room_id, user_id, membership, params,
+ )
- pdu_dict = ret["event"]
+ pdu_dict = ret.get("event", None)
+ if not isinstance(pdu_dict, dict):
+ raise InvalidResponseError("Bad 'event' field in response")
- logger.debug("Got response to make_%s: %s", membership, pdu_dict)
+ logger.debug("Got response to make_%s: %s", membership, pdu_dict)
- pdu_dict["content"].update(content)
+ pdu_dict["content"].update(content)
- # The protoevent received over the JSON wire may not have all
- # the required fields. Lets just gloss over that because
- # there's some we never care about
- if "prev_state" not in pdu_dict:
- pdu_dict["prev_state"] = []
+ # The protoevent received over the JSON wire may not have all
+ # the required fields. Lets just gloss over that because
+ # there's some we never care about
+ if "prev_state" not in pdu_dict:
+ pdu_dict["prev_state"] = []
- ev = builder.EventBuilder(pdu_dict)
+ ev = builder.EventBuilder(pdu_dict)
- defer.returnValue(
- (destination, ev)
- )
- break
- except CodeMessageException as e:
- if not 500 <= e.code < 600:
- raise
- else:
- logger.warn(
- "Failed to make_%s via %s: %s",
- membership, destination, e.message
- )
- except Exception as e:
- logger.warn(
- "Failed to make_%s via %s: %s",
- membership, destination, e.message
- )
+ defer.returnValue(
+ (destination, ev)
+ )
- raise RuntimeError("Failed to send to any server.")
+ return self._try_destination_list(
+ "make_" + membership, destinations, send_request,
+ )
- @defer.inlineCallbacks
def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers.
@@ -552,103 +603,111 @@ class FederationClient(FederationBase):
giving the serer the event was sent to, ``state`` (?) and
``auth_chain``.
- Fails with a ``CodeMessageException`` if the chosen remote server
+ Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
"""
- for destination in destinations:
- if destination == self.server_name:
- continue
-
- try:
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
+ def check_authchain_validity(signed_auth_chain):
+ for e in signed_auth_chain:
+ if e.type == EventTypes.Create:
+ create_event = e
+ break
+ else:
+ raise InvalidResponseError(
+ "no %s in auth chain" % (EventTypes.Create,),
)
- logger.debug("Got content: %s", content)
+ # the room version should be sane.
+ room_version = create_event.content.get("room_version", "1")
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ # This shouldn't be possible, because the remote server should have
+ # rejected the join attempt during make_join.
+ raise InvalidResponseError(
+ "room appears to have unsupported version %s" % (
+ room_version,
+ ))
+
+ @defer.inlineCallbacks
+ def send_request(destination):
+ time_now = self._clock.time_msec()
+ _, content = yield self.transport_layer.send_join(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
- state = [
- event_from_pdu_json(p, outlier=True)
- for p in content.get("state", [])
- ]
+ logger.debug("Got content: %s", content)
- auth_chain = [
- event_from_pdu_json(p, outlier=True)
- for p in content.get("auth_chain", [])
- ]
+ state = [
+ event_from_pdu_json(p, outlier=True)
+ for p in content.get("state", [])
+ ]
- pdus = {
- p.event_id: p
- for p in itertools.chain(state, auth_chain)
- }
+ auth_chain = [
+ event_from_pdu_json(p, outlier=True)
+ for p in content.get("auth_chain", [])
+ ]
- valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, list(pdus.values()),
- outlier=True,
- )
+ pdus = {
+ p.event_id: p
+ for p in itertools.chain(state, auth_chain)
+ }
- valid_pdus_map = {
- p.event_id: p
- for p in valid_pdus
- }
-
- # NB: We *need* to copy to ensure that we don't have multiple
- # references being passed on, as that causes... issues.
- signed_state = [
- copy.copy(valid_pdus_map[p.event_id])
- for p in state
- if p.event_id in valid_pdus_map
- ]
+ valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+ destination, list(pdus.values()),
+ outlier=True,
+ )
- signed_auth = [
- valid_pdus_map[p.event_id]
- for p in auth_chain
- if p.event_id in valid_pdus_map
- ]
+ valid_pdus_map = {
+ p.event_id: p
+ for p in valid_pdus
+ }
- # NB: We *need* to copy to ensure that we don't have multiple
- # references being passed on, as that causes... issues.
- for s in signed_state:
- s.internal_metadata = copy.deepcopy(s.internal_metadata)
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ signed_state = [
+ copy.copy(valid_pdus_map[p.event_id])
+ for p in state
+ if p.event_id in valid_pdus_map
+ ]
- auth_chain.sort(key=lambda e: e.depth)
+ signed_auth = [
+ valid_pdus_map[p.event_id]
+ for p in auth_chain
+ if p.event_id in valid_pdus_map
+ ]
- defer.returnValue({
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- })
- except CodeMessageException as e:
- if not 500 <= e.code < 600:
- raise
- else:
- logger.exception(
- "Failed to send_join via %s: %s",
- destination, e.message
- )
- except Exception as e:
- logger.exception(
- "Failed to send_join via %s: %s",
- destination, e.message
- )
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ for s in signed_state:
+ s.internal_metadata = copy.deepcopy(s.internal_metadata)
- raise RuntimeError("Failed to send to any server.")
+ check_authchain_validity(signed_auth)
+
+ defer.returnValue({
+ "state": signed_state,
+ "auth_chain": signed_auth,
+ "origin": destination,
+ })
+ return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
- code, content = yield self.transport_layer.send_invite(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ try:
+ code, content = yield self.transport_layer.send_invite(
+ destination=destination,
+ room_id=room_id,
+ event_id=event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+ except HttpResponseException as e:
+ if e.code == 403:
+ raise e.to_synapse_error()
+ raise
pdu_dict = content["event"]
@@ -663,7 +722,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu)
- @defer.inlineCallbacks
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@@ -680,35 +738,25 @@ class FederationClient(FederationBase):
Return:
Deferred: resolves to None.
- Fails with a ``CodeMessageException`` if the chosen remote server
- returns a non-200 code.
+ Fails with a ``SynapseError`` if the chosen remote server
+ returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
"""
- for destination in destinations:
- if destination == self.server_name:
- continue
-
- try:
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_leave(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ @defer.inlineCallbacks
+ def send_request(destination):
+ time_now = self._clock.time_msec()
+ _, content = yield self.transport_layer.send_leave(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
- logger.debug("Got content: %s", content)
- defer.returnValue(None)
- except CodeMessageException:
- raise
- except Exception as e:
- logger.exception(
- "Failed to send_leave via %s: %s",
- destination, e.message
- )
+ logger.debug("Got content: %s", content)
+ defer.returnValue(None)
- raise RuntimeError("Failed to send to any server.")
+ return self._try_destination_list("send_leave", destinations, send_request)
def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2d420a58a2..3e0cd294a1 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,28 +14,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
-import simplejson as json
-from twisted.internet import defer
+import six
+from six import iteritems
-from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
-from synapse.crypto.event_signing import compute_event_signature
-from synapse.federation.federation_base import (
- FederationBase,
- event_from_pdu_json,
-)
+from canonicaljson import json
+from prometheus_client import Counter
+from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
+from twisted.python import failure
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import (
+ AuthError,
+ FederationError,
+ IncompatibleRoomVersionError,
+ NotFoundError,
+ SynapseError,
+)
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
+from synapse.replication.http.federation import (
+ ReplicationFederationSendEduRestServlet,
+ ReplicationGetQueryRestServlet,
+)
from synapse.types import get_domain_from_id
-from synapse.util import async
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logutils import log_function
-from prometheus_client import Counter
-
-from six import iteritems
-
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
@@ -59,8 +71,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
- self._server_linearizer = async.Linearizer("fed_server")
- self._transaction_linearizer = async.Linearizer("fed_txn_handler")
+ self._server_linearizer = Linearizer("fed_server")
+ self._transaction_linearizer = Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
@@ -74,6 +86,9 @@ class FederationServer(FederationBase):
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -134,6 +149,8 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus))
+ origin_host, _ = parse_server_name(transaction.origin)
+
pdus_by_room = {}
for p in transaction.pdus:
@@ -154,9 +171,21 @@ class FederationServer(FederationBase):
# we can process different rooms in parallel (which is useful if they
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
+
@defer.inlineCallbacks
def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
+ try:
+ yield self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warn(
+ "Ignoring PDUs for room %s from banned server", room_id,
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
try:
@@ -168,10 +197,14 @@ class FederationServer(FederationBase):
logger.warn("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
+ f = failure.Failure()
pdu_results[event_id] = {"error": str(e)}
- logger.exception("Failed to handle PDU %s", event_id)
+ logger.error(
+ "Failed to handle PDU %s: %s",
+ event_id, f.getTraceback().rstrip(),
+ )
- yield async.concurrently_execute(
+ yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT,
)
@@ -184,10 +217,6 @@ class FederationServer(FederationBase):
edu.content
)
- pdu_failures = getattr(transaction, "pdu_failures", [])
- for failure in pdu_failures:
- logger.info("Got failure %r", failure)
-
response = {
"pdus": pdu_results,
}
@@ -211,6 +240,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -234,6 +266,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -277,7 +312,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
+ pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -298,14 +333,27 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp))
@defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = yield self.store.get_room_version(room_id)
+ if room_version not in supported_versions:
+ logger.warn("Room version %s not in %s", room_version, supported_versions)
+ raise IncompatibleRoomVersionError(room_version=room_version)
+
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+ defer.returnValue({
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version,
+ })
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = event_from_pdu_json(content)
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -314,6 +362,10 @@ class FederationServer(FederationBase):
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -325,7 +377,9 @@ class FederationServer(FederationBase):
}))
@defer.inlineCallbacks
- def on_make_leave_request(self, room_id, user_id):
+ def on_make_leave_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -334,6 +388,10 @@ class FederationServer(FederationBase):
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -341,6 +399,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {
@@ -369,6 +430,9 @@ class FederationServer(FederationBase):
Deferred: Results in `dict` with the same format as `content`
"""
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
auth_chain = [
event_from_pdu_json(e)
for e in content["auth_chain"]
@@ -381,6 +445,7 @@ class FederationServer(FederationBase):
ret = yield self.handler.on_query_auth(
origin,
event_id,
+ room_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
@@ -442,6 +507,9 @@ class FederationServer(FederationBase):
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d",
@@ -470,17 +538,6 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
@@ -560,7 +617,9 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+ yield self.handler.on_receive_pdu(
+ origin, pdu, get_missing=True, sent_to_us_directly=True,
+ )
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
@@ -588,6 +647,101 @@ class FederationServer(FederationBase):
)
defer.returnValue(ret)
+ @defer.inlineCallbacks
+ def check_server_matches_acl(self, server_name, room_id):
+ """Check if the given server is allowed by the server ACLs in the room
+
+ Args:
+ server_name (str): name of server, *without any port part*
+ room_id (str): ID of the room to check
+
+ Raises:
+ AuthError if the server does not match the ACL
+ """
+ state_ids = yield self.store.get_current_state_ids(room_id)
+ acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+ if not acl_event_id:
+ return
+
+ acl_event = yield self.store.get_event(acl_event_id)
+ if server_matches_acl_event(server_name, acl_event):
+ return
+
+ raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+ """Check if the given server is allowed by the ACL event
+
+ Args:
+ server_name (str): name of server, without any port part
+ acl_event (EventBase): m.room.server_acl event
+
+ Returns:
+ bool: True if this server is allowed by the ACLs
+ """
+ logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+ # first of all, check if literal IPs are blocked, and if so, whether the
+ # server name is a literal IP
+ allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+ if not isinstance(allow_ip_literals, bool):
+ logger.warn("Ignorning non-bool allow_ip_literals flag")
+ allow_ip_literals = True
+ if not allow_ip_literals:
+ # check for ipv6 literals. These start with '['.
+ if server_name[0] == '[':
+ return False
+
+ # check for ipv4 literals. We can just lift the routine from twisted.
+ if isIPAddress(server_name):
+ return False
+
+ # next, check the deny list
+ deny = acl_event.content.get("deny", [])
+ if not isinstance(deny, (list, tuple)):
+ logger.warn("Ignorning non-list deny ACL %s", deny)
+ deny = []
+ for e in deny:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched deny rule %s", server_name, e)
+ return False
+
+ # then the allow list.
+ allow = acl_event.content.get("allow", [])
+ if not isinstance(allow, (list, tuple)):
+ logger.warn("Ignorning non-list allow ACL %s", allow)
+ allow = []
+ for e in allow:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched allow rule %s", server_name, e)
+ return True
+
+ # everything else should be rejected.
+ # logger.info("%s fell through", server_name)
+ return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+ if not isinstance(acl_entry, six.string_types):
+ logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ return False
+ regex = _glob_to_regex(acl_entry)
+ return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+ return re.compile(res + "\\Z", re.IGNORECASE)
+
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
@@ -610,6 +764,8 @@ class FederationHandlerRegistry(object):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+ logger.info("Registering federation EDU handler for %r", edu_type)
+
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
@@ -628,6 +784,8 @@ class FederationHandlerRegistry(object):
"Already have a Query handler for %s" % (query_type,)
)
+ logger.info("Registering federation query handler for %r", query_type)
+
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@@ -650,3 +808,49 @@ class FederationHandlerRegistry(object):
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
+
+
+class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
+ """A FederationHandlerRegistry for worker processes.
+
+ When receiving EDU or queries it will check if an appropriate handler has
+ been registered on the worker, if there isn't one then it calls off to the
+ master process.
+ """
+
+ def __init__(self, hs):
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+ self.clock = hs.get_clock()
+
+ self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
+ self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
+
+ super(ReplicationFederationHandlerRegistry, self).__init__()
+
+ def on_edu(self, edu_type, origin, content):
+ """Overrides FederationHandlerRegistry
+ """
+ handler = self.edu_handlers.get(edu_type)
+ if handler:
+ return super(ReplicationFederationHandlerRegistry, self).on_edu(
+ edu_type, origin, content,
+ )
+
+ return self._send_edu(
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
+ )
+
+ def on_query(self, query_type, args):
+ """Overrides FederationHandlerRegistry
+ """
+ handler = self.query_handlers.get(query_type)
+ if handler:
+ return handler(args)
+
+ return self._get_query_client(
+ query_type=query_type,
+ args=args,
+ )
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 84dc606673..9146215c21 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -19,13 +19,12 @@ package.
These actions are mostly only used by the :py:mod:`.replication` module.
"""
+import logging
+
from twisted.internet import defer
from synapse.util.logutils import log_function
-import logging
-
-
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 1d5c0f3797..6f5995735a 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -29,18 +29,18 @@ dead worker doesn't cause the queues to grow limitlessly.
Events are replicated via a separate events stream.
"""
-from .units import Edu
+import logging
+from collections import namedtuple
-from synapse.storage.presence import UserPresenceState
-from synapse.util.metrics import Measure
-from synapse.metrics import LaterGauge
+from six import iteritems
from sortedcontainers import SortedDict
-from collections import namedtuple
-import logging
+from synapse.metrics import LaterGauge
+from synapse.storage.presence import UserPresenceState
+from synapse.util.metrics import Measure
-from six import itervalues, iteritems
+from .units import Edu
logger = logging.getLogger(__name__)
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu
- self.failures = SortedDict() # stream position -> (destination, Failure)
-
self.device_messages = SortedDict() # stream position -> destination
self.pos = 1
@@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "failures", "device_messages", "pos_time",
+ "edus", "device_messages", "pos_time",
]:
register(queue_name, getattr(self, queue_name))
@@ -119,7 +117,7 @@ class FederationRemoteSendQueue(object):
user_ids = set(
user_id
- for uids in itervalues(self.presence_changed)
+ for uids in self.presence_changed.values()
for user_id in uids
)
@@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.edus[key]
- # Delete things out of failure map
- keys = self.failures.keys()
- i = self.failures.bisect_left(position_to_delete)
- for key in keys[:i]:
- del self.failures[key]
-
# Delete things out of device map
keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete)
@@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
- def send_failure(self, failure, destination):
- """As per TransactionQueue"""
- pos = self._next_pos()
-
- self.failures[pos] = (destination, str(failure))
- self.notifier.on_new_replication_data()
-
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
@@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
- # Fetch changed failures
- i = self.failures.bisect_right(from_token)
- j = self.failures.bisect_right(to_token) + 1
- failures = self.failures.items()[i:j]
-
- for (pos, (destination, failure)) in failures:
- rows.append((pos, FailureRow(
- destination=destination,
- failure=failure,
- )))
-
# Fetch changed device messages
i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1
@@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
- "destination", # str
- "failure",
-))):
- """Streams failures to a remote server. Failures are issued when there was
- something wrong with a transaction the remote sent us, e.g. it included
- an event that was invalid.
- """
-
- TypeId = "f"
-
- @staticmethod
- def from_data(data):
- return FailureRow(
- destination=data["destination"],
- failure=data["failure"],
- )
-
- def to_data(self):
- return {
- "destination": self.destination,
- "failure": self.failure,
- }
-
- def add_to_buffer(self, buff):
- buff.failures.setdefault(self.destination, []).append(self.failure)
-
-
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str
))):
@@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow,
KeyedEduRow,
EduRow,
- FailureRow,
DeviceRow,
)
}
@@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
- "failures", # dict of destination -> [failures]
"device_destinations", # set of destinations
))
@@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[],
keyed_edus={},
edus={},
- failures={},
device_destinations=set(),
)
@@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in iteritems(buff.failures):
- for failure in failure_list:
- transaction_queue.send_failure(destination, failure)
-
for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f0aeb5a0d3..94d7423d01 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -13,37 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
+import logging
-from twisted.internet import defer
+from six import itervalues
-from .persistence import TransactionActions
-from .units import Transaction, Edu
+from prometheus_client import Counter
+
+from twisted.internet import defer
-from synapse.api.errors import HttpResponseException, FederationDeniedError
-from synapse.util import logcontext, PreserveLoggingContext
-from synapse.util.async import run_on_reactor
-from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
-from synapse.util.metrics import measure_func
-from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
-from synapse.metrics import LaterGauge
+from synapse.api.errors import FederationDeniedError, HttpResponseException
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import (
+ LaterGauge,
+ event_processing_loop_counter,
+ event_processing_loop_room_count,
+ events_processed_counter,
sent_edus_counter,
sent_transactions_counter,
- events_processed_counter,
)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import logcontext
+from synapse.util.metrics import measure_func
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
-from prometheus_client import Counter
-
-from six import itervalues
-
-import logging
-
+from .persistence import TransactionActions
+from .units import Edu, Transaction
logger = logging.getLogger(__name__)
-sent_pdus_destination_dist = Counter(
- "synapse_federation_transaction_queue_sent_pdu_destinations", ""
+sent_pdus_destination_dist_count = Counter(
+ "synapse_federation_client_sent_pdu_destinations:count", ""
+)
+sent_pdus_destination_dist_total = Counter(
+ "synapse_federation_client_sent_pdu_destinations:total", ""
)
@@ -55,6 +58,7 @@ class TransactionQueue(object):
"""
def __init__(self, hs):
+ self.hs = hs
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -115,9 +119,6 @@ class TransactionQueue(object):
),
)
- # destination -> list of tuple(failure, deferred)
- self.pending_failures_by_dest = {}
-
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
@@ -165,10 +166,11 @@ class TransactionQueue(object):
if self._is_processing:
return
- # fire off a processing loop in the background. It's likely it will
- # outlast the current request, so run it in the sentinel logcontext.
- with PreserveLoggingContext():
- self._process_event_queue_loop()
+ # fire off a processing loop in the background
+ run_as_background_process(
+ "process_event_queue_for_federation",
+ self._process_event_queue_loop,
+ )
@defer.inlineCallbacks
def _process_event_queue_loop(self):
@@ -254,7 +256,13 @@ class TransactionQueue(object):
synapse.metrics.event_processing_last_ts.labels(
"federation_sender").set(ts)
- events_processed_counter.inc(len(events))
+ events_processed_counter.inc(len(events))
+
+ event_processing_loop_room_count.labels(
+ "federation_sender"
+ ).inc(len(events_by_room))
+
+ event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels(
"federation_sender").set(next_token)
@@ -280,7 +288,8 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc(len(destinations))
+ sent_pdus_destination_dist_total.inc(len(destinations))
+ sent_pdus_destination_dist_count.inc()
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
@@ -300,6 +309,9 @@ class TransactionQueue(object):
Args:
states (list(UserPresenceState))
"""
+ if not self.hs.config.use_presence:
+ # No-op if presence is disabled.
+ return
# First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled
@@ -379,19 +391,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination)
- def send_failure(self, failure, destination):
- if destination == self.server_name or destination == "localhost":
- return
-
- if not self.can_send_to(destination):
- return
-
- self.pending_failures_by_dest.setdefault(
- destination, []
- ).append(failure)
-
- self._attempt_new_transaction(destination)
-
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -431,14 +430,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Starting transaction loop", destination)
- # Drop the logcontext before starting the transaction. It doesn't
- # really make sense to log all the outbound transactions against
- # whatever path led us to this point: that's pretty arbitrary really.
- #
- # (this also means we can fire off _perform_transaction without
- # yielding)
- with logcontext.PreserveLoggingContext():
- self._transaction_transmission_loop(destination)
+ run_as_background_process(
+ "federation_transaction_transmission_loop",
+ self._transaction_transmission_loop,
+ destination,
+ )
@defer.inlineCallbacks
def _transaction_transmission_loop(self, destination):
@@ -451,9 +447,6 @@ class TransactionQueue(object):
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
- # XXX: what's this for?
- yield run_on_reactor()
-
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
@@ -472,7 +465,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@@ -500,7 +492,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
+ if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
@@ -510,7 +502,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION
success = yield self._send_new_transaction(
- destination, pending_pdus, pending_edus, pending_failures,
+ destination, pending_pdus, pending_edus,
)
if success:
sent_transactions_counter.inc()
@@ -587,14 +579,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
- def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures):
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
- failures = [x.get_dict() for x in pending_failures]
success = True
@@ -604,11 +594,10 @@ class TransactionQueue(object):
logger.debug(
"TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
+ " (pdus: %d, edus: %d)",
destination, txn_id,
len(pdus),
len(edus),
- len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
@@ -620,7 +609,6 @@ class TransactionQueue(object):
destination=destination,
pdus=pdus,
edus=edus,
- pdu_failures=failures,
)
self._next_txn_id += 1
@@ -630,12 +618,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
+ " (PDUs: %d, EDUs: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
- len(failures),
)
# Actually send the transaction
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 6db8efa6dd..1054441ca5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -14,16 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
-import logging
-import urllib
-
-
logger = logging.getLogger(__name__)
@@ -107,7 +106,7 @@ class TransportLayerClient(object):
dest (str)
room_id (str)
event_tuples (list)
- limt (int)
+ limit (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
@@ -196,7 +195,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def make_membership_event(self, destination, room_id, user_id, membership):
+ def make_membership_event(self, destination, room_id, user_id, membership, params):
"""Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs.
@@ -206,6 +205,8 @@ class TransportLayerClient(object):
room_id (str): room to join/leave
user_id (str): user to be joined/left
membership (str): one of join/leave
+ params (dict[str, str|Iterable[str]]): Query parameters to include in the
+ request.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -242,6 +243,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(
destination=destination,
path=path,
+ args=params,
retry_on_dns_fail=retry_on_dns_fail,
timeout=20000,
ignore_backoff=ignore_backoff,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..7a993fd1cf 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -14,25 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import logging
+import re
+
from twisted.internet import defer
+import synapse
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
- parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
+ parse_integer_from_args,
+ parse_json_object_from_request,
+ parse_string_from_args,
)
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import run_in_background
-from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-
-import functools
-import logging
-import re
-import synapse
-
logger = logging.getLogger(__name__)
@@ -99,26 +101,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +109,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,7 +147,84 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith("\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+
+ # ensure that the origin is a valid server name
+ parse_and_validate_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
+ """Abstract base class for federation servlet classes.
+
+ The servlet object should have a PATH attribute which takes the form of a regexp to
+ match against the request path (excluding the /federation/v1 prefix).
+
+ The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
+ the appropriate HTTP method. These methods have the signature:
+
+ on_<METHOD>(self, origin, content, query, **kwargs)
+
+ With arguments:
+
+ origin (unicode|None): The authenticated server_name of the calling server,
+ unless REQUIRE_AUTH is set to False and authentication failed.
+
+ content (unicode|None): decoded json body of the request. None if the
+ request was a GET.
+
+ query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
+ (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
+ yet.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+
+ Raises:
+ SynapseError: to return an error code
+
+ Exception: other exceptions will be caught, logged, and a 500 will be
+ returned.
+ """
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name):
@@ -180,6 +239,18 @@ class BaseFederationServlet(object):
@defer.inlineCallbacks
@functools.wraps(func)
def new_func(request, *args, **kwargs):
+ """ A callback which can be passed to HttpServer.RegisterPaths
+
+ Args:
+ request (twisted.web.http.Request):
+ *args: unused?
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: (response code, response object) as returned
+ by the callback method. None if the request has already been handled.
+ """
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
@@ -190,10 +261,10 @@ class BaseFederationServlet(object):
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.exception("authenticate_request failed")
+ logger.warn("authenticate_request failed: missing authentication")
raise
- except Exception:
- logger.exception("authenticate_request failed")
+ except Exception as e:
+ logger.warn("authenticate_request failed: %s", e)
raise
if origin:
@@ -259,11 +330,10 @@ class FederationSendServlet(BaseFederationServlet):
)
logger.info(
- "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+ "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
- len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
@@ -361,8 +431,32 @@ class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks
- def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_join_request(context, user_id)
+ def on_GET(self, origin, _content, query, context, user_id):
+ """
+ Args:
+ origin (unicode): The authenticated server_name of the calling server
+
+ _content (None): (GETs don't have bodies)
+
+ query (dict[bytes, list[bytes]]): Query params from the request.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+ """
+ versions = query.get(b'ver')
+ if versions is not None:
+ supported_versions = [v.decode("utf-8") for v in versions]
+ else:
+ supported_versions = ["1"]
+
+ content = yield self.handler.on_make_join_request(
+ origin, context, user_id,
+ supported_versions=supported_versions,
+ )
defer.returnValue((200, content))
@@ -371,15 +465,17 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(context, user_id)
+ content = yield self.handler.on_make_leave_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
- PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id, txid):
+ def on_PUT(self, origin, content, query, room_id, event_id):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 01c5b8fe17..c5ab14314e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,10 +17,9 @@
server protocol.
"""
-from synapse.util.jsonobject import JsonEncodedObject
-
import logging
+from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
@@ -74,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "pdu_failures",
]
internal_keys = [
|