From cd6bcdaf87f6c68e0c95b789c8fcb144a0d64b1a Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 10:37:24 +0100
Subject: Better testing framework for homeserver-using things (#3446)

---
 tests/server.py      | 181 +++++++++++++++++++++++++++++++++++++++++++++++++++
 tests/test_server.py | 128 ++++++++++++++++++++++++++++++++++++
 2 files changed, 309 insertions(+)
 create mode 100644 tests/server.py
 create mode 100644 tests/test_server.py

(limited to 'tests')

diff --git a/tests/server.py b/tests/server.py
new file mode 100644
index 0000000000..73069dff52
--- /dev/null
+++ b/tests/server.py
@@ -0,0 +1,181 @@
+from io import BytesIO
+
+import attr
+import json
+from six import text_type
+
+from twisted.python.failure import Failure
+from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.http.site import SynapseRequest
+from twisted.internet import threads
+from tests.utils import setup_test_homeserver as _sth
+
+
+@attr.s
+class FakeChannel(object):
+    """
+    A fake Twisted Web Channel (the part that interfaces with the
+    wire).
+    """
+
+    result = attr.ib(factory=dict)
+
+    @property
+    def json_body(self):
+        if not self.result:
+            raise Exception("No result yet.")
+        return json.loads(self.result["body"])
+
+    def writeHeaders(self, version, code, reason, headers):
+        self.result["version"] = version
+        self.result["code"] = code
+        self.result["reason"] = reason
+        self.result["headers"] = headers
+
+    def write(self, content):
+        if "body" not in self.result:
+            self.result["body"] = b""
+
+        self.result["body"] += content
+
+    def requestDone(self, _self):
+        self.result["done"] = True
+
+    def getPeer(self):
+        return None
+
+    def getHost(self):
+        return None
+
+    @property
+    def transport(self):
+        return self
+
+
+class FakeSite:
+    """
+    A fake Twisted Web Site, with mocks of the extra things that
+    Synapse adds.
+    """
+
+    server_version_string = b"1"
+    site_tag = "test"
+
+    @property
+    def access_logger(self):
+        class FakeLogger:
+            def info(self, *args, **kwargs):
+                pass
+
+        return FakeLogger()
+
+
+def make_request(method, path, content=b""):
+    """
+    Make a web request using the given method and path, feed it the
+    content, and return the Request and the Channel underneath.
+    """
+
+    if isinstance(content, text_type):
+        content = content.encode('utf8')
+
+    site = FakeSite()
+    channel = FakeChannel()
+
+    req = SynapseRequest(site, channel)
+    req.process = lambda: b""
+    req.content = BytesIO(content)
+    req.requestReceived(method, path, b"1.1")
+
+    return req, channel
+
+
+def wait_until_result(clock, channel, timeout=100):
+    """
+    Wait until the channel has a result.
+    """
+    clock.run()
+    x = 0
+
+    while not channel.result:
+        x += 1
+
+        if x > timeout:
+            raise Exception("Timed out waiting for request to finish.")
+
+        clock.advance(0.1)
+
+
+class ThreadedMemoryReactorClock(MemoryReactorClock):
+    """
+    A MemoryReactorClock that supports callFromThread.
+    """
+    def callFromThread(self, callback, *args, **kwargs):
+        """
+        Make the callback fire in the next reactor iteration.
+        """
+        d = Deferred()
+        d.addCallback(lambda x: callback(*args, **kwargs))
+        self.callLater(0, d.callback, True)
+        return d
+
+
+def setup_test_homeserver(*args, **kwargs):
+    """
+    Set up a synchronous test server, driven by the reactor used by
+    the homeserver.
+    """
+    d = _sth(*args, **kwargs).result
+
+    # Make the thread pool synchronous.
+    clock = d.get_clock()
+    pool = d.get_db_pool()
+
+    def runWithConnection(func, *args, **kwargs):
+        return threads.deferToThreadPool(
+            pool._reactor,
+            pool.threadpool,
+            pool._runWithConnection,
+            func,
+            *args,
+            **kwargs
+        )
+
+    def runInteraction(interaction, *args, **kwargs):
+        return threads.deferToThreadPool(
+            pool._reactor,
+            pool.threadpool,
+            pool._runInteraction,
+            interaction,
+            *args,
+            **kwargs
+        )
+
+    pool.runWithConnection = runWithConnection
+    pool.runInteraction = runInteraction
+
+    class ThreadPool:
+        """
+        Threadless thread pool.
+        """
+        def start(self):
+            pass
+
+        def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
+            def _(res):
+                if isinstance(res, Failure):
+                    onResult(False, res)
+                else:
+                    onResult(True, res)
+
+            d = Deferred()
+            d.addCallback(lambda x: function(*args, **kwargs))
+            d.addBoth(_)
+            clock._reactor.callLater(0, d.callback, True)
+            return d
+
+    clock.threadpool = ThreadPool()
+    pool.threadpool = ThreadPool()
+    return d
diff --git a/tests/test_server.py b/tests/test_server.py
new file mode 100644
index 0000000000..8ad822c43b
--- /dev/null
+++ b/tests/test_server.py
@@ -0,0 +1,128 @@
+import json
+import re
+
+from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.util import Clock
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import JsonResource
+from tests import unittest
+from tests.server import make_request, setup_test_homeserver
+
+
+class JsonResourceTests(unittest.TestCase):
+    def setUp(self):
+        self.reactor = MemoryReactorClock()
+        self.hs_clock = Clock(self.reactor)
+        self.homeserver = setup_test_homeserver(
+            http_client=None, clock=self.hs_clock, reactor=self.reactor
+        )
+
+    def test_handler_for_request(self):
+        """
+        JsonResource.handler_for_request gives correctly decoded URL args to
+        the callback, while Twisted will give the raw bytes of URL query
+        arguments.
+        """
+        got_kwargs = {}
+
+        def _callback(request, **kwargs):
+            got_kwargs.update(kwargs)
+            return (200, kwargs)
+
+        res = JsonResource(self.homeserver)
+        res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
+
+        request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
+        request.render(res)
+
+        self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
+        self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
+
+    def test_callback_direct_exception(self):
+        """
+        If the web callback raises an uncaught exception, it will be translated
+        into a 500.
+        """
+
+        def _callback(request, **kwargs):
+            raise Exception("boo")
+
+        res = JsonResource(self.homeserver)
+        res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+        request, channel = make_request(b"GET", b"/foo")
+        request.render(res)
+
+        self.assertEqual(channel.result["code"], b'500')
+
+    def test_callback_indirect_exception(self):
+        """
+        If the web callback raises an uncaught exception in a Deferred, it will
+        be translated into a 500.
+        """
+
+        def _throw(*args):
+            raise Exception("boo")
+
+        def _callback(request, **kwargs):
+            d = Deferred()
+            d.addCallback(_throw)
+            self.reactor.callLater(1, d.callback, True)
+            return d
+
+        res = JsonResource(self.homeserver)
+        res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+        request, channel = make_request(b"GET", b"/foo")
+        request.render(res)
+
+        # No error has been raised yet
+        self.assertTrue("code" not in channel.result)
+
+        # Advance time, now there's an error
+        self.reactor.advance(1)
+        self.assertEqual(channel.result["code"], b'500')
+
+    def test_callback_synapseerror(self):
+        """
+        If the web callback raises a SynapseError, it returns the appropriate
+        status code and message set in it.
+        """
+
+        def _callback(request, **kwargs):
+            raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
+
+        res = JsonResource(self.homeserver)
+        res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+        request, channel = make_request(b"GET", b"/foo")
+        request.render(res)
+
+        self.assertEqual(channel.result["code"], b'403')
+        reply_body = json.loads(channel.result["body"])
+        self.assertEqual(reply_body["error"], "Forbidden!!one!")
+        self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
+
+    def test_no_handler(self):
+        """
+        If there is no handler to process the request, Synapse will return 400.
+        """
+
+        def _callback(request, **kwargs):
+            """
+            Not ever actually called!
+            """
+            self.fail("shouldn't ever get here")
+
+        res = JsonResource(self.homeserver)
+        res.register_paths("GET", [re.compile("^/foo$")], _callback)
+
+        request, channel = make_request(b"GET", b"/foobar")
+        request.render(res)
+
+        self.assertEqual(channel.result["code"], b'400')
+        reply_body = json.loads(channel.result["body"])
+        self.assertEqual(reply_body["error"], "Unrecognized request")
+        self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")
-- 
cgit 1.5.1


From 77078d6c8ef11e2401406edc1ca340e0d7779267 Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 11:27:32 +0100
Subject: handle federation not telling us about prev_events

---
 synapse/federation/federation_server.py |   4 +-
 synapse/handlers/federation.py          |  87 ++++++++----
 tests/test_federation.py                | 235 ++++++++++++++++++++++++++++++++
 tests/unittest.py                       |   2 +-
 4 files changed, 301 insertions(+), 27 deletions(-)
 create mode 100644 tests/test_federation.py

(limited to 'tests')

diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d4dd967c60..4096093527 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -549,7 +549,9 @@ class FederationServer(FederationBase):
                 affected=pdu.event_id,
             )
 
-        yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+        yield self.handler.on_receive_pdu(
+            origin, pdu, get_missing=True, sent_to_us_directly=True,
+        )
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b6f8d4cf82..250a5509d8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -44,6 +44,7 @@ from synapse.util.frozenutils import unfreeze
 from synapse.crypto.event_signing import (
     compute_event_signature, add_hashes_and_signatures,
 )
+from synapse.state import resolve_events_with_factory
 from synapse.types import UserID, get_domain_from_id
 
 from synapse.events.utils import prune_event
@@ -89,7 +90,9 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def on_receive_pdu(self, origin, pdu, get_missing=True):
+    def on_receive_pdu(
+            self, origin, pdu, get_missing=True, sent_to_us_directly=False,
+    ):
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -163,7 +166,7 @@ class FederationHandler(BaseHandler):
                     "Ignoring PDU %s for room %s from %s as we've left the room!",
                     pdu.event_id, pdu.room_id, origin,
                 )
-                return
+                defer.returnValue(None)
 
         state = None
 
@@ -225,26 +228,54 @@ class FederationHandler(BaseHandler):
                         list(prevs - seen)[:5],
                     )
 
-            if prevs - seen:
-                logger.info(
-                    "Still missing %d events for room %r: %r...",
-                    len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+            if sent_to_us_directly and prevs - seen:
+                # If they have sent it to us directly, and the server
+                # isn't telling us about the auth events that it's
+                # made a message referencing, we explode
+                raise FederationError(
+                    "ERROR",
+                    403,
+                    ("Your server isn't divulging details about prev_events "
+                     "referenced in this event."),
+                    affected=pdu.event_id,
                 )
-                fetch_state = True
+            elif prevs - seen:
+                # If we're walking back up the chain to fetch it, then
+                # try and find the states. If we can't get the states,
+                # discard it.
+                state_groups = []
+                auth_chains = set()
+                try:
+                    # Get the ones we know about
+                    ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
+                    state_groups.append(ours)
 
-        if fetch_state:
-            # We need to get the state at this event, since we haven't
-            # processed all the prev events.
-            logger.debug(
-                "_handle_new_pdu getting state for %s",
-                pdu.room_id
-            )
-            try:
-                state, auth_chain = yield self.replication_layer.get_state_for_room(
-                    origin, pdu.room_id, pdu.event_id,
-                )
-            except Exception:
-                logger.exception("Failed to get state for event: %s", pdu.event_id)
+                    for p in prevs - seen:
+                        state, auth_chain = yield self.replication_layer.get_state_for_room(
+                            origin, pdu.room_id, p
+                        )
+                        auth_chains.update(auth_chain)
+                        state_group = {
+                            (x.type, x.state_key): x.event_id for x in state
+                        }
+                        state_groups.append(state_group)
+
+                    def fetch(ev_ids):
+                        return self.store.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        )
+
+                    state = yield resolve_events_with_factory(state_groups, {pdu.event_id: pdu}, fetch)
+
+                    state = yield self.store.get_events(state.values())
+                    state = state.values()
+                except Exception:
+                    raise FederationError(
+                        "ERROR",
+                        403,
+                        "We can't get valid state history.",
+                        affected=pdu.event_id,
+                    )
 
         yield self._process_received_pdu(
             origin,
@@ -322,11 +353,17 @@ class FederationHandler(BaseHandler):
 
         for e in missing_events:
             logger.info("Handling found event %s", e.event_id)
-            yield self.on_receive_pdu(
-                origin,
-                e,
-                get_missing=False
-            )
+            try:
+                yield self.on_receive_pdu(
+                    origin,
+                    e,
+                    get_missing=False
+                )
+            except FederationError as e:
+                if e.code == 403:
+                    logger.warn("Event %s failed history check.")
+                else:
+                    raise
 
     @log_function
     @defer.inlineCallbacks
diff --git a/tests/test_federation.py b/tests/test_federation.py
new file mode 100644
index 0000000000..12f4633cd5
--- /dev/null
+++ b/tests/test_federation.py
@@ -0,0 +1,235 @@
+
+from twisted.internet.defer import Deferred, succeed, maybeDeferred
+
+from synapse.util import Clock
+from synapse.events import FrozenEvent
+from synapse.types import Requester, UserID
+from synapse.replication.slave.storage.events import SlavedEventStore
+
+from tests import unittest
+from tests.server import make_request, setup_test_homeserver, ThreadedMemoryReactorClock
+
+from mock import Mock
+
+from synapse.api.errors import CodeMessageException, HttpResponseException
+
+
+class MessageAcceptTests(unittest.TestCase):
+    def setUp(self):
+
+        self.http_client = Mock()
+        self.reactor = ThreadedMemoryReactorClock()
+        self.hs_clock = Clock(self.reactor)
+        self.homeserver = setup_test_homeserver(
+            http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor
+        )
+
+        user_id = UserID("us", "test")
+        our_user = Requester(user_id, None, False, None, None)
+        room_creator = self.homeserver.get_room_creation_handler()
+        room = room_creator.create_room(
+            our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+        )
+        self.reactor.advance(0.1)
+        self.room_id = self.successResultOf(room)["room_id"]
+
+        # Figure out what the most recent event is
+        most_recent = self.successResultOf(
+            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        )[0]
+
+        join_event = FrozenEvent(
+            {
+                "room_id": self.room_id,
+                "sender": "@baduser:test.serv",
+                "state_key": "@baduser:test.serv",
+                "event_id": "$join:test.serv",
+                "depth": 1000,
+                "origin_server_ts": 1,
+                "type": "m.room.member",
+                "origin": "test.servx",
+                "content": {"membership": "join"},
+                "auth_events": [],
+                "prev_state": [(most_recent, {})],
+                "prev_events": [(most_recent, {})],
+            }
+        )
+
+        self.handler = self.homeserver.get_handlers().federation_handler
+        self.handler.do_auth = lambda *a, **b: succeed(True)
+        self.client = self.homeserver.get_federation_client()
+        self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
+            pdus
+        )
+
+        # Send the join, it should return None (which is not an error)
+        d = self.handler.on_receive_pdu(
+            "test.serv", join_event, sent_to_us_directly=True
+        )
+        self.reactor.advance(1)
+        self.assertEqual(self.successResultOf(d), None)
+
+        # Make sure we actually joined the room
+        self.assertEqual(
+            self.successResultOf(
+                self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+            )[0],
+            "$join:test.serv",
+        )
+
+    def test_cant_hide_direct_ancestors(self):
+        """
+        If you send a message, you must be able to provide the direct
+        prev_events that said event references.
+        """
+
+        def post_json(destination, path, data, headers=None, timeout=0):
+            # If it asks us for new missing events, give them NOTHING
+            if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+                return {"events": []}
+
+        self.http_client.post_json = post_json
+
+        # Figure out what the most recent event is
+        most_recent = self.successResultOf(
+            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        )[0]
+
+        # Now lie about an event
+        lying_event = FrozenEvent(
+            {
+                "room_id": self.room_id,
+                "sender": "@baduser:test.serv",
+                "event_id": "one:test.serv",
+                "depth": 1000,
+                "origin_server_ts": 1,
+                "type": "m.room.message",
+                "origin": "test.serv",
+                "content": "hewwo?",
+                "auth_events": [],
+                "prev_events": [("two:test.serv", {}), (most_recent, {})],
+            }
+        )
+
+        d = self.handler.on_receive_pdu(
+            "test.serv", lying_event, sent_to_us_directly=True
+        )
+
+        # Step the reactor, so the database fetches come back
+        self.reactor.advance(1)
+
+        # on_receive_pdu should throw an error
+        failure = self.failureResultOf(d)
+        self.assertEqual(
+            failure.value.args[0],
+            (
+                "ERROR 403: Your server isn't divulging details about prev_events "
+                "referenced in this event."
+            ),
+        )
+
+        # Make sure the invalid event isn't there
+        extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+    @unittest.DEBUG
+    def test_cant_hide_past_history(self):
+        """
+        If you send a message, you must be able to provide the direct
+        prev_events that said event references.
+        """
+
+        def post_json(destination, path, data, headers=None, timeout=0):
+            if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+                return {
+                    "events": [
+                        {
+                            "room_id": self.room_id,
+                            "sender": "@baduser:test.serv",
+                            "event_id": "three:test.serv",
+                            "depth": 1000,
+                            "origin_server_ts": 1,
+                            "type": "m.room.message",
+                            "origin": "test.serv",
+                            "content": "hewwo?",
+                            "auth_events": [],
+                            "prev_events": [("four:test.serv", {})],
+                        }
+                    ]
+                }
+
+        self.http_client.post_json = post_json
+
+        def get_json(destination, path, args, headers=None):
+            print(destination, path, args)
+            if path.startswith("/_matrix/federation/v1/state_ids/"):
+                d = self.successResultOf(
+                    self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
+                )
+
+                return succeed(
+                    {
+                        "pdu_ids": [
+                            y
+                            for x, y in d.items()
+                            if x == ("m.room.member", "@us:test")
+                        ],
+                        "auth_chain_ids": d.values(),
+                    }
+                )
+
+        self.http_client.get_json = get_json
+
+        # Figure out what the most recent event is
+        most_recent = self.successResultOf(
+            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        )[0]
+
+        # Make a good event
+        good_event = FrozenEvent(
+            {
+                "room_id": self.room_id,
+                "sender": "@baduser:test.serv",
+                "event_id": "one:test.serv",
+                "depth": 1000,
+                "origin_server_ts": 1,
+                "type": "m.room.message",
+                "origin": "test.serv",
+                "content": "hewwo?",
+                "auth_events": [],
+                "prev_events": [(most_recent, {})],
+            }
+        )
+
+        d = self.handler.on_receive_pdu(
+            "test.serv", good_event, sent_to_us_directly=True
+        )
+        self.reactor.advance(1)
+        self.assertEqual(self.successResultOf(d), None)
+
+        bad_event = FrozenEvent(
+            {
+                "room_id": self.room_id,
+                "sender": "@baduser:test.serv",
+                "event_id": "two:test.serv",
+                "depth": 1000,
+                "origin_server_ts": 1,
+                "type": "m.room.message",
+                "origin": "test.serv",
+                "content": "hewwo?",
+                "auth_events": [],
+                "prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
+            }
+        )
+
+        d = self.handler.on_receive_pdu(
+            "test.serv", bad_event, sent_to_us_directly=True
+        )
+        self.reactor.advance(1)
+
+        extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
+
+        state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
+        self.reactor.advance(1)
+        self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())
diff --git a/tests/unittest.py b/tests/unittest.py
index 184fe880f3..de24b1d2d4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -35,7 +35,7 @@ class ToTwistedHandler(logging.Handler):
     def emit(self, record):
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace('warning', 'warn')
-        self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry)
+        self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")"))
 
 
 handler = ToTwistedHandler()
-- 
cgit 1.5.1


From 8d62baa48cb222c3010007fdd6e48673f5cd0519 Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 11:31:48 +0100
Subject: cleanups

---
 tests/test_federation.py | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

(limited to 'tests')

diff --git a/tests/test_federation.py b/tests/test_federation.py
index 12f4633cd5..bc8b3af9b3 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -35,7 +35,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         join_event = FrozenEvent(
@@ -72,7 +72,7 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+                maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
             )[0],
             "$join:test.serv",
         )
@@ -92,7 +92,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Now lie about an event
@@ -129,7 +129,7 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
 
     @unittest.DEBUG
@@ -161,7 +161,6 @@ class MessageAcceptTests(unittest.TestCase):
         self.http_client.post_json = post_json
 
         def get_json(destination, path, args, headers=None):
-            print(destination, path, args)
             if path.startswith("/_matrix/federation/v1/state_ids/"):
                 d = self.successResultOf(
                     self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
@@ -182,7 +181,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Make a good event
@@ -227,7 +226,7 @@ class MessageAcceptTests(unittest.TestCase):
         )
         self.reactor.advance(1)
 
-        extrem = self.homeserver.datastore.get_latest_event_ids_in_room(self.room_id)
+        extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
 
         state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
-- 
cgit 1.5.1


From 35cc3e8b143f69abadbd41f82e463fbcd3528346 Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 11:32:09 +0100
Subject: stylistic cleanup

---
 tests/test_federation.py | 24 ++++++++++++++++++------
 1 file changed, 18 insertions(+), 6 deletions(-)

(limited to 'tests')

diff --git a/tests/test_federation.py b/tests/test_federation.py
index bc8b3af9b3..95fa73723c 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -35,7 +35,9 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+            maybeDeferred(
+                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+            )
         )[0]
 
         join_event = FrozenEvent(
@@ -72,7 +74,9 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+                maybeDeferred(
+                    self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+                )
             )[0],
             "$join:test.serv",
         )
@@ -92,7 +96,9 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+            maybeDeferred(
+                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+            )
         )[0]
 
         # Now lie about an event
@@ -129,7 +135,9 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+        extrem = maybeDeferred(
+            self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+        )
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
 
     @unittest.DEBUG
@@ -181,7 +189,9 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+            maybeDeferred(
+                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+            )
         )[0]
 
         # Make a good event
@@ -226,7 +236,9 @@ class MessageAcceptTests(unittest.TestCase):
         )
         self.reactor.advance(1)
 
-        extrem = maybeDeferred(self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id)
+        extrem = maybeDeferred(
+            self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+        )
         self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
 
         state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
-- 
cgit 1.5.1


From 94f09618e54e8ae0a30611f0da463d275768ab74 Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 11:38:03 +0100
Subject: cleanups

---
 tests/unittest.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

(limited to 'tests')

diff --git a/tests/unittest.py b/tests/unittest.py
index de24b1d2d4..b25f2db5d5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler):
     def emit(self, record):
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace('warning', 'warn')
-        self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")"))
+        self.tx_log.emit(
+            twisted.logger.LogLevel.levelWithName(log_level),
+            log_entry.replace("{", r"(").replace("}", r")"),
+        )
 
 
 handler = ToTwistedHandler()
-- 
cgit 1.5.1


From f03a5d1a1759221bd94d75604f3e4e787cd4133e Mon Sep 17 00:00:00 2001
From: Amber Brown <hawkowl@atleastfornow.net>
Date: Wed, 27 Jun 2018 11:38:14 +0100
Subject: pep8

---
 synapse/handlers/federation.py | 10 ++++------
 tests/test_federation.py       |  7 ++-----
 2 files changed, 6 insertions(+), 11 deletions(-)

(limited to 'tests')

diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e9c5d1026a..c4d96749a3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -169,11 +169,8 @@ class FederationHandler(BaseHandler):
                 defer.returnValue(None)
 
         state = None
-
         auth_chain = []
 
-        fetch_state = False
-
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
@@ -252,9 +249,10 @@ class FederationHandler(BaseHandler):
                     # Ask the remote server for the states we don't
                     # know about
                     for p in prevs - seen:
-                        state, got_auth_chain = yield self.replication_layer.get_state_for_room(
-                            origin, pdu.room_id, p
-                        )
+                        state, got_auth_chain = (
+                            yield self.replication_layer.get_state_for_room(
+                                origin, pdu.room_id, p
+                        ))
                         auth_chains.update(got_auth_chain)
                         state_group = {(x.type, x.state_key): x.event_id for x in state}
                         state_groups.append(state_group)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 95fa73723c..fc80a69369 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,18 +1,15 @@
 
-from twisted.internet.defer import Deferred, succeed, maybeDeferred
+from twisted.internet.defer import succeed, maybeDeferred
 
 from synapse.util import Clock
 from synapse.events import FrozenEvent
 from synapse.types import Requester, UserID
-from synapse.replication.slave.storage.events import SlavedEventStore
 
 from tests import unittest
-from tests.server import make_request, setup_test_homeserver, ThreadedMemoryReactorClock
+from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock
 
 from mock import Mock
 
-from synapse.api.errors import CodeMessageException, HttpResponseException
-
 
 class MessageAcceptTests(unittest.TestCase):
     def setUp(self):
-- 
cgit 1.5.1


From e72234f6bda33d89dcca07751e34c62b88215e9d Mon Sep 17 00:00:00 2001
From: Matthew Hodgson <matthew@matrix.org>
Date: Thu, 28 Jun 2018 20:56:07 +0100
Subject: fix tests

---
 synapse/config/appservice.py |  1 +
 tests/api/test_auth.py       | 18 +++++++++++++++---
 2 files changed, 16 insertions(+), 3 deletions(-)

(limited to 'tests')

diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 89c07f202f..0c27bb2fa7 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -157,6 +157,7 @@ def _load_appservice(hostname, as_info, config_filename):
             config_filename,
         )
 
+    ip_range_whitelist = None
     if as_info.get('ip_range_whitelist'):
         ip_range_whitelist = IPSet(
             as_info.get('ip_range_whitelist')
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4575dd9834..48bd411e49 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -86,11 +86,15 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token(self):
-        app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user,
+            ip_range_whitelist=None,
+        )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
         request.args["access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
@@ -119,12 +123,16 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
         masquerading_user_id = "@doppelganger:matrix.org"
-        app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user,
+            ip_range_whitelist=None,
+        )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -133,12 +141,16 @@ class AuthTestCase(unittest.TestCase):
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
         masquerading_user_id = "@doppelganger:matrix.org"
-        app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user,
+            ip_range_whitelist=None,
+        )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-- 
cgit 1.5.1


From f82cf3c7dfdcdbcf076dde1835796f2b274077c5 Mon Sep 17 00:00:00 2001
From: Matthew Hodgson <matthew@matrix.org>
Date: Thu, 28 Jun 2018 21:14:16 +0100
Subject: add test

---
 tests/api/test_auth.py | 33 +++++++++++++++++++++++++++++++++
 1 file changed, 33 insertions(+)

(limited to 'tests')

diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 48bd411e49..aec3b62897 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -100,6 +100,39 @@ class AuthTestCase(unittest.TestCase):
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
 
+    @defer.inlineCallbacks
+    def test_get_user_by_req_appservice_valid_token_good_ip(self):
+        from netaddr import IPSet
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user,
+            ip_range_whitelist=IPSet(["192.168/16"]),
+        )
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        self.store.get_user_by_access_token = Mock(return_value=None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "192.168.10.10"
+        request.args["access_token"] = [self.test_token]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = yield self.auth.get_user_by_req(request)
+        self.assertEquals(requester.user.to_string(), self.test_user)
+
+    def test_get_user_by_req_appservice_valid_token_bad_ip(self):
+        from netaddr import IPSet
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user,
+            ip_range_whitelist=IPSet(["192.168/16"]),
+        )
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        self.store.get_user_by_access_token = Mock(return_value=None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "131.111.8.42"
+        request.args["access_token"] = [self.test_token]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        d = self.auth.get_user_by_req(request)
+        self.failureResultOf(d, AuthError)
+
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
         self.store.get_user_by_access_token = Mock(return_value=None)
-- 
cgit 1.5.1


From 508196e08a834496daa1bfc5f561e69a430e270c Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Tue, 3 Jul 2018 14:36:14 +0100
Subject: Reject invalid server names (#3480)

Make sure that server_names used in auth headers are sane, and reject them with
a sensible error code, before they disappear off into the depths of the system.
---
 changelog.d/3480.feature               |  1 +
 synapse/federation/transport/server.py | 66 ++++++++++++++++++++++------------
 synapse/http/endpoint.py               | 34 ++++++++++++++++--
 tests/http/__init__.py                 |  0
 tests/http/test_endpoint.py            | 46 ++++++++++++++++++++++++
 5 files changed, 122 insertions(+), 25 deletions(-)
 create mode 100644 changelog.d/3480.feature
 create mode 100644 tests/http/__init__.py
 create mode 100644 tests/http/test_endpoint.py

(limited to 'tests')

diff --git a/changelog.d/3480.feature b/changelog.d/3480.feature
new file mode 100644
index 0000000000..a21580943d
--- /dev/null
+++ b/changelog.d/3480.feature
@@ -0,0 +1 @@
+Reject invalid server names in federation requests
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..1180d4b69d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -99,26 +100,6 @@ class Authenticator(object):
 
         origin = None
 
-        def parse_auth_header(header_str):
-            try:
-                params = auth.split(" ")[1].split(",")
-                param_dict = dict(kv.split("=") for kv in params)
-
-                def strip_quotes(value):
-                    if value.startswith("\""):
-                        return value[1:-1]
-                    else:
-                        return value
-
-                origin = strip_quotes(param_dict["origin"])
-                key = strip_quotes(param_dict["key"])
-                sig = strip_quotes(param_dict["sig"])
-                return (origin, key, sig)
-            except Exception:
-                raise AuthenticationError(
-                    400, "Malformed Authorization header", Codes.UNAUTHORIZED
-                )
-
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
 
         if not auth_headers:
@@ -127,8 +108,8 @@ class Authenticator(object):
             )
 
         for auth in auth_headers:
-            if auth.startswith("X-Matrix"):
-                (origin, key, sig) = parse_auth_header(auth)
+            if auth.startswith(b"X-Matrix"):
+                (origin, key, sig) = _parse_auth_header(auth)
                 json_request["origin"] = origin
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
@@ -165,6 +146,47 @@ class Authenticator(object):
             logger.exception("Error resetting retry timings on %s", origin)
 
 
+def _parse_auth_header(header_bytes):
+    """Parse an X-Matrix auth header
+
+    Args:
+        header_bytes (bytes): header value
+
+    Returns:
+        Tuple[str, str, str]: origin, key id, signature.
+
+    Raises:
+        AuthenticationError if the header could not be parsed
+    """
+    try:
+        header_str = header_bytes.decode('utf-8')
+        params = header_str.split(" ")[1].split(",")
+        param_dict = dict(kv.split("=") for kv in params)
+
+        def strip_quotes(value):
+            if value.startswith(b"\""):
+                return value[1:-1]
+            else:
+                return value
+
+        origin = strip_quotes(param_dict["origin"])
+        # ensure that the origin is a valid server name
+        parse_server_name(origin)
+
+        key = strip_quotes(param_dict["key"])
+        sig = strip_quotes(param_dict["sig"])
+        return origin, key, sig
+    except Exception as e:
+        logger.warn(
+            "Error parsing auth header '%s': %s",
+            header_bytes.decode('ascii', 'replace'),
+            e,
+        )
+        raise AuthenticationError(
+            400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+        )
+
+
 class BaseFederationServlet(object):
     REQUIRE_AUTH = True
 
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 80da870584..5a9cbb3324 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -38,6 +38,36 @@ _Server = collections.namedtuple(
 )
 
 
+def parse_server_name(server_name):
+    """Split a server name into host/port parts.
+
+    Does some basic sanity checking of the
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == ']':
+            # ipv6 literal, hopefully
+            if server_name[0] != '[':
+                raise Exception()
+
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
 def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
                                timeout=None):
     """Construct an endpoint for the given matrix destination.
@@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         timeout (int): connection timeout in seconds
     """
 
-    domain_port = destination.split(":")
-    domain = domain_port[0]
-    port = int(domain_port[1]) if domain_port[1:] else None
+    domain, port = parse_server_name(destination)
 
     endpoint_kw_args = {}
 
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
new file mode 100644
index 0000000000..cd74825c85
--- /dev/null
+++ b/tests/http/test_endpoint.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.http.endpoint import parse_server_name
+from tests import unittest
+
+
+class ServerNameTestCase(unittest.TestCase):
+    def test_parse_server_name(self):
+        test_data = {
+            'localhost': ('localhost', None),
+            'my-example.com:1234': ('my-example.com', 1234),
+            '1.2.3.4': ('1.2.3.4', None),
+            '[0abc:1def::1234]': ('[0abc:1def::1234]', None),
+            '1.2.3.4:1': ('1.2.3.4', 1),
+            '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
+        }
+
+        for i, o in test_data.items():
+            self.assertEqual(parse_server_name(i), o)
+
+    def test_parse_bad_server_names(self):
+        test_data = [
+            "",  # empty
+            "localhost:http",  # non-numeric port
+            "1234]",  # smells like ipv6 literal but isn't
+        ]
+        for i in test_data:
+            try:
+                parse_server_name(i)
+                self.fail(
+                    "Expected parse_server_name(\"%s\") to throw" % i,
+                )
+            except ValueError:
+                pass
-- 
cgit 1.5.1


From ea555d56331ad01edc9871ec7bf879df7d24dc7d Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Wed, 4 Jul 2018 09:35:40 +0100
Subject: Reinstate lost run_on_reactor in unit test

a61738b removed a call to run_on_reactor from a unit test, but that call was
doing something useful, in making the function in question asynchronous.

Reinstate the call and add a check that we are testing what we wanted to be
testing.
---
 changelog.d/3385.misc                 |  1 +
 tests/util/caches/test_descriptors.py | 17 +++++++++++++++--
 2 files changed, 16 insertions(+), 2 deletions(-)
 create mode 100644 changelog.d/3385.misc

(limited to 'tests')

diff --git a/changelog.d/3385.misc b/changelog.d/3385.misc
new file mode 100644
index 0000000000..92a91a1ca5
--- /dev/null
+++ b/changelog.d/3385.misc
@@ -0,0 +1 @@
+Reinstate lost run_on_reactor in unit tests
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 24754591df..a94d566c96 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -19,13 +19,19 @@ import logging
 import mock
 from synapse.api.errors import SynapseError
 from synapse.util import logcontext
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from synapse.util.caches import descriptors
 from tests import unittest
 
 logger = logging.getLogger(__name__)
 
 
+def run_on_reactor():
+    d = defer.Deferred()
+    reactor.callLater(0, d.callback, 0)
+    return logcontext.make_deferred_yieldable(d)
+
+
 class CacheTestCase(unittest.TestCase):
     def test_invalidate_all(self):
         cache = descriptors.Cache("testcache")
@@ -194,6 +200,8 @@ class DescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 @defer.inlineCallbacks
                 def inner_fn():
+                    # we want this to behave like an asynchronous function
+                    yield run_on_reactor()
                     raise SynapseError(400, "blah")
 
                 return inner_fn()
@@ -203,7 +211,12 @@ class DescriptorTestCase(unittest.TestCase):
             with logcontext.LoggingContext() as c1:
                 c1.name = "c1"
                 try:
-                    yield obj.fn(1)
+                    d = obj.fn(1)
+                    self.assertEqual(
+                        logcontext.LoggingContext.current_context(),
+                        logcontext.LoggingContext.sentinel,
+                    )
+                    yield d
                     self.fail("No exception thrown")
                 except SynapseError:
                     pass
-- 
cgit 1.5.1


From 546bc9e28b3d7758c732df8e120639d58d455164 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Wed, 4 Jul 2018 18:15:03 +0100
Subject: More server_name validation

We need to do a bit more validation when we get a server name, but don't want
to be re-doing it all over the shop, so factor out a separate
parse_and_validate_server_name, and do the extra validation.

Also, use it to verify the server name in the config file.
---
 changelog.d/3483.feature               |  1 +
 synapse/config/server.py               | 11 ++++++--
 synapse/federation/transport/server.py |  5 ++--
 synapse/http/endpoint.py               | 47 ++++++++++++++++++++++++++++++----
 tests/http/test_endpoint.py            | 17 +++++++++---
 5 files changed, 68 insertions(+), 13 deletions(-)
 create mode 100644 changelog.d/3483.feature

(limited to 'tests')

diff --git a/changelog.d/3483.feature b/changelog.d/3483.feature
new file mode 100644
index 0000000000..afa2fbbcba
--- /dev/null
+++ b/changelog.d/3483.feature
@@ -0,0 +1 @@
+Reject invalid server names in homeserver.yaml
\ No newline at end of file
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 968ecd9ea0..71fd51e4bc 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -16,6 +16,7 @@
 
 import logging
 
+from synapse.http.endpoint import parse_and_validate_server_name
 from ._base import Config, ConfigError
 
 logger = logging.Logger(__name__)
@@ -25,6 +26,12 @@ class ServerConfig(Config):
 
     def read_config(self, config):
         self.server_name = config["server_name"]
+
+        try:
+            parse_and_validate_server_name(self.server_name)
+        except ValueError as e:
+            raise ConfigError(str(e))
+
         self.pid_file = self.abspath(config.get("pid_file"))
         self.web_client = config["web_client"]
         self.web_client_location = config.get("web_client_location", None)
@@ -162,8 +169,8 @@ class ServerConfig(Config):
             })
 
     def default_config(self, server_name, **kwargs):
-        if ":" in server_name:
-            bind_port = int(server_name.split(":")[1])
+        _, bind_port = parse_and_validate_server_name(server_name)
+        if bind_port is not None:
             unsecure_port = bind_port - 400
         else:
             bind_port = 8448
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 1180d4b69d..e1fdcc89dc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError, FederationDeniedError
-from synapse.http.endpoint import parse_server_name
+from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -170,8 +170,9 @@ def _parse_auth_header(header_bytes):
                 return value
 
         origin = strip_quotes(param_dict["origin"])
+
         # ensure that the origin is a valid server name
-        parse_server_name(origin)
+        parse_and_validate_server_name(origin)
 
         key = strip_quotes(param_dict["key"])
         sig = strip_quotes(param_dict["sig"])
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 5a9cbb3324..1b1123b292 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
+
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer
 from twisted.internet.error import ConnectError
@@ -41,8 +43,6 @@ _Server = collections.namedtuple(
 def parse_server_name(server_name):
     """Split a server name into host/port parts.
 
-    Does some basic sanity checking of the
-
     Args:
         server_name (str): server name to parse
 
@@ -55,9 +55,6 @@ def parse_server_name(server_name):
     try:
         if server_name[-1] == ']':
             # ipv6 literal, hopefully
-            if server_name[0] != '[':
-                raise Exception()
-
             return server_name, None
 
         domain_port = server_name.rsplit(":", 1)
@@ -68,6 +65,46 @@ def parse_server_name(server_name):
         raise ValueError("Invalid server name '%s'" % server_name)
 
 
+VALID_HOST_REGEX = re.compile(
+    "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == '[':
+        if host[-1] != ']':
+            raise ValueError("Mismatched [...] in server name '%s'" % (
+                server_name,
+            ))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' contains invalid characters" % (
+            server_name,
+        ))
+
+    return host, port
+
+
 def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
                                timeout=None):
     """Construct an endpoint for the given matrix destination.
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index cd74825c85..b8a48d20a4 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -12,7 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.http.endpoint import parse_server_name
+from synapse.http.endpoint import (
+    parse_server_name,
+    parse_and_validate_server_name,
+)
 from tests import unittest
 
 
@@ -30,17 +33,23 @@ class ServerNameTestCase(unittest.TestCase):
         for i, o in test_data.items():
             self.assertEqual(parse_server_name(i), o)
 
-    def test_parse_bad_server_names(self):
+    def test_validate_bad_server_names(self):
         test_data = [
             "",  # empty
             "localhost:http",  # non-numeric port
             "1234]",  # smells like ipv6 literal but isn't
+            "[1234",
+            "underscore_.com",
+            "percent%65.com",
+            "1234:5678:80",   # too many colons
         ]
         for i in test_data:
             try:
-                parse_server_name(i)
+                parse_and_validate_server_name(i)
                 self.fail(
-                    "Expected parse_server_name(\"%s\") to throw" % i,
+                    "Expected parse_and_validate_server_name('%s') to throw" % (
+                        i,
+                    ),
                 )
             except ValueError:
                 pass
-- 
cgit 1.5.1


From 3cf3e08a97f4617763ce10da4f127c0e21d7ff1d Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Wed, 4 Jul 2018 15:31:00 +0100
Subject: Implementation of server_acls

... as described at
https://docs.google.com/document/d/1EttUVzjc2DWe2ciw4XPtNpUpIl9lWXGEsy2ewDS7rtw.
---
 synapse/api/constants.py                   |   2 +
 synapse/federation/federation_server.py    | 150 ++++++++++++++++++++++++++++-
 synapse/federation/transport/server.py     |   8 +-
 tests/federation/__init__.py               |   0
 tests/federation/test_federation_server.py |  57 +++++++++++
 5 files changed, 213 insertions(+), 4 deletions(-)
 create mode 100644 tests/federation/__init__.py
 create mode 100644 tests/federation/test_federation_server.py

(limited to 'tests')

diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 5baba43966..4df930c8d1 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -76,6 +76,8 @@ class EventTypes(object):
     Topic = "m.room.topic"
     Name = "m.room.name"
 
+    ServerACL = "m.room.server_acl"
+
 
 class RejectedReason(object):
     AUTH_ERROR = "auth_error"
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index fe51ba6806..591d0026bf 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,10 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
 
 from canonicaljson import json
+import six
 from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
 
+from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.federation.federation_base import (
@@ -27,6 +31,7 @@ from synapse.federation.federation_base import (
 
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
 from synapse.types import get_domain_from_id
 from synapse.util import async
 from synapse.util.caches.response_cache import ResponseCache
@@ -74,6 +79,9 @@ class FederationServer(FederationBase):
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             pdus = yield self.handler.on_backfill_request(
                 origin, room_id, versions, limit
             )
@@ -134,6 +142,8 @@ class FederationServer(FederationBase):
 
         received_pdus_counter.inc(len(transaction.pdus))
 
+        origin_host, _ = parse_server_name(transaction.origin)
+
         pdus_by_room = {}
 
         for p in transaction.pdus:
@@ -154,9 +164,21 @@ class FederationServer(FederationBase):
         # we can process different rooms in parallel (which is useful if they
         # require callouts to other servers to fetch missing events), but
         # impose a limit to avoid going too crazy with ram/cpu.
+
         @defer.inlineCallbacks
         def process_pdus_for_room(room_id):
             logger.debug("Processing PDUs for %s", room_id)
+            try:
+                yield self.check_server_matches_acl(origin_host, room_id)
+            except AuthError as e:
+                logger.warn(
+                    "Ignoring PDUs for room %s from banned server", room_id,
+                )
+                for pdu in pdus_by_room[room_id]:
+                    event_id = pdu.event_id
+                    pdu_results[event_id] = e.error_dict()
+                return
+
             for pdu in pdus_by_room[room_id]:
                 event_id = pdu.event_id
                 try:
@@ -211,6 +233,9 @@ class FederationServer(FederationBase):
         if not event_id:
             raise NotImplementedError("Specify an event")
 
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
+
         in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -234,6 +259,9 @@ class FederationServer(FederationBase):
         if not event_id:
             raise NotImplementedError("Specify an event")
 
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
+
         in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -298,7 +326,9 @@ class FederationServer(FederationBase):
         defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
-    def on_make_join_request(self, room_id, user_id):
+    def on_make_join_request(self, origin, room_id, user_id):
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
         pdu = yield self.handler.on_make_join_request(room_id, user_id)
         time_now = self._clock.time_msec()
         defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -306,6 +336,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
         pdu = event_from_pdu_json(content)
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -314,6 +346,10 @@ class FederationServer(FederationBase):
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
         pdu = event_from_pdu_json(content)
+
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -325,7 +361,9 @@ class FederationServer(FederationBase):
         }))
 
     @defer.inlineCallbacks
-    def on_make_leave_request(self, room_id, user_id):
+    def on_make_leave_request(self, origin, room_id, user_id):
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
         pdu = yield self.handler.on_make_leave_request(room_id, user_id)
         time_now = self._clock.time_msec()
         defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -334,6 +372,10 @@ class FederationServer(FederationBase):
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
         pdu = event_from_pdu_json(content)
+
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -341,6 +383,9 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_event_auth(self, origin, room_id, event_id):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             time_now = self._clock.time_msec()
             auth_pdus = yield self.handler.on_event_auth(event_id)
             res = {
@@ -369,6 +414,9 @@ class FederationServer(FederationBase):
             Deferred: Results in `dict` with the same format as `content`
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             auth_chain = [
                 event_from_pdu_json(e)
                 for e in content["auth_chain"]
@@ -442,6 +490,9 @@ class FederationServer(FederationBase):
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             logger.info(
                 "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                 " limit: %d, min_depth: %d",
@@ -579,6 +630,101 @@ class FederationServer(FederationBase):
         )
         defer.returnValue(ret)
 
+    @defer.inlineCallbacks
+    def check_server_matches_acl(self, server_name, room_id):
+        """Check if the given server is allowed by the server ACLs in the room
+
+        Args:
+            server_name (str): name of server, *without any port part*
+            room_id (str): ID of the room to check
+
+        Raises:
+            AuthError if the server does not match the ACL
+        """
+        state_ids = yield self.store.get_current_state_ids(room_id)
+        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+        if not acl_event_id:
+            return
+
+        acl_event = yield self.store.get_event(acl_event_id)
+        if server_matches_acl_event(server_name, acl_event):
+            return
+
+        raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+    """Check if the given server is allowed by the ACL event
+
+    Args:
+        server_name (str): name of server, without any port part
+        acl_event (EventBase): m.room.server_acl event
+
+    Returns:
+        bool: True if this server is allowed by the ACLs
+    """
+    logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+    # first of all, check if literal IPs are blocked, and if so, whether the
+    # server name is a literal IP
+    allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+    if not isinstance(allow_ip_literals, bool):
+        logger.warn("Ignorning non-bool allow_ip_literals flag")
+        allow_ip_literals = True
+    if not allow_ip_literals:
+        # check for ipv6 literals. These start with '['.
+        if server_name[0] == '[':
+            return False
+
+        # check for ipv4 literals. We can just lift the routine from twisted.
+        if isIPAddress(server_name):
+            return False
+
+    # next,  check the deny list
+    deny = acl_event.content.get("deny", [])
+    if not isinstance(deny, (list, tuple)):
+        logger.warn("Ignorning non-list deny ACL %s", deny)
+        deny = []
+    for e in deny:
+        if _acl_entry_matches(server_name, e):
+            # logger.info("%s matched deny rule %s", server_name, e)
+            return False
+
+    # then the allow list.
+    allow = acl_event.content.get("allow", [])
+    if not isinstance(allow, (list, tuple)):
+        logger.warn("Ignorning non-list allow ACL %s", allow)
+        allow = []
+    for e in allow:
+        if _acl_entry_matches(server_name, e):
+            # logger.info("%s matched allow rule %s", server_name, e)
+            return True
+
+    # everything else should be rejected.
+    # logger.info("%s fell through", server_name)
+    return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+    if not isinstance(acl_entry, six.string_types):
+        logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+        return False
+    regex = _glob_to_regex(acl_entry)
+    return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+    res = ''
+    for c in glob:
+        if c == '*':
+            res = res + '.*'
+        elif c == '?':
+            res = res + '.'
+        else:
+            res = res + re.escape(c)
+    return re.compile(res + "\\Z", re.IGNORECASE)
+
 
 class FederationHandlerRegistry(object):
     """Allows classes to register themselves as handlers for a given EDU or
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index e1fdcc89dc..c6d98d35cb 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -385,7 +385,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_join_request(context, user_id)
+        content = yield self.handler.on_make_join_request(
+            origin, context, user_id,
+        )
         defer.returnValue((200, content))
 
 
@@ -394,7 +396,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_leave_request(context, user_id)
+        content = yield self.handler.on_make_leave_request(
+            origin, context, user_id,
+        )
         defer.returnValue((200, content))
 
 
diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
new file mode 100644
index 0000000000..4e8dc8fea0
--- /dev/null
+++ b/tests/federation/test_federation_server.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.events import FrozenEvent
+from synapse.federation.federation_server import server_matches_acl_event
+from tests import unittest
+
+
+@unittest.DEBUG
+class ServerACLsTestCase(unittest.TestCase):
+    def test_blacklisted_server(self):
+        e = _create_acl_event({
+            "allow": ["*"],
+            "deny": ["evil.com"],
+        })
+        logging.info("ACL event: %s", e.content)
+
+        self.assertFalse(server_matches_acl_event("evil.com", e))
+        self.assertFalse(server_matches_acl_event("EVIL.COM", e))
+
+        self.assertTrue(server_matches_acl_event("evil.com.au", e))
+        self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
+
+    def test_block_ip_literals(self):
+        e = _create_acl_event({
+            "allow_ip_literals": False,
+            "allow": ["*"],
+        })
+        logging.info("ACL event: %s", e.content)
+
+        self.assertFalse(server_matches_acl_event("1.2.3.4", e))
+        self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
+        self.assertFalse(server_matches_acl_event("[1:2::]", e))
+        self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+
+
+def _create_acl_event(content):
+    return FrozenEvent({
+        "room_id": "!a:b",
+        "event_id": "$a:b",
+        "type": "m.room.server_acls",
+        "sender": "@a:b",
+        "content": content
+    })
-- 
cgit 1.5.1