diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index dcda40863f..98a50f0948 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -583,12 +583,15 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
- user_id = yield self._get_appservice_user_id(request.args)
+ user_id = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id))
- access_token = request.args["access_token"][0]
+ access_token = get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
+
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
@@ -629,17 +632,19 @@ class Auth(object):
)
@defer.inlineCallbacks
- def _get_appservice_user_id(self, request_args):
+ def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token(
- request_args["access_token"][0]
+ get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
)
if app_service is None:
defer.returnValue(None)
- if "user_id" not in request_args:
+ if "user_id" not in request.args:
defer.returnValue(app_service.sender)
- user_id = request_args["user_id"][0]
+ user_id = request.args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
@@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
try:
- token = request.args["access_token"][0]
+ token = get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
service = yield self.store.get_app_service_by_token(token)
if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
@@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to"
" edit its room list entry"
)
+
+
+def has_access_token(request):
+ """Checks if the request has an access_token.
+
+ Returns:
+ bool: False if no access_token was given, True otherwise.
+ """
+ query_params = request.args.get("access_token")
+ return bool(query_params)
+
+
+def get_access_token_from_request(request, token_not_found_http_status=401):
+ """Extracts the access_token from the request.
+
+ Args:
+ request: The http request.
+ token_not_found_http_status(int): The HTTP status code to set in the
+ AuthError if the token isn't found. This is used in some of the
+ legacy APIs to change the status code to 403 from the default of
+ 401 since some of the old clients depended on auth errors returning
+ 403.
+ Returns:
+ str: The access_token
+ Raises:
+ AuthError: If there isn't an access_token in the request.
+ """
+ query_params = request.args.get("access_token")
+ # Try to get the access_token from the query params.
+ if not query_params:
+ raise AuthError(
+ token_not_found_http_status,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
+
+ return query_params[0]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 78719eed25..3395c9e41e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -122,8 +122,12 @@ class FederationClient(FederationBase):
pdu.event_id
)
+ def send_presence(self, destination, states):
+ if destination != self.server_name:
+ self._transaction_queue.enqueue_presence(destination, states)
+
@log_function
- def send_edu(self, destination, edu_type, content):
+ def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
origin=self.server_name,
destination=destination,
@@ -134,7 +138,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc()
# TODO, add errback, etc.
- self._transaction_queue.enqueue_edu(edu)
+ self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None)
@log_function
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 1ac569b305..f8ca93e4c3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
+from synapse.handlers.presence import format_user_presence_state
import synapse.metrics
import logging
@@ -69,13 +70,21 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {}
+ # Presence needs to be separate as we send single aggragate EDUs
+ self.pending_presence_by_dest = presence = {}
+ self.pending_edus_keyed_by_dest = edus_keyed = {}
+
metrics.register_callback(
"pending_pdus",
lambda: sum(map(len, pdus.values())),
)
metrics.register_callback(
"pending_edus",
- lambda: sum(map(len, edus.values())),
+ lambda: (
+ sum(map(len, edus.values()))
+ + sum(map(len, presence.values()))
+ + sum(map(len, edus_keyed.values()))
+ ),
)
# destination -> list of tuple(failure, deferred)
@@ -130,13 +139,27 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_edu(self, edu):
+ def enqueue_presence(self, destination, states):
+ self.pending_presence_by_dest.setdefault(destination, {}).update({
+ state.user_id: state for state in states
+ })
+
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
+
+ def enqueue_edu(self, edu, key=None):
destination = edu.destination
if not self.can_send_to(destination):
return
- self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+ if key:
+ self.pending_edus_keyed_by_dest.setdefault(
+ destination, {}
+ )[(edu.edu_type, key)] = edu
+ else:
+ self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn(
self._attempt_new_transaction, destination
@@ -190,8 +213,13 @@ class TransactionQueue(object):
while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ pending_edus.extend(
+ self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+ )
+
limiter = yield get_retry_limiter(
destination,
self.clock,
@@ -203,6 +231,22 @@ class TransactionQueue(object):
)
pending_edus.extend(device_message_edus)
+ if pending_presence:
+ pending_edus.append(
+ Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ format_user_presence_state(
+ presence, self.clock.time_msec()
+ )
+ for presence in pending_presence.values()
+ ]
+ },
+ )
+ )
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 16dbddee03..a949e39bda 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -625,18 +625,8 @@ class PresenceHandler(object):
Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
"""
- now = self.clock.time_msec()
for host, states in hosts_to_states.items():
- self.federation.send_edu(
- destination=host,
- edu_type="m.presence",
- content={
- "push": [
- _format_user_presence_state(state, now)
- for state in states
- ]
- }
- )
+ self.federation.send_presence(host, states)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
@@ -723,13 +713,13 @@ class PresenceHandler(object):
defer.returnValue([
{
"type": "m.presence",
- "content": _format_user_presence_state(state, now),
+ "content": format_user_presence_state(state, now),
}
for state in updates
])
else:
defer.returnValue([
- _format_user_presence_state(state, now) for state in updates
+ format_user_presence_state(state, now) for state in updates
])
@defer.inlineCallbacks
@@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
return False
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients
and to other servers.
"""
@@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
defer.returnValue(([
{
"type": "m.presence",
- "content": _format_user_presence_state(s, now),
+ "content": format_user_presence_state(s, now),
}
for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
}
},
},
+ key=(room_id, receipt_type, user_id),
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b687957dd..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
"user_id": user_id,
"typing": typing,
},
+ key=(room_id, user_id),
))
yield preserve_context_over_deferred(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6c0eec8fb3..345018a8fc 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
- request.redirect("%s?%s" % (self.cas_server_url, service_param))
+ request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)
@@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
def parse_cas_response(self, cas_response_body):
user = None
- attributes = None
+ attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
@@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
- attributes = {}
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
@@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
- if attributes is None:
- raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 9bff02ee4e..1358d0acab 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from synapse.api.errors import AuthError, Codes
+from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- try:
- access_token = request.args["access_token"][0]
- except KeyError:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 71d58c8e8d..3046da7aec 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
+from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
@@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- if "access_token" not in request.args:
- raise SynapseError(400, "Expected application service token.")
+ as_token = get_access_token_from_request(request)
+
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
- as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- if "access_token" not in request.args:
- raise SynapseError(400, "Expected application service token.")
-
+ access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token(
- request.args["access_token"][0]
+ access_token
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index bdccf464a5..2f2c9d0881 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from synapse.api.auth import get_access_token_from_request
+
logger = logging.getLogger(__name__)
@@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response
def _get_key(self, request):
- token = request.args["access_token"][0]
+ token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2121bd75ea..68d18a9b82 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
- if 'access_token' in request.args:
+ if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
+ access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration(
- desired_username, request.args["access_token"][0], body
+ desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index dca615927a..31f94bc6e9 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -80,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
- del fields["access_token"]
+ fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
@@ -104,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
- del fields["access_token"]
+ fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 10e9305f7b..a87d90741a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -353,12 +353,14 @@ class EventPushActionsStore(SQLBaseStore):
before_clause += " "
before_clause += "AND epa.highlight = 1"
+ # NB. This assumes event_ids are globally unique since
+ # it makes the query easier to index
sql = (
"SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering,"
" epa.actions, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e"
- " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
+ " WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?"
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql
new file mode 100644
index 0000000000..4fc32c351a
--- /dev/null
+++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ CREATE INDEX event_push_actions_user_id_highlight_stream_ordering on event_push_actions(
+ user_id, highlight, stream_ordering
+ );
|