summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py58
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/transaction_queue.py50
-rw-r--r--synapse/handlers/presence.py20
-rw-r--r--synapse/handlers/receipts.py1
-rw-r--r--synapse/handlers/typing.py1
-rw-r--r--synapse/rest/client/v1/login.py7
-rw-r--r--synapse/rest/client/v1/logout.py10
-rw-r--r--synapse/rest/client/v1/register.py12
-rw-r--r--synapse/rest/client/v1/transactions.py4
-rw-r--r--synapse/rest/client/v2_alpha/register.py6
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py4
-rw-r--r--synapse/storage/event_push_actions.py4
-rw-r--r--synapse/storage/schema/delta/35/event_push_actions_index.sql18
14 files changed, 150 insertions, 53 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index dcda40863f..98a50f0948 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -583,12 +583,15 @@ class Auth(object):
         """
         # Can optionally look elsewhere in the request (e.g. headers)
         try:
-            user_id = yield self._get_appservice_user_id(request.args)
+            user_id = yield self._get_appservice_user_id(request)
             if user_id:
                 request.authenticated_entity = user_id
                 defer.returnValue(synapse.types.create_requester(user_id))
 
-            access_token = request.args["access_token"][0]
+            access_token = get_access_token_from_request(
+                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+            )
+
             user_info = yield self.get_user_by_access_token(access_token, rights)
             user = user_info["user"]
             token_id = user_info["token_id"]
@@ -629,17 +632,19 @@ class Auth(object):
             )
 
     @defer.inlineCallbacks
-    def _get_appservice_user_id(self, request_args):
+    def _get_appservice_user_id(self, request):
         app_service = yield self.store.get_app_service_by_token(
-            request_args["access_token"][0]
+            get_access_token_from_request(
+                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+            )
         )
         if app_service is None:
             defer.returnValue(None)
 
-        if "user_id" not in request_args:
+        if "user_id" not in request.args:
             defer.returnValue(app_service.sender)
 
-        user_id = request_args["user_id"][0]
+        user_id = request.args["user_id"][0]
         if app_service.sender == user_id:
             defer.returnValue(app_service.sender)
 
@@ -833,7 +838,9 @@ class Auth(object):
     @defer.inlineCallbacks
     def get_appservice_by_req(self, request):
         try:
-            token = request.args["access_token"][0]
+            token = get_access_token_from_request(
+                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+            )
             service = yield self.store.get_app_service_by_token(token)
             if not service:
                 logger.warn("Unrecognised appservice access token: %s" % (token,))
@@ -1142,3 +1149,40 @@ class Auth(object):
                 "This server requires you to be a moderator in the room to"
                 " edit its room list entry"
             )
+
+
+def has_access_token(request):
+    """Checks if the request has an access_token.
+
+    Returns:
+        bool: False if no access_token was given, True otherwise.
+    """
+    query_params = request.args.get("access_token")
+    return bool(query_params)
+
+
+def get_access_token_from_request(request, token_not_found_http_status=401):
+    """Extracts the access_token from the request.
+
+    Args:
+        request: The http request.
+        token_not_found_http_status(int): The HTTP status code to set in the
+            AuthError if the token isn't found. This is used in some of the
+            legacy APIs to change the status code to 403 from the default of
+            401 since some of the old clients depended on auth errors returning
+            403.
+    Returns:
+        str: The access_token
+    Raises:
+        AuthError: If there isn't an access_token in the request.
+    """
+    query_params = request.args.get("access_token")
+    # Try to get the access_token from the query params.
+    if not query_params:
+        raise AuthError(
+            token_not_found_http_status,
+            "Missing access token.",
+            errcode=Codes.MISSING_TOKEN
+        )
+
+    return query_params[0]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 78719eed25..3395c9e41e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -122,8 +122,12 @@ class FederationClient(FederationBase):
             pdu.event_id
         )
 
+    def send_presence(self, destination, states):
+        if destination != self.server_name:
+            self._transaction_queue.enqueue_presence(destination, states)
+
     @log_function
-    def send_edu(self, destination, edu_type, content):
+    def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
             origin=self.server_name,
             destination=destination,
@@ -134,7 +138,7 @@ class FederationClient(FederationBase):
         sent_edus_counter.inc()
 
         # TODO, add errback, etc.
-        self._transaction_queue.enqueue_edu(edu)
+        self._transaction_queue.enqueue_edu(edu, key=key)
         return defer.succeed(None)
 
     @log_function
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 1ac569b305..f8ca93e4c3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
 import logging
@@ -69,13 +70,21 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
+        # Presence needs to be separate as we send single aggragate EDUs
+        self.pending_presence_by_dest = presence = {}
+        self.pending_edus_keyed_by_dest = edus_keyed = {}
+
         metrics.register_callback(
             "pending_pdus",
             lambda: sum(map(len, pdus.values())),
         )
         metrics.register_callback(
             "pending_edus",
-            lambda: sum(map(len, edus.values())),
+            lambda: (
+                sum(map(len, edus.values()))
+                + sum(map(len, presence.values()))
+                + sum(map(len, edus_keyed.values()))
+            ),
         )
 
         # destination -> list of tuple(failure, deferred)
@@ -130,13 +139,27 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_edu(self, edu):
+    def enqueue_presence(self, destination, states):
+        self.pending_presence_by_dest.setdefault(destination, {}).update({
+            state.user_id: state for state in states
+        })
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
+    def enqueue_edu(self, edu, key=None):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+        if key:
+            self.pending_edus_keyed_by_dest.setdefault(
+                destination, {}
+            )[(edu.edu_type, key)] = edu
+        else:
+            self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
         preserve_context_over_fn(
             self._attempt_new_transaction, destination
@@ -190,8 +213,13 @@ class TransactionQueue(object):
             while True:
                     pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                     pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
                     pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
+                    pending_edus.extend(
+                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                    )
+
                     limiter = yield get_retry_limiter(
                         destination,
                         self.clock,
@@ -203,6 +231,22 @@ class TransactionQueue(object):
                     )
 
                     pending_edus.extend(device_message_edus)
+                    if pending_presence:
+                        pending_edus.append(
+                            Edu(
+                                origin=self.server_name,
+                                destination=destination,
+                                edu_type="m.presence",
+                                content={
+                                    "push": [
+                                        format_user_presence_state(
+                                            presence, self.clock.time_msec()
+                                        )
+                                        for presence in pending_presence.values()
+                                    ]
+                                },
+                            )
+                        )
 
                     if pending_pdus:
                         logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 16dbddee03..a949e39bda 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -625,18 +625,8 @@ class PresenceHandler(object):
         Args:
             hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
         """
-        now = self.clock.time_msec()
         for host, states in hosts_to_states.items():
-            self.federation.send_edu(
-                destination=host,
-                edu_type="m.presence",
-                content={
-                    "push": [
-                        _format_user_presence_state(state, now)
-                        for state in states
-                    ]
-                }
-            )
+            self.federation.send_presence(host, states)
 
     @defer.inlineCallbacks
     def incoming_presence(self, origin, content):
@@ -723,13 +713,13 @@ class PresenceHandler(object):
             defer.returnValue([
                 {
                     "type": "m.presence",
-                    "content": _format_user_presence_state(state, now),
+                    "content": format_user_presence_state(state, now),
                 }
                 for state in updates
             ])
         else:
             defer.returnValue([
-                _format_user_presence_state(state, now) for state in updates
+                format_user_presence_state(state, now) for state in updates
             ])
 
     @defer.inlineCallbacks
@@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
     return False
 
 
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
     """
@@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
         defer.returnValue(([
             {
                 "type": "m.presence",
-                "content": _format_user_presence_state(s, now),
+                "content": format_user_presence_state(s, now),
             }
             for s in updates.values()
             if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
                             }
                         },
                     },
+                    key=(room_id, receipt_type, user_id),
                 )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b687957dd..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
                         "user_id": user_id,
                         "typing": typing,
                     },
+                    key=(room_id, user_id),
                 ))
 
         yield preserve_context_over_deferred(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6c0eec8fb3..345018a8fc 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
         service_param = urllib.urlencode({
             "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
         })
-        request.redirect("%s?%s" % (self.cas_server_url, service_param))
+        request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
         finish_request(request)
 
 
@@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
 
     def parse_cas_response(self, cas_response_body):
         user = None
-        attributes = None
+        attributes = {}
         try:
             root = ET.fromstring(cas_response_body)
             if not root.tag.endswith("serviceResponse"):
@@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
                 if child.tag.endswith("user"):
                     user = child.text
                 if child.tag.endswith("attributes"):
-                    attributes = {}
                     for attribute in child:
                         # ElementTree library expands the namespace in
                         # attribute tags to the full URL of the namespace.
@@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
                         attributes[tag] = attribute.text
             if user is None:
                 raise Exception("CAS response does not contain user")
-            if attributes is None:
-                raise Exception("CAS response does not contain attributes")
         except Exception:
             logger.error("Error parsing CAS response", exc_info=1)
             raise LoginError(401, "Invalid CAS response",
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 9bff02ee4e..1358d0acab 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, Codes
+from synapse.api.auth import get_access_token_from_request
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        try:
-            access_token = request.args["access_token"][0]
-        except KeyError:
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
-                errcode=Codes.MISSING_TOKEN
-            )
+        access_token = get_access_token_from_request(request)
         yield self.store.delete_access_token(access_token)
         defer.returnValue((200, {}))
 
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 71d58c8e8d..3046da7aec 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.api.constants import LoginType
+from synapse.api.auth import get_access_token_from_request
 from .base import ClientV1RestServlet, client_path_patterns
 import synapse.util.stringutils as stringutils
 from synapse.http.servlet import parse_json_object_from_request
@@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def _do_app_service(self, request, register_json, session):
-        if "access_token" not in request.args:
-            raise SynapseError(400, "Expected application service token.")
+        as_token = get_access_token_from_request(request)
+
         if "user" not in register_json:
             raise SynapseError(400, "Expected 'user' key.")
 
-        as_token = request.args["access_token"][0]
         user_localpart = register_json["user"].encode("utf-8")
 
         handler = self.handlers.registration_handler
@@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         user_json = parse_json_object_from_request(request)
 
-        if "access_token" not in request.args:
-            raise SynapseError(400, "Expected application service token.")
-
+        access_token = get_access_token_from_request(request)
         app_service = yield self.store.get_app_service_by_token(
-            request.args["access_token"][0]
+            access_token
         )
         if not app_service:
             raise SynapseError(403, "Invalid application service token.")
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index bdccf464a5..2f2c9d0881 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -17,6 +17,8 @@
 to ensure idempotency when performing PUTs using the REST API."""
 import logging
 
+from synapse.api.auth import get_access_token_from_request
+
 logger = logging.getLogger(__name__)
 
 
@@ -90,6 +92,6 @@ class HttpTransactionStore(object):
         return response
 
     def _get_key(self, request):
-        token = request.args["access_token"][0]
+        token = get_access_token_from_request(request)
         path_without_txn_id = request.path.rsplit("/", 1)[0]
         return path_without_txn_id + "/" + token
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2121bd75ea..68d18a9b82 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -15,6 +15,7 @@
 
 from twisted.internet import defer
 
+from synapse.api.auth import get_access_token_from_request, has_access_token
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
             desired_username = body['username']
 
         appservice = None
-        if 'access_token' in request.args:
+        if has_access_token(request):
             appservice = yield self.auth.get_appservice_by_req(request)
 
         # fork off as soon as possible for ASes and shared secret auth which
@@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
             # 'user' key not 'username'). Since this is a new addition, we'll
             # fallback to 'username' if they gave one.
             desired_username = body.get("user", desired_username)
+            access_token = get_access_token_from_request(request)
 
             if isinstance(desired_username, basestring):
                 result = yield self._do_appservice_registration(
-                    desired_username, request.args["access_token"][0], body
+                    desired_username, access_token, body
                 )
             defer.returnValue((200, result))  # we throw for non 200 responses
             return
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index dca615927a..31f94bc6e9 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -80,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
         yield self.auth.get_user_by_req(request)
 
         fields = request.args
-        del fields["access_token"]
+        fields.pop("access_token", None)
 
         results = yield self.appservice_handler.query_3pe(
             ThirdPartyEntityKind.USER, protocol, fields
@@ -104,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
         yield self.auth.get_user_by_req(request)
 
         fields = request.args
-        del fields["access_token"]
+        fields.pop("access_token", None)
 
         results = yield self.appservice_handler.query_3pe(
             ThirdPartyEntityKind.LOCATION, protocol, fields
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 10e9305f7b..a87d90741a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -353,12 +353,14 @@ class EventPushActionsStore(SQLBaseStore):
                     before_clause += " "
                 before_clause += "AND epa.highlight = 1"
 
+            # NB. This assumes event_ids are globally unique since
+            # it makes the query easier to index
             sql = (
                 "SELECT epa.event_id, epa.room_id,"
                 " epa.stream_ordering, epa.topological_ordering,"
                 " epa.actions, epa.profile_tag, e.received_ts"
                 " FROM event_push_actions epa, events e"
-                " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
+                " WHERE epa.event_id = e.event_id"
                 " AND epa.user_id = ? %s"
                 " ORDER BY epa.stream_ordering DESC"
                 " LIMIT ?"
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql
new file mode 100644
index 0000000000..4fc32c351a
--- /dev/null
+++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ CREATE INDEX event_push_actions_user_id_highlight_stream_ordering on event_push_actions(
+     user_id, highlight, stream_ordering
+ );