summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/config/api.py2
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/http/federation/matrix_federation_agent.py90
-rw-r--r--synapse/storage/state.py38
6 files changed, 115 insertions, 42 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index fedfb92b3e..f47c33a074 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -73,6 +73,7 @@ class EventTypes(object):
     RoomHistoryVisibility = "m.room.history_visibility"
     CanonicalAlias = "m.room.canonical_alias"
     RoomAvatar = "m.room.avatar"
+    RoomEncryption = "m.room.encryption"
     GuestAccess = "m.room.guest_access"
 
     # These are used for validation
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 403d96ba76..9f25bbc5cb 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -24,6 +24,7 @@ class ApiConfig(Config):
             EventTypes.JoinRules,
             EventTypes.CanonicalAlias,
             EventTypes.RoomAvatar,
+            EventTypes.RoomEncryption,
             EventTypes.Name,
         ])
 
@@ -36,5 +37,6 @@ class ApiConfig(Config):
             - "{JoinRules}"
             - "{CanonicalAlias}"
             - "{RoomAvatar}"
+            - "{RoomEncryption}"
             - "{Name}"
         """.format(**vars(EventTypes))
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index aeadc9c564..3da86d4ba6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -148,6 +148,22 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
+        # Reject if PDU count > 50 and EDU count > 100
+        if (len(transaction.pdus) > 50
+                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+
+            logger.info(
+                "Transaction PDU or EDU count too large. Returning 400",
+            )
+
+            response = {}
+            yield self.transaction_actions.set_response(
+                origin,
+                transaction,
+                400, response
+            )
+            defer.returnValue((400, response))
+
         received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 13ba9291b0..5e40e9ea46 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -263,6 +263,16 @@ class RoomCreationHandler(BaseHandler):
             }
         }
 
+        # Check if old room was non-federatable
+
+        # Get old room's create event
+        old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
+
+        # Check if the create event specified a non-federatable room
+        if not old_room_create_event.content.get("m.federate", True):
+            # If so, mark the new room as non-federatable as well
+            creation_content["m.federate"] = False
+
         initial_state = dict()
 
         # Replicate relevant room events
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 26649e70be..384d8a37a2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -23,14 +23,17 @@ from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IStreamClientEndpoint
 from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 
 from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util import Clock
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.metrics import Measure
 
 # period to cache .well-known results for by default
 WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
@@ -44,7 +47,6 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
 # cap for .well-known cache period
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 
-
 logger = logging.getLogger(__name__)
 well_known_cache = TTLCache('well-known')
 
@@ -78,6 +80,8 @@ class MatrixFederationAgent(object):
         _well_known_cache=well_known_cache,
     ):
         self._reactor = reactor
+        self._clock = Clock(reactor)
+
         self._tls_client_options_factory = tls_client_options_factory
         if _srv_resolver is None:
             _srv_resolver = SrvResolver()
@@ -98,6 +102,10 @@ class MatrixFederationAgent(object):
         )
         self._well_known_agent = _well_known_agent
 
+        # our cache of .well-known lookup results, mapping from server name
+        # to delegated name. The values can be:
+        #   `bytes`:     a valid server-name
+        #   `None`:      there is no (valid) .well-known here
         self._well_known_cache = _well_known_cache
 
     @defer.inlineCallbacks
@@ -152,12 +160,9 @@ class MatrixFederationAgent(object):
         class EndpointFactory(object):
             @staticmethod
             def endpointForURI(_uri):
-                logger.info(
-                    "Connecting to %s:%i",
-                    res.target_host.decode("ascii"),
-                    res.target_port,
+                ep = LoggingHostnameEndpoint(
+                    self._reactor, res.target_host, res.target_port,
                 )
-                ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
                 if tls_options is not None:
                     ep = wrapClientTLS(tls_options, ep)
                 return ep
@@ -210,11 +215,7 @@ class MatrixFederationAgent(object):
                 target_port=parsed_uri.port,
             ))
 
-        # try a SRV lookup
-        service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
-        server_list = yield self._srv_resolver.resolve_service(service_name)
-
-        if not server_list and lookup_well_known:
+        if lookup_well_known:
             # try a .well-known lookup
             well_known_server = yield self._get_well_known(parsed_uri.host)
 
@@ -250,6 +251,10 @@ class MatrixFederationAgent(object):
                 res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
                 defer.returnValue(res)
 
+        # try a SRV lookup
+        service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+        server_list = yield self._srv_resolver.resolve_service(service_name)
+
         if not server_list:
             target_host = parsed_uri.host
             port = 8448
@@ -283,14 +288,32 @@ class MatrixFederationAgent(object):
                 None if there was no .well-known file.
         """
         try:
-            cached = self._well_known_cache[server_name]
-            defer.returnValue(cached)
+            result = self._well_known_cache[server_name]
         except KeyError:
-            pass
+            # TODO: should we linearise so that we don't end up doing two .well-known
+            # requests for the same server in parallel?
+            with Measure(self._clock, "get_well_known"):
+                result, cache_period = yield self._do_get_well_known(server_name)
+
+            if cache_period > 0:
+                self._well_known_cache.set(server_name, result, cache_period)
 
-        # TODO: should we linearise so that we don't end up doing two .well-known requests
-        # for the same server in parallel?
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _do_get_well_known(self, server_name):
+        """Actually fetch and parse a .well-known, without checking the cache
 
+        Args:
+            server_name (bytes): name of the server, from the requested url
+
+        Returns:
+            Deferred[Tuple[bytes|None|object],int]:
+                result, cache period, where result is one of:
+                 - the new server name from the .well-known (as a `bytes`)
+                 - None if there was no .well-known file.
+                 - INVALID_WELL_KNOWN if the .well-known was invalid
+        """
         uri = b"https://%s/.well-known/matrix/server" % (server_name, )
         uri_str = uri.decode("ascii")
         logger.info("Fetching %s", uri_str)
@@ -301,18 +324,7 @@ class MatrixFederationAgent(object):
             body = yield make_deferred_yieldable(readBody(response))
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code, ))
-        except Exception as e:
-            logger.info("Error fetching %s: %s", uri_str, e)
-
-            # add some randomness to the TTL to avoid a stampeding herd every hour
-            # after startup
-            cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
-            cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
 
-            self._well_known_cache.set(server_name, None, cache_period)
-            defer.returnValue(None)
-
-        try:
             parsed_body = json.loads(body.decode('utf-8'))
             logger.info("Response from .well-known: %s", parsed_body)
             if not isinstance(parsed_body, dict):
@@ -320,7 +332,13 @@ class MatrixFederationAgent(object):
             if "m.server" not in parsed_body:
                 raise Exception("Missing key 'm.server'")
         except Exception as e:
-            raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,))
+            logger.info("Error fetching %s: %s", uri_str, e)
+
+            # add some randomness to the TTL to avoid a stampeding herd every hour
+            # after startup
+            cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+            cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+            defer.returnValue((None, cache_period))
 
         result = parsed_body["m.server"].encode("ascii")
 
@@ -336,10 +354,20 @@ class MatrixFederationAgent(object):
         else:
             cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
 
-        if cache_period > 0:
-            self._well_known_cache.set(server_name, result, cache_period)
+        defer.returnValue((result, cache_period))
 
-        defer.returnValue(result)
+
+@implementer(IStreamClientEndpoint)
+class LoggingHostnameEndpoint(object):
+    """A wrapper for HostnameEndpint which logs when it connects"""
+    def __init__(self, reactor, host, port, *args, **kwargs):
+        self.host = host
+        self.port = port
+        self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+
+    def connect(self, protocol_factory):
+        logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
+        return self.ep.connect(protocol_factory)
 
 
 def _cache_period_from_headers(headers, time_now=time.time):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c3ab7db7ae..d14a7b2538 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -428,13 +428,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         # for now we do this by looking at the create event. We may want to cache this
         # more intelligently in future.
-        state_ids = yield self.get_current_state_ids(room_id)
-        create_id = state_ids.get((EventTypes.Create, ""))
-
-        if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
 
-        create_event = yield self.get_event(create_id)
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
         defer.returnValue(create_event.content.get("room_version", "1"))
 
     @defer.inlineCallbacks
@@ -447,19 +443,39 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         Returns:
             Deferred[unicode|None]: predecessor room id
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+
+        # Return predecessor if present
+        defer.returnValue(create_event.content.get("predecessor", None))
+
+    @defer.inlineCallbacks
+    def get_create_event_for_room(self, room_id):
+        """Get the create state event for a room.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[EventBase]: The room creation event.
+
+        Raises:
+            NotFoundError if the room is unknown
         """
         state_ids = yield self.get_current_state_ids(room_id)
         create_id = state_ids.get((EventTypes.Create, ""))
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            defer.returnValue(None)
+            raise NotFoundError("Unknown room %s" % (room_id))
 
-        # Retrieve the room's create event
+        # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
-
-        # Return predecessor if present
-        defer.returnValue(create_event.content.get("predecessor", None))
+        defer.returnValue(create_event)
 
     @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):