diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/federation_base.py | 11 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 14 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 196 | ||||
-rw-r--r-- | synapse/federation/persistence.py | 5 | ||||
-rw-r--r-- | synapse/federation/send_queue.py | 77 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 67 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 9 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 100 | ||||
-rw-r--r-- | synapse/federation/units.py | 3 |
9 files changed, 319 insertions, 163 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 4cc98a3fe8..c11798093d 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -16,14 +16,15 @@ import logging import six +from twisted.internet import defer + from synapse.api.constants import MAX_DEPTH -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event -from synapse.http.servlet import assert_params_in_request -from synapse.util import unwrapFirstError, logcontext -from twisted.internet import defer +from synapse.http.servlet import assert_params_in_dict +from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) @@ -198,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) + assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth')) depth = pdu_json['depth'] if not isinstance(depth, six.integer_types): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 87a92f6ea9..62d7ed13cf 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,25 +21,25 @@ import random from six.moves import range +from prometheus_client import Counter + from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.errors import ( - CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError + CodeMessageException, + FederationDeniedError, + HttpResponseException, + SynapseError, ) from synapse.events import builder -from synapse.federation.federation_base import ( - FederationBase, - event_from_pdu_json, -) +from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.util import logcontext, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logutils import log_function from synapse.util.retryutils import NotRetryingDestination -from prometheus_client import Counter - logger = logging.getLogger(__name__) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2d420a58a2..e501251b6e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,28 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re + +import six +from six import iteritems + +from canonicaljson import json +from prometheus_client import Counter -import simplejson as json from twisted.internet import defer +from twisted.internet.abstract import isIPAddress +from twisted.python import failure -from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError from synapse.crypto.event_signing import compute_event_signature -from synapse.federation.federation_base import ( - FederationBase, - event_from_pdu_json, -) - +from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction +from synapse.http.endpoint import parse_server_name from synapse.types import get_domain_from_id from synapse.util import async from synapse.util.caches.response_cache import ResponseCache from synapse.util.logutils import log_function -from prometheus_client import Counter - -from six import iteritems - # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -74,6 +76,9 @@ class FederationServer(FederationBase): @log_function def on_backfill_request(self, origin, room_id, versions, limit): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + pdus = yield self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -134,6 +139,8 @@ class FederationServer(FederationBase): received_pdus_counter.inc(len(transaction.pdus)) + origin_host, _ = parse_server_name(transaction.origin) + pdus_by_room = {} for p in transaction.pdus: @@ -154,9 +161,21 @@ class FederationServer(FederationBase): # we can process different rooms in parallel (which is useful if they # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. + @defer.inlineCallbacks def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) + try: + yield self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warn( + "Ignoring PDUs for room %s from banned server", room_id, + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + for pdu in pdus_by_room[room_id]: event_id = pdu.event_id try: @@ -168,8 +187,12 @@ class FederationServer(FederationBase): logger.warn("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: + f = failure.Failure() pdu_results[event_id] = {"error": str(e)} - logger.exception("Failed to handle PDU %s", event_id) + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), + ) yield async.concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), @@ -185,8 +208,8 @@ class FederationServer(FederationBase): ) pdu_failures = getattr(transaction, "pdu_failures", []) - for failure in pdu_failures: - logger.info("Got failure %r", failure) + for fail in pdu_failures: + logger.info("Got failure %r", fail) response = { "pdus": pdu_results, @@ -211,6 +234,9 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -234,6 +260,9 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -277,7 +306,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_pdu_request(self, origin, event_id): - pdu = yield self._get_persisted_pdu(origin, event_id) + pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: defer.returnValue( @@ -298,7 +327,9 @@ class FederationServer(FederationBase): defer.returnValue((200, resp)) @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): + def on_make_join_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -306,6 +337,8 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_invite_request(self, origin, content): pdu = event_from_pdu_json(content) + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @@ -314,6 +347,10 @@ class FederationServer(FederationBase): def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -325,7 +362,9 @@ class FederationServer(FederationBase): })) @defer.inlineCallbacks - def on_make_leave_request(self, room_id, user_id): + def on_make_leave_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_leave_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -334,6 +373,10 @@ class FederationServer(FederationBase): def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @@ -341,6 +384,9 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) res = { @@ -369,6 +415,9 @@ class FederationServer(FederationBase): Deferred: Results in `dict` with the same format as `content` """ with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + auth_chain = [ event_from_pdu_json(e) for e in content["auth_chain"] @@ -442,6 +491,9 @@ class FederationServer(FederationBase): def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d, min_depth: %d", @@ -470,17 +522,6 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return self.store.get_user_id_for_open_id_token(token, ts_now_ms) - @log_function - def _get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. - - Returns: - Deferred: Results in a `Pdu`. - """ - return self.handler.get_persisted_pdu( - origin, event_id, do_auth=do_auth - ) - def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. @@ -560,7 +601,9 @@ class FederationServer(FederationBase): affected=pdu.event_id, ) - yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) + yield self.handler.on_receive_pdu( + origin, pdu, get_missing=True, sent_to_us_directly=True, + ) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name @@ -588,6 +631,101 @@ class FederationServer(FederationBase): ) defer.returnValue(ret) + @defer.inlineCallbacks + def check_server_matches_acl(self, server_name, room_id): + """Check if the given server is allowed by the server ACLs in the room + + Args: + server_name (str): name of server, *without any port part* + room_id (str): ID of the room to check + + Raises: + AuthError if the server does not match the ACL + """ + state_ids = yield self.store.get_current_state_ids(room_id) + acl_event_id = state_ids.get((EventTypes.ServerACL, "")) + + if not acl_event_id: + return + + acl_event = yield self.store.get_event(acl_event_id) + if server_matches_acl_event(server_name, acl_event): + return + + raise AuthError(code=403, msg="Server is banned from room") + + +def server_matches_acl_event(server_name, acl_event): + """Check if the given server is allowed by the ACL event + + Args: + server_name (str): name of server, without any port part + acl_event (EventBase): m.room.server_acl event + + Returns: + bool: True if this server is allowed by the ACLs + """ + logger.debug("Checking %s against acl %s", server_name, acl_event.content) + + # first of all, check if literal IPs are blocked, and if so, whether the + # server name is a literal IP + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warn("Ignorning non-bool allow_ip_literals flag") + allow_ip_literals = True + if not allow_ip_literals: + # check for ipv6 literals. These start with '['. + if server_name[0] == '[': + return False + + # check for ipv4 literals. We can just lift the routine from twisted. + if isIPAddress(server_name): + return False + + # next, check the deny list + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warn("Ignorning non-list deny ACL %s", deny) + deny = [] + for e in deny: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched deny rule %s", server_name, e) + return False + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warn("Ignorning non-list allow ACL %s", allow) + allow = [] + for e in allow: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched allow rule %s", server_name, e) + return True + + # everything else should be rejected. + # logger.info("%s fell through", server_name) + return False + + +def _acl_entry_matches(server_name, acl_entry): + if not isinstance(acl_entry, six.string_types): + logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + return False + regex = _glob_to_regex(acl_entry) + return regex.match(server_name) + + +def _glob_to_regex(glob): + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + return re.compile(res + "\\Z", re.IGNORECASE) + class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 84dc606673..9146215c21 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -19,13 +19,12 @@ package. These actions are mostly only used by the :py:mod:`.replication` module. """ +import logging + from twisted.internet import defer from synapse.util.logutils import log_function -import logging - - logger = logging.getLogger(__name__) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 9f1142b5a9..5157c3860d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -29,18 +29,18 @@ dead worker doesn't cause the queues to grow limitlessly. Events are replicated via a separate events stream. """ -from .units import Edu +import logging +from collections import namedtuple -from synapse.storage.presence import UserPresenceState -from synapse.util.metrics import Measure -from synapse.metrics import LaterGauge +from six import iteritems, itervalues -from blist import sorteddict -from collections import namedtuple +from sortedcontainers import SortedDict -import logging +from synapse.metrics import LaterGauge +from synapse.storage.presence import UserPresenceState +from synapse.util.metrics import Measure -from six import itervalues, iteritems +from .units import Edu logger = logging.getLogger(__name__) @@ -55,19 +55,19 @@ class FederationRemoteSendQueue(object): self.is_mine_id = hs.is_mine_id self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = sorteddict() # Stream position -> user_id + self.presence_changed = SortedDict() # Stream position -> user_id self.keyed_edu = {} # (destination, key) -> EDU - self.keyed_edu_changed = sorteddict() # stream position -> (destination, key) + self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) - self.edus = sorteddict() # stream position -> Edu + self.edus = SortedDict() # stream position -> Edu - self.failures = sorteddict() # stream position -> (destination, Failure) + self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = sorteddict() # stream position -> destination + self.device_messages = SortedDict() # stream position -> destination self.pos = 1 - self.pos_time = sorteddict() + self.pos_time = SortedDict() # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -98,7 +98,7 @@ class FederationRemoteSendQueue(object): now = self.clock.time_msec() keys = self.pos_time.keys() - time = keys.bisect_left(now - FIVE_MINUTES_AGO) + time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO) if not keys[:time]: return @@ -113,7 +113,7 @@ class FederationRemoteSendQueue(object): with Measure(self.clock, "send_queue._clear"): # Delete things out of presence maps keys = self.presence_changed.keys() - i = keys.bisect_left(position_to_delete) + i = self.presence_changed.bisect_left(position_to_delete) for key in keys[:i]: del self.presence_changed[key] @@ -131,7 +131,7 @@ class FederationRemoteSendQueue(object): # Delete things out of keyed edus keys = self.keyed_edu_changed.keys() - i = keys.bisect_left(position_to_delete) + i = self.keyed_edu_changed.bisect_left(position_to_delete) for key in keys[:i]: del self.keyed_edu_changed[key] @@ -145,19 +145,19 @@ class FederationRemoteSendQueue(object): # Delete things out of edu map keys = self.edus.keys() - i = keys.bisect_left(position_to_delete) + i = self.edus.bisect_left(position_to_delete) for key in keys[:i]: del self.edus[key] # Delete things out of failure map keys = self.failures.keys() - i = keys.bisect_left(position_to_delete) + i = self.failures.bisect_left(position_to_delete) for key in keys[:i]: del self.failures[key] # Delete things out of device map keys = self.device_messages.keys() - i = keys.bisect_left(position_to_delete) + i = self.device_messages.bisect_left(position_to_delete) for key in keys[:i]: del self.device_messages[key] @@ -250,13 +250,12 @@ class FederationRemoteSendQueue(object): self._clear_queue_before_pos(federation_ack) # Fetch changed presence - keys = self.presence_changed.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 + i = self.presence_changed.bisect_right(from_token) + j = self.presence_changed.bisect_right(to_token) + 1 dest_user_ids = [ (pos, user_id) - for pos in keys[i:j] - for user_id in self.presence_changed[pos] + for pos, user_id_list in self.presence_changed.items()[i:j] + for user_id in user_id_list ] for (key, user_id) in dest_user_ids: @@ -265,13 +264,12 @@ class FederationRemoteSendQueue(object): ))) # Fetch changes keyed edus - keys = self.keyed_edu_changed.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 + i = self.keyed_edu_changed.bisect_right(from_token) + j = self.keyed_edu_changed.bisect_right(to_token) + 1 # We purposefully clobber based on the key here, python dict comprehensions # always use the last value, so this will correctly point to the last # stream position. - keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]} + keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} for ((destination, edu_key), pos) in iteritems(keyed_edus): rows.append((pos, KeyedEduRow( @@ -280,19 +278,17 @@ class FederationRemoteSendQueue(object): ))) # Fetch changed edus - keys = self.edus.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - edus = ((k, self.edus[k]) for k in keys[i:j]) + i = self.edus.bisect_right(from_token) + j = self.edus.bisect_right(to_token) + 1 + edus = self.edus.items()[i:j] for (pos, edu) in edus: rows.append((pos, EduRow(edu))) # Fetch changed failures - keys = self.failures.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - failures = ((k, self.failures[k]) for k in keys[i:j]) + i = self.failures.bisect_right(from_token) + j = self.failures.bisect_right(to_token) + 1 + failures = self.failures.items()[i:j] for (pos, (destination, failure)) in failures: rows.append((pos, FailureRow( @@ -301,10 +297,9 @@ class FederationRemoteSendQueue(object): ))) # Fetch changed device messages - keys = self.device_messages.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - device_messages = {self.device_messages[k]: k for k in keys[i:j]} + i = self.device_messages.bisect_right(from_token) + j = self.device_messages.bisect_right(to_token) + 1 + device_messages = {v: k for k, v in self.device_messages.items()[i:j]} for (destination, pos) in iteritems(device_messages): rows.append((pos, DeviceRow( diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f0aeb5a0d3..6996d6b695 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -13,37 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime +import logging -from twisted.internet import defer +from six import itervalues -from .persistence import TransactionActions -from .units import Transaction, Edu +from prometheus_client import Counter + +from twisted.internet import defer -from synapse.api.errors import HttpResponseException, FederationDeniedError -from synapse.util import logcontext, PreserveLoggingContext -from synapse.util.async import run_on_reactor -from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter -from synapse.util.metrics import measure_func -from synapse.handlers.presence import format_user_presence_state, get_interested_remotes import synapse.metrics -from synapse.metrics import LaterGauge +from synapse.api.errors import FederationDeniedError, HttpResponseException +from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.metrics import ( + LaterGauge, + events_processed_counter, sent_edus_counter, sent_transactions_counter, - events_processed_counter, ) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import logcontext +from synapse.util.metrics import measure_func +from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter -from prometheus_client import Counter - -from six import itervalues - -import logging - +from .persistence import TransactionActions +from .units import Edu, Transaction logger = logging.getLogger(__name__) -sent_pdus_destination_dist = Counter( - "synapse_federation_transaction_queue_sent_pdu_destinations", "" +sent_pdus_destination_dist_count = Counter( + "synapse_federation_client_sent_pdu_destinations:count", "" +) +sent_pdus_destination_dist_total = Counter( + "synapse_federation_client_sent_pdu_destinations:total", "" ) @@ -165,10 +166,11 @@ class TransactionQueue(object): if self._is_processing: return - # fire off a processing loop in the background. It's likely it will - # outlast the current request, so run it in the sentinel logcontext. - with PreserveLoggingContext(): - self._process_event_queue_loop() + # fire off a processing loop in the background + run_as_background_process( + "process_event_queue_for_federation", + self._process_event_queue_loop, + ) @defer.inlineCallbacks def _process_event_queue_loop(self): @@ -280,7 +282,8 @@ class TransactionQueue(object): if not destinations: return - sent_pdus_destination_dist.inc(len(destinations)) + sent_pdus_destination_dist_total.inc(len(destinations)) + sent_pdus_destination_dist_count.inc() for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( @@ -431,14 +434,11 @@ class TransactionQueue(object): logger.debug("TX [%s] Starting transaction loop", destination) - # Drop the logcontext before starting the transaction. It doesn't - # really make sense to log all the outbound transactions against - # whatever path led us to this point: that's pretty arbitrary really. - # - # (this also means we can fire off _perform_transaction without - # yielding) - with logcontext.PreserveLoggingContext(): - self._transaction_transmission_loop(destination) + run_as_background_process( + "federation_transaction_transmission_loop", + self._transaction_transmission_loop, + destination, + ) @defer.inlineCallbacks def _transaction_transmission_loop(self, destination): @@ -451,9 +451,6 @@ class TransactionQueue(object): # hence why we throw the result away. yield get_retry_limiter(destination, self.clock, self.store) - # XXX: what's this for? - yield run_on_reactor() - pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6db8efa6dd..4529d454af 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -14,16 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import urllib + from twisted.internet import defer -from synapse.api.constants import Membership +from synapse.api.constants import Membership from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function -import logging -import urllib - - logger = logging.getLogger(__name__) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 19d09f5422..8574898f0c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -14,25 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging +import re + from twisted.internet import defer +import synapse +from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError, FederationDeniedError +from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( - parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_boolean_from_args, + parse_integer_from_args, + parse_json_object_from_request, + parse_string_from_args, ) +from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.logcontext import run_in_background from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string -from synapse.util.logcontext import run_in_background -from synapse.types import ThirdPartyInstanceID, get_domain_from_id - -import functools -import logging -import re -import synapse - logger = logging.getLogger(__name__) @@ -99,26 +101,6 @@ class Authenticator(object): origin = None - def parse_auth_header(header_str): - try: - params = auth.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith("\""): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return (origin, key, sig) - except Exception: - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: @@ -127,8 +109,8 @@ class Authenticator(object): ) for auth in auth_headers: - if auth.startswith("X-Matrix"): - (origin, key, sig) = parse_auth_header(auth) + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig @@ -165,6 +147,48 @@ class Authenticator(object): logger.exception("Error resetting retry timings on %s", origin) +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode('utf-8') + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith(b"\""): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warn( + "Error parsing auth header '%s': %s", + header_bytes.decode('ascii', 'replace'), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + ) + + class BaseFederationServlet(object): REQUIRE_AUTH = True @@ -362,7 +386,9 @@ class FederationMakeJoinServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_join_request(context, user_id) + content = yield self.handler.on_make_join_request( + origin, context, user_id, + ) defer.returnValue((200, content)) @@ -371,15 +397,17 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request(context, user_id) + content = yield self.handler.on_make_leave_request( + origin, context, user_id, + ) defer.returnValue((200, content)) class FederationSendLeaveServlet(BaseFederationServlet): - PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)" + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks - def on_PUT(self, origin, content, query, room_id, txid): + def on_PUT(self, origin, content, query, room_id, event_id): content = yield self.handler.on_send_leave_request(origin, content) defer.returnValue((200, content)) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 01c5b8fe17..bb1b3b13f7 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,10 +17,9 @@ server protocol. """ -from synapse.util.jsonobject import JsonEncodedObject - import logging +from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) |