summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4496.misc1
-rw-r--r--changelog.d/4506.misc1
-rw-r--r--synapse/federation/federation_client.py64
-rw-r--r--synapse/federation/transport/client.py39
-rw-r--r--tests/test_utils/__init__.py18
-rw-r--r--tests/test_utils/logging_setup.py54
-rw-r--r--tests/unittest.py37
7 files changed, 170 insertions, 44 deletions
diff --git a/changelog.d/4496.misc b/changelog.d/4496.misc
new file mode 100644
index 0000000000..43f8963614
--- /dev/null
+++ b/changelog.d/4496.misc
@@ -0,0 +1 @@
+Add infrastructure to support different event formats
diff --git a/changelog.d/4506.misc b/changelog.d/4506.misc
new file mode 100644
index 0000000000..ea0e7d9580
--- /dev/null
+++ b/changelog.d/4506.misc
@@ -0,0 +1 @@
+Make it possible to set the log level for tests via an environment variable
\ No newline at end of file
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 71809893c5..cacb1c8aaf 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -751,18 +751,9 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
-        time_now = self._clock.time_msec()
-        try:
-            code, content = yield self.transport_layer.send_invite(
-                destination=destination,
-                room_id=room_id,
-                event_id=event_id,
-                content=pdu.get_pdu_json(time_now),
-            )
-        except HttpResponseException as e:
-            if e.code == 403:
-                raise e.to_synapse_error()
-            raise
+        room_version = yield self.store.get_room_version(room_id)
+
+        content = yield self._do_send_invite(destination, pdu, room_version)
 
         pdu_dict = content["event"]
 
@@ -780,6 +771,55 @@ class FederationClient(FederationBase):
 
         defer.returnValue(pdu)
 
+    @defer.inlineCallbacks
+    def _do_send_invite(self, destination, pdu, room_version):
+        """Actually sends the invite, first trying v2 API and falling back to
+        v1 API if necessary.
+
+        Args:
+            destination (str): Target server
+            pdu (FrozenEvent)
+            room_version (str)
+
+        Returns:
+            dict: The event as a dict as returned by the remote server
+        """
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_invite_v2(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content={
+                    "event": pdu.get_pdu_json(time_now),
+                    "room_version": room_version,
+                    "invite_room_state": pdu.unsigned.get("invite_room_state", []),
+                },
+            )
+            defer.returnValue(content)
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                if room_version in (RoomVersions.V1, RoomVersions.V2):
+                    pass  # We'll fall through
+                else:
+                    raise Exception("Remote server is too old")
+            elif e.code == 403:
+                raise e.to_synapse_error()
+            else:
+                raise
+
+        # Didn't work, try v1 API.
+        # Note the v1 API returns a tuple of `(200, content)`
+
+        _, content = yield self.transport_layer.send_invite_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+        defer.returnValue(content)
+
     def send_leave(self, destinations, pdu):
         """Sends a leave event to one of a list of homeservers.
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 260178c47b..8e2be218e2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -21,7 +21,7 @@ from six.moves import urllib
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
-from synapse.api.urls import FEDERATION_V1_PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
 from synapse.util.logutils import log_function
 
 logger = logging.getLogger(__name__)
@@ -289,7 +289,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_invite(self, destination, room_id, event_id, content):
+    def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -303,6 +303,20 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def send_invite_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/invite/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
     def get_public_rooms(self, remote_server, limit, since_token,
                          search_filter=None, include_all_networks=False,
                          third_party_instance_id=None):
@@ -958,3 +972,24 @@ def _create_v1_path(path, *args):
         FEDERATION_V1_PREFIX
         + path % tuple(urllib.parse.quote(arg, "") for arg in args)
     )
+
+
+def _create_v2_path(path, *args):
+    """Creates a path against V2 federation API from the path template and
+    args. Ensures that all args are url encoded.
+
+    Example:
+
+        _create_v2_path("/event/%s/", event_id)
+
+    Args:
+        path (str): String template for the path
+        args: ([str]): Args to insert into path. Each arg will be url encoded
+
+    Returns:
+        str
+    """
+    return (
+        FEDERATION_V2_PREFIX
+        + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    )
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
new file mode 100644
index 0000000000..a7310cf12a
--- /dev/null
+++ b/tests/test_utils/__init__.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utilities for running the unit tests
+"""
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
new file mode 100644
index 0000000000..d0bc8e2112
--- /dev/null
+++ b/tests/test_utils/logging_setup.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+
+import twisted.logger
+
+from synapse.util.logcontext import LoggingContextFilter
+
+
+class ToTwistedHandler(logging.Handler):
+    """logging handler which sends the logs to the twisted log"""
+    tx_log = twisted.logger.Logger()
+
+    def emit(self, record):
+        log_entry = self.format(record)
+        log_level = record.levelname.lower().replace('warning', 'warn')
+        self.tx_log.emit(
+            twisted.logger.LogLevel.levelWithName(log_level),
+            log_entry.replace("{", r"(").replace("}", r")"),
+        )
+
+
+def setup_logging():
+    """Configure the python logging appropriately for the tests.
+
+    (Logs will end up in _trial_temp.)
+    """
+    root_logger = logging.getLogger()
+
+    log_format = (
+        "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
+    )
+
+    handler = ToTwistedHandler()
+    formatter = logging.Formatter(log_format)
+    handler.setFormatter(formatter)
+    handler.addFilter(LoggingContextFilter(request=""))
+    root_logger.addHandler(handler)
+
+    log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
+    root_logger.setLevel(log_level)
diff --git a/tests/unittest.py b/tests/unittest.py
index cda549c783..fac254ff10 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -31,38 +31,14 @@ from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContext, LoggingContextFilter
+from synapse.util.logcontext import LoggingContext
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
 setupdb()
-
-# Set up putting Synapse's logs into Trial's.
-rootLogger = logging.getLogger()
-
-log_format = (
-    "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
-)
-
-
-class ToTwistedHandler(logging.Handler):
-    tx_log = twisted.logger.Logger()
-
-    def emit(self, record):
-        log_entry = self.format(record)
-        log_level = record.levelname.lower().replace('warning', 'warn')
-        self.tx_log.emit(
-            twisted.logger.LogLevel.levelWithName(log_level),
-            log_entry.replace("{", r"(").replace("}", r")"),
-        )
-
-
-handler = ToTwistedHandler()
-formatter = logging.Formatter(log_format)
-handler.setFormatter(formatter)
-handler.addFilter(LoggingContextFilter(request=""))
-rootLogger.addHandler(handler)
+setup_logging()
 
 
 def around(target):
@@ -96,7 +72,7 @@ class TestCase(unittest.TestCase):
 
         method = getattr(self, methodName)
 
-        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
+        level = getattr(method, "loglevel", getattr(self, "loglevel", None))
 
         @around(self)
         def setUp(orig):
@@ -114,7 +90,7 @@ class TestCase(unittest.TestCase):
                 )
 
             old_level = logging.getLogger().level
-            if old_level != level:
+            if level is not None and old_level != level:
 
                 @around(self)
                 def tearDown(orig):
@@ -122,7 +98,8 @@ class TestCase(unittest.TestCase):
                     logging.getLogger().setLevel(old_level)
                     return ret
 
-            logging.getLogger().setLevel(level)
+                logging.getLogger().setLevel(level)
+
             return orig()
 
         @around(self)