summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-08-08 11:24:53 +0100
committerErik Johnston <erik@matrix.org>2016-08-08 11:24:53 +0100
commit34101427417f58418a9587adfeca3d41898b5e24 (patch)
treef1cb2f39d3838b1dcf709c3f3921dbe82afdd3a5
parentUpdate changelog (diff)
parentMerge pull request #992 from matrix-org/erikj/psutil_conditional (diff)
downloadsynapse-34101427417f58418a9587adfeca3d41898b5e24.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.17.0
-rw-r--r--scripts-dev/federation_client.py1
-rw-r--r--synapse/app/federation_reader.py2
-rwxr-xr-xsynapse/app/homeserver.py2
-rw-r--r--synapse/app/pusher.py2
-rw-r--r--synapse/app/synchrotron.py2
-rw-r--r--synapse/federation/federation_client.py35
-rw-r--r--synapse/federation/transport/server.py113
-rw-r--r--synapse/metrics/__init__.py13
-rw-r--r--synapse/metrics/metric.py5
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py3
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/util/versionstring.py8
13 files changed, 124 insertions, 68 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 59c3dce3d7..d1ab42d3af 100644
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -143,6 +143,7 @@ def main():
     )
 
     json.dump(result, sys.stdout)
+    print ""
 
 if __name__ == "__main__":
     main()
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 58d425f9ac..7355499ae2 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -165,7 +165,7 @@ def start(config_options):
         db_config=config.database_config,
         tls_server_context_factory=tls_server_context_factory,
         config=config,
-        version_string=get_version_string("Synapse", synapse),
+        version_string="Synapse/" + get_version_string(synapse),
         database_engine=database_engine,
     )
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index fe68ceb07c..40e6f65236 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -285,7 +285,7 @@ def setup(config_options):
     # check any extra requirements we have now we have a config
     check_requirements(config)
 
-    version_string = get_version_string("Synapse", synapse)
+    version_string = "Synapse/" + get_version_string(synapse)
 
     logger.info("Server hostname: %s", config.server_name)
     logger.info("Server version: %s", version_string)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 4f1d18ab5f..c8dde0fcb8 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -273,7 +273,7 @@ def start(config_options):
         config.server_name,
         db_config=config.database_config,
         config=config,
-        version_string=get_version_string("Synapse", synapse),
+        version_string="Synapse/" + get_version_string(synapse),
         database_engine=database_engine,
     )
 
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 8cf5bbbb6d..215ccfd522 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -424,7 +424,7 @@ def start(config_options):
         config.server_name,
         db_config=config.database_config,
         config=config,
-        version_string=get_version_string("Synapse", synapse),
+        version_string="Synapse/" + get_version_string(synapse),
         database_engine=database_engine,
         application_service_handler=SynchrotronApplicationService(),
     )
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 65778fd4ee..da95c2ad6d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -236,9 +236,9 @@ class FederationClient(FederationBase):
         # TODO: Rate limit the number of times we try and get the same event.
 
         if self._get_pdu_cache:
-            e = self._get_pdu_cache.get(event_id)
-            if e:
-                defer.returnValue(e)
+            ev = self._get_pdu_cache.get(event_id)
+            if ev:
+                defer.returnValue(ev)
 
         pdu = None
         for destination in destinations:
@@ -269,7 +269,7 @@ class FederationClient(FederationBase):
 
                         break
 
-            except SynapseError:
+            except SynapseError as e:
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
@@ -336,8 +336,10 @@ class FederationClient(FederationBase):
                 ev.event_id: ev for ev in fetched_events
             }
 
-            pdus = [event_map[e_id] for e_id in state_event_ids]
-            auth_chain = [event_map[e_id] for e_id in auth_event_ids]
+            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+            auth_chain = [
+                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
+            ]
 
             auth_chain.sort(key=lambda e: e.depth)
 
@@ -523,14 +525,19 @@ class FederationClient(FederationBase):
                     (destination, self.event_from_pdu_json(pdu_dict))
                 )
                 break
-            except CodeMessageException:
-                raise
+            except CodeMessageException as e:
+                if not 500 <= e.code < 600:
+                    raise
+                else:
+                    logger.warn(
+                        "Failed to make_%s via %s: %s",
+                        membership, destination, e.message
+                    )
             except Exception as e:
                 logger.warn(
                     "Failed to make_%s via %s: %s",
                     membership, destination, e.message
                 )
-                raise
 
         raise RuntimeError("Failed to send to any server.")
 
@@ -602,8 +609,14 @@ class FederationClient(FederationBase):
                     "auth_chain": signed_auth,
                     "origin": destination,
                 })
-            except CodeMessageException:
-                raise
+            except CodeMessageException as e:
+                if not 500 <= e.code < 600:
+                    raise
+                else:
+                    logger.exception(
+                        "Failed to send_join via %s: %s",
+                        destination, e.message
+                    )
             except Exception as e:
                 logger.exception(
                     "Failed to send_join via %s: %s",
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 0bc6e0801d..37c0d4fbc4 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,13 +18,14 @@ from twisted.internet import defer
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import JsonResource
-from synapse.http.servlet import parse_json_object_from_request, parse_string
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.versionstring import get_version_string
 
 import functools
 import logging
-import simplejson as json
 import re
+import synapse
 
 
 logger = logging.getLogger(__name__)
@@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
         )
 
 
+class AuthenticationError(SynapseError):
+    """There was a problem authenticating the request"""
+    pass
+
+
+class NoAuthenticationError(AuthenticationError):
+    """The request had no authentication information"""
+    pass
+
+
 class Authenticator(object):
     def __init__(self, hs):
         self.keyring = hs.get_keyring()
@@ -67,7 +78,7 @@ class Authenticator(object):
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
-    def authenticate_request(self, request):
+    def authenticate_request(self, request, content):
         json_request = {
             "method": request.method,
             "uri": request.uri,
@@ -75,17 +86,10 @@ class Authenticator(object):
             "signatures": {},
         }
 
-        content = None
-        origin = None
+        if content is not None:
+            json_request["content"] = content
 
-        if request.method in ["PUT", "POST"]:
-            # TODO: Handle other method types? other content types?
-            try:
-                content_bytes = request.content.read()
-                content = json.loads(content_bytes)
-                json_request["content"] = content
-            except:
-                raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+        origin = None
 
         def parse_auth_header(header_str):
             try:
@@ -103,14 +107,14 @@ class Authenticator(object):
                 sig = strip_quotes(param_dict["sig"])
                 return (origin, key, sig)
             except:
-                raise SynapseError(
+                raise AuthenticationError(
                     400, "Malformed Authorization header", Codes.UNAUTHORIZED
                 )
 
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
 
         if not auth_headers:
-            raise SynapseError(
+            raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
             )
 
@@ -121,7 +125,7 @@ class Authenticator(object):
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
         if not json_request["signatures"]:
-            raise SynapseError(
+            raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
             )
 
@@ -130,10 +134,12 @@ class Authenticator(object):
         logger.info("Request from %s", origin)
         request.authenticated_entity = origin
 
-        defer.returnValue((origin, content))
+        defer.returnValue(origin)
 
 
 class BaseFederationServlet(object):
+    REQUIRE_AUTH = True
+
     def __init__(self, handler, authenticator, ratelimiter, server_name,
                  room_list_handler):
         self.handler = handler
@@ -141,29 +147,46 @@ class BaseFederationServlet(object):
         self.ratelimiter = ratelimiter
         self.room_list_handler = room_list_handler
 
-    def _wrap(self, code):
+    def _wrap(self, func):
         authenticator = self.authenticator
         ratelimiter = self.ratelimiter
 
         @defer.inlineCallbacks
-        @functools.wraps(code)
-        def new_code(request, *args, **kwargs):
+        @functools.wraps(func)
+        def new_func(request, *args, **kwargs):
+            content = None
+            if request.method in ["PUT", "POST"]:
+                # TODO: Handle other method types? other content types?
+                content = parse_json_object_from_request(request)
+
             try:
-                (origin, content) = yield authenticator.authenticate_request(request)
+                origin = yield authenticator.authenticate_request(request, content)
+            except NoAuthenticationError:
+                origin = None
+                if self.REQUIRE_AUTH:
+                    logger.exception("authenticate_request failed")
+                    raise
+            except:
+                logger.exception("authenticate_request failed")
+                raise
+
+            if origin:
                 with ratelimiter.ratelimit(origin) as d:
                     yield d
-                    response = yield code(
+                    response = yield func(
                         origin, content, request.args, *args, **kwargs
                     )
-            except:
-                logger.exception("authenticate_request failed")
-                raise
+            else:
+                response = yield func(
+                    origin, content, request.args, *args, **kwargs
+                )
+
             defer.returnValue(response)
 
         # Extra logic that functools.wraps() doesn't finish
-        new_code.__self__ = code.__self__
+        new_func.__self__ = func.__self__
 
-        return new_code
+        return new_func
 
     def register(self, server):
         pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -429,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
 class On3pidBindServlet(BaseFederationServlet):
     PATH = "/3pid/onbind"
 
+    REQUIRE_AUTH = False
+
     @defer.inlineCallbacks
-    def on_POST(self, request):
-        content = parse_json_object_from_request(request)
+    def on_POST(self, origin, content, query):
         if "invites" in content:
             last_exception = None
             for invite in content["invites"]:
@@ -453,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
                 raise last_exception
         defer.returnValue((200, {}))
 
-    # Avoid doing remote HS authorization checks which are done by default by
-    # BaseFederationServlet.
-    def _wrap(self, code):
-        return code
-
 
 class OpenIdUserInfo(BaseFederationServlet):
     """
@@ -478,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
 
     PATH = "/openid/userinfo"
 
+    REQUIRE_AUTH = False
+
     @defer.inlineCallbacks
-    def on_GET(self, request):
-        token = parse_string(request, "access_token")
+    def on_GET(self, origin, content, query):
+        token = query.get("access_token", [None])[0]
         if token is None:
             defer.returnValue((401, {
                 "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -497,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
 
         defer.returnValue((200, {"sub": user_id}))
 
-    # Avoid doing remote HS authorization checks which are done by default by
-    # BaseFederationServlet.
-    def _wrap(self, code):
-        return code
-
 
 class PublicRoomList(BaseFederationServlet):
     """
@@ -542,6 +558,20 @@ class PublicRoomList(BaseFederationServlet):
         defer.returnValue((200, data))
 
 
+class FederationVersionServlet(BaseFederationServlet):
+    PATH = "/version"
+
+    REQUIRE_AUTH = False
+
+    def on_GET(self, origin, content, query):
+        return defer.succeed((200, {
+            "server": {
+                "name": "Synapse",
+                "version": get_version_string(synapse)
+            },
+        }))
+
+
 SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
@@ -565,6 +595,7 @@ SERVLET_CLASSES = (
     On3pidBindServlet,
     OpenIdUserInfo,
     PublicRoomList,
+    FederationVersionServlet,
 )
 
 
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index cce3dba47c..76d5998d75 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -68,9 +68,18 @@ class Metrics(object):
 
 
 def register_memory_metrics(hs):
-    metric = MemoryUsageMetric(hs)
+    try:
+        import psutil
+        process = psutil.Process()
+        process.memory_info().rss
+    except (ImportError, AttributeError):
+        logger.warn(
+            "psutil is not installed or incorrect version."
+            " Disabling memory metrics."
+        )
+        return
+    metric = MemoryUsageMetric(hs, psutil)
     all_metrics.append(metric)
-    return metric
 
 
 def get_metrics_for(pkg_name):
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 7becbe0491..e81af29895 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -16,8 +16,6 @@
 
 from itertools import chain
 
-import psutil
-
 
 # TODO(paul): I can't believe Python doesn't have one of these
 def map_concat(func, items):
@@ -167,9 +165,10 @@ class MemoryUsageMetric(object):
     UPDATE_HZ = 2  # number of times to get memory per second
     WINDOW_SIZE_SEC = 30  # the size of the window in seconds
 
-    def __init__(self, hs):
+    def __init__(self, hs, psutil):
         clock = hs.get_clock()
         self.memory_snapshots = []
+
         self.process = psutil.Process()
 
         clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 799d35da5e..86e3d89154 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -36,7 +36,6 @@ REQUIREMENTS = {
     "blist": ["blist"],
     "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
     "pymacaroons-pynacl": ["pymacaroons"],
-    "psutil>=2.0.0": ["psutil>=2.0.0"],
 }
 CONDITIONAL_REQUIREMENTS = {
     "web_client": {
@@ -52,6 +51,9 @@ CONDITIONAL_REQUIREMENTS = {
     "ldap": {
         "ldap3>=1.0": ["ldap3>=1.0"],
     },
+    "psutil": {
+        "psutil>=2.0.0": ["psutil>=2.0.0"],
+    },
 }
 
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 4060593f7f..bdd0e60c5b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -345,7 +345,8 @@ class PreviewUrlResource(Resource):
                 # lines)
                 text_nodes = (
                     re.sub(r'\s+', '\n', el.text).strip()
-                    for el in cloned_tree.iter() if el.text
+                    for el in cloned_tree.iter()
+                    if el.text and isinstance(el.tag, basestring)  # Removes comments
                 )
                 og['og:description'] = summarize_paragraphs(text_nodes)
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e4dbaa3547..d2feee8dbb 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -350,7 +350,7 @@ class EventsStore(SQLBaseStore):
         )
 
         if not events and not allow_none:
-            raise RuntimeError("Could not find event %s" % (event_id,))
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
 
         defer.returnValue(events[0] if events else None)
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4f156cb3b..52086df465 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -21,7 +21,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-def get_version_string(name, module):
+def get_version_string(module):
     try:
         null = open(os.devnull, 'w')
         cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -74,11 +74,11 @@ def get_version_string(name, module):
             )
 
             return (
-                "%s/%s (%s)" % (
-                    name, module.__version__, git_version,
+                "%s (%s)" % (
+                    module.__version__, git_version,
                 )
             ).encode("ascii")
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)
 
-    return ("%s/%s" % (name, module.__version__,)).encode("ascii")
+    return module.__version__.encode("ascii")