diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d4dd967c60..591d0026bf 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
-import simplejson as json
+from canonicaljson import json
+import six
from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
+from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import (
@@ -27,6 +31,7 @@ from synapse.federation.federation_base import (
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
@@ -74,6 +79,9 @@ class FederationServer(FederationBase):
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -134,6 +142,8 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus))
+ origin_host, _ = parse_server_name(transaction.origin)
+
pdus_by_room = {}
for p in transaction.pdus:
@@ -154,9 +164,21 @@ class FederationServer(FederationBase):
# we can process different rooms in parallel (which is useful if they
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
+
@defer.inlineCallbacks
def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
+ try:
+ yield self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warn(
+ "Ignoring PDUs for room %s from banned server", room_id,
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
try:
@@ -211,6 +233,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -234,6 +259,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -298,7 +326,9 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp))
@defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -306,6 +336,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = event_from_pdu_json(content)
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -314,6 +346,10 @@ class FederationServer(FederationBase):
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -325,7 +361,9 @@ class FederationServer(FederationBase):
}))
@defer.inlineCallbacks
- def on_make_leave_request(self, room_id, user_id):
+ def on_make_leave_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -334,6 +372,10 @@ class FederationServer(FederationBase):
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -341,6 +383,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {
@@ -369,6 +414,9 @@ class FederationServer(FederationBase):
Deferred: Results in `dict` with the same format as `content`
"""
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
auth_chain = [
event_from_pdu_json(e)
for e in content["auth_chain"]
@@ -442,6 +490,9 @@ class FederationServer(FederationBase):
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d",
@@ -549,7 +600,9 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+ yield self.handler.on_receive_pdu(
+ origin, pdu, get_missing=True, sent_to_us_directly=True,
+ )
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
@@ -577,6 +630,101 @@ class FederationServer(FederationBase):
)
defer.returnValue(ret)
+ @defer.inlineCallbacks
+ def check_server_matches_acl(self, server_name, room_id):
+ """Check if the given server is allowed by the server ACLs in the room
+
+ Args:
+ server_name (str): name of server, *without any port part*
+ room_id (str): ID of the room to check
+
+ Raises:
+ AuthError if the server does not match the ACL
+ """
+ state_ids = yield self.store.get_current_state_ids(room_id)
+ acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+ if not acl_event_id:
+ return
+
+ acl_event = yield self.store.get_event(acl_event_id)
+ if server_matches_acl_event(server_name, acl_event):
+ return
+
+ raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+ """Check if the given server is allowed by the ACL event
+
+ Args:
+ server_name (str): name of server, without any port part
+ acl_event (EventBase): m.room.server_acl event
+
+ Returns:
+ bool: True if this server is allowed by the ACLs
+ """
+ logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+ # first of all, check if literal IPs are blocked, and if so, whether the
+ # server name is a literal IP
+ allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+ if not isinstance(allow_ip_literals, bool):
+ logger.warn("Ignorning non-bool allow_ip_literals flag")
+ allow_ip_literals = True
+ if not allow_ip_literals:
+ # check for ipv6 literals. These start with '['.
+ if server_name[0] == '[':
+ return False
+
+ # check for ipv4 literals. We can just lift the routine from twisted.
+ if isIPAddress(server_name):
+ return False
+
+ # next, check the deny list
+ deny = acl_event.content.get("deny", [])
+ if not isinstance(deny, (list, tuple)):
+ logger.warn("Ignorning non-list deny ACL %s", deny)
+ deny = []
+ for e in deny:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched deny rule %s", server_name, e)
+ return False
+
+ # then the allow list.
+ allow = acl_event.content.get("allow", [])
+ if not isinstance(allow, (list, tuple)):
+ logger.warn("Ignorning non-list allow ACL %s", allow)
+ allow = []
+ for e in allow:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched allow rule %s", server_name, e)
+ return True
+
+ # everything else should be rejected.
+ # logger.info("%s fell through", server_name)
+ return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+ if not isinstance(acl_entry, six.string_types):
+ logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ return False
+ regex = _glob_to_regex(acl_entry)
+ return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+ return re.compile(res + "\\Z", re.IGNORECASE)
+
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..c6d98d35cb 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -99,26 +100,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +108,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,6 +146,48 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+
+ # ensure that the origin is a valid server name
+ parse_and_validate_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
@@ -362,7 +385,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_join_request(context, user_id)
+ content = yield self.handler.on_make_join_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
@@ -371,7 +396,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(context, user_id)
+ content = yield self.handler.on_make_leave_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
|