diff --git a/CHANGES.rst b/CHANGES.rst
index 49673ccce4..23be6c8efa 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,47 @@
+Changes in synapse v0.17.3 (2016-09-09)
+=======================================
+
+This release fixes a major bug that stopped servers from handling rooms with
+over 1000 members.
+
+
+Changes in synapse v0.17.2 (2016-09-08)
+=======================================
+
+This release contains security bug fixes. Please upgrade.
+
+
+No changes since v0.17.2-rc1
+
+
+Changes in synapse v0.17.2-rc1 (2016-09-05)
+===========================================
+
+Features:
+
+* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
+ #1062, #1066)
+
+
+Changes:
+
+* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
+ #1068)
+* Don't notify for online to online presence transitions. (PR #1054)
+* Occasionally persist unpersisted presence updates (PR #1055)
+* Allow application services to have an optional 'url' (PR #1056)
+* Clean up old sent transactions from DB (PR #1059)
+
+
+Bug fixes:
+
+* Fix None check in backfill (PR #1043)
+* Fix membership changes to be idempotent (PR #1067)
+* Fix bug in get_pdu where it would sometimes return events with incorrect
+ signature
+
+
+
Changes in synapse v0.17.1 (2016-08-24)
=======================================
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 43bf78f885..b778cd65c9 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.17.1"
+__version__ = "0.17.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index dcda40863f..98a50f0948 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -583,12 +583,15 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
- user_id = yield self._get_appservice_user_id(request.args)
+ user_id = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id))
- access_token = request.args["access_token"][0]
+ access_token = get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
+
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
@@ -629,17 +632,19 @@ class Auth(object):
)
@defer.inlineCallbacks
- def _get_appservice_user_id(self, request_args):
+ def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token(
- request_args["access_token"][0]
+ get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
)
if app_service is None:
defer.returnValue(None)
- if "user_id" not in request_args:
+ if "user_id" not in request.args:
defer.returnValue(app_service.sender)
- user_id = request_args["user_id"][0]
+ user_id = request.args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
@@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
try:
- token = request.args["access_token"][0]
+ token = get_access_token_from_request(
+ request, self.TOKEN_NOT_FOUND_HTTP_STATUS
+ )
service = yield self.store.get_app_service_by_token(token)
if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
@@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to"
" edit its room list entry"
)
+
+
+def has_access_token(request):
+ """Checks if the request has an access_token.
+
+ Returns:
+ bool: False if no access_token was given, True otherwise.
+ """
+ query_params = request.args.get("access_token")
+ return bool(query_params)
+
+
+def get_access_token_from_request(request, token_not_found_http_status=401):
+ """Extracts the access_token from the request.
+
+ Args:
+ request: The http request.
+ token_not_found_http_status(int): The HTTP status code to set in the
+ AuthError if the token isn't found. This is used in some of the
+ legacy APIs to change the status code to 403 from the default of
+ 401 since some of the old clients depended on auth errors returning
+ 403.
+ Returns:
+ str: The access_token
+ Raises:
+ AuthError: If there isn't an access_token in the request.
+ """
+ query_params = request.args.get("access_token")
+ # Try to get the access_token from the query params.
+ if not query_params:
+ raise AuthError(
+ token_not_found_http_status,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
+
+ return query_params[0]
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index cc4af23962..b0eb0c6d9d 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
+def _is_valid_3pe_metadata(info):
+ if "instances" not in info:
+ return False
+ if not isinstance(info["instances"], list):
+ return False
+ return True
+
+
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
@@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol)
)
try:
- defer.returnValue((yield self.get_json(uri, {})))
+ info = yield self.get_json(uri, {})
+
+ if not _is_valid_3pe_metadata(info):
+ logger.warning("query_3pe_protocol to %s did not return a"
+ " valid result", uri)
+ defer.returnValue(None)
+
+ defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
- defer.returnValue({})
+ defer.returnValue(None)
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 627acc6a4f..78719eed25 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -138,6 +138,12 @@ class FederationClient(FederationBase):
return defer.succeed(None)
@log_function
+ def send_device_messages(self, destination):
+ """Sends the device messages in the local database to the remote
+ destination"""
+ self._transaction_queue.enqueue_device_messages(destination)
+
+ @log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5621655098..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
- logger.exception("Failed to handle edu %r", edu_type, e)
+ logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..1ac569b305 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
from twisted.internet import defer
from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
@@ -81,6 +81,8 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
+ self.last_device_stream_id_by_dest = {}
+
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@@ -155,179 +157,240 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
+ def enqueue_device_messages(self, destination):
+ if destination == self.server_name or destination == "localhost":
+ return
+
+ if not self.can_send_to(destination):
+ return
+
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
+
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
- yield run_on_reactor()
- while True:
- # list of (pending_pdu, deferred, order)
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- logger.debug(
- "TX [%s] Transaction already in progress",
- destination
- )
- return
+ # list of (pending_pdu, deferred, order)
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ logger.debug(
+ "TX [%s] Transaction already in progress",
+ destination
+ )
+ return
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ try:
+ self.pending_transactions[destination] = 1
- if pending_pdus:
- logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
+ yield run_on_reactor()
- if not pending_pdus and not pending_edus and not pending_failures:
- logger.debug("TX [%s] Nothing to send", destination)
- return
+ while True:
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- yield self._send_new_transaction(
- destination, pending_pdus, pending_edus, pending_failures
+ limiter = yield get_retry_limiter(
+ destination,
+ self.clock,
+ self.store,
+ )
+
+ device_message_edus, device_stream_id = (
+ yield self._get_new_device_messages(destination)
+ )
+
+ pending_edus.extend(device_message_edus)
+
+ if pending_pdus:
+ logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
+
+ if not pending_pdus and not pending_edus and not pending_failures:
+ logger.debug("TX [%s] Nothing to send", destination)
+ self.last_device_stream_id_by_dest[destination] = (
+ device_stream_id
+ )
+ return
+
+ success = yield self._send_new_transaction(
+ destination, pending_pdus, pending_edus, pending_failures,
+ device_stream_id,
+ should_delete_from_device_stream=bool(device_message_edus),
+ limiter=limiter,
+ )
+ if not success:
+ break
+ except NotRetryingDestination:
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
+
+ @defer.inlineCallbacks
+ def _get_new_device_messages(self, destination):
+ last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+ to_device_stream_id = self.store.get_to_device_stream_token()
+ contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+ destination, last_device_stream_id, to_device_stream_id
+ )
+ edus = [
+ Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type="m.direct_to_device",
+ content=content,
)
+ for content in contents
+ ]
+ defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures):
+ pending_failures, device_stream_id,
+ should_delete_from_device_stream, limiter):
- # Sort based on the order field
- pending_pdus.sort(key=lambda t: t[1])
- pdus = [x[0] for x in pending_pdus]
- edus = pending_edus
- failures = [x.get_dict() for x in pending_failures]
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[1])
+ pdus = [x[0] for x in pending_pdus]
+ edus = pending_edus
+ failures = [x.get_dict() for x in pending_failures]
- try:
- self.pending_transactions[destination] = 1
+ success = True
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ try:
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- txn_id = str(self._next_txn_id)
+ txn_id = str(self._next_txn_id)
- limiter = yield get_retry_limiter(
- destination,
- self.clock,
- self.store,
- )
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pdus),
+ len(edus),
+ len(failures)
+ )
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
+ logger.debug("TX [%s] Persisting transaction...", destination)
- logger.debug("TX [%s] Persisting transaction...", destination)
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
- transaction = Transaction.create_new(
- origin_server_ts=int(self.clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
- self._next_txn_id += 1
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pdus),
+ len(edus),
+ len(failures),
+ )
- yield self.transaction_actions.prepare_to_send(transaction)
+ with limiter:
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
+ )
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
- logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures),
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
)
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self.clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
- )
-
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Marked as delivered", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
+ if code != 200:
+ for p in pdus:
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
+ success = False
+ else:
+ # Remove the acknowledged device messages from the database
+ if should_delete_from_device_stream:
+ yield self.store.delete_device_msgs_for_remote(
+ destination, device_stream_id
+ )
+ self.last_device_stream_id_by_dest[destination] = device_stream_id
+ except RuntimeError as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
- if code != 200:
- for p in pdus:
- logger.info(
- "Failed to send event %s to %s", p.event_id, destination
- )
- except NotRetryingDestination:
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ success = False
+
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
- for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ success = False
- for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
+ defer.returnValue(success)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b440280b74..88fa0bb2e4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_3pe_protocols(self):
+ def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services()
protocols = {}
+
+ # Collect up all the individual protocol responses out of the ASes
for s in services:
for p in s.protocols:
- protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
+ if only_protocol is not None and p != only_protocol:
+ continue
+
+ if p not in protocols:
+ protocols[p] = []
+
+ info = yield self.appservice_api.get_3pe_protocol(s, p)
+
+ if info is not None:
+ protocols[p].append(info)
+
+ def _merge_instances(infos):
+ if not infos:
+ return {}
+
+ # Merge the 'instances' lists of multiple results, but just take
+ # the other fields from the first as they ought to be identical
+ # copy the result so as not to corrupt the cached one
+ combined = dict(infos[0])
+ combined["instances"] = list(combined["instances"])
+
+ for info in infos[1:]:
+ combined["instances"].extend(info["instances"])
+
+ return combined
+
+ for p in protocols.keys():
+ protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
new file mode 100644
index 0000000000..c5368e5df2
--- /dev/null
+++ b/synapse/handlers/devicemessage.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.types import get_domain_from_id
+from synapse.util.stringutils import random_string
+
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceMessageHandler(object):
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.federation = hs.get_replication_layer()
+
+ self.federation.register_edu_handler(
+ "m.direct_to_device", self.on_direct_to_device_edu
+ )
+
+ @defer.inlineCallbacks
+ def on_direct_to_device_edu(self, origin, content):
+ local_messages = {}
+ sender_user_id = content["sender"]
+ if origin != get_domain_from_id(sender_user_id):
+ logger.warn(
+ "Dropping device message from %r with spoofed sender %r",
+ origin, sender_user_id
+ )
+ message_type = content["type"]
+ message_id = content["message_id"]
+ for user_id, by_device in content["messages"].items():
+ messages_by_device = {
+ device_id: {
+ "content": message_content,
+ "type": message_type,
+ "sender": sender_user_id,
+ }
+ for device_id, message_content in by_device.items()
+ }
+ if messages_by_device:
+ local_messages[user_id] = messages_by_device
+
+ stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+ origin, message_id, local_messages
+ )
+
+ self.notifier.on_new_event(
+ "to_device_key", stream_id, users=local_messages.keys()
+ )
+
+ @defer.inlineCallbacks
+ def send_device_message(self, sender_user_id, message_type, messages):
+
+ local_messages = {}
+ remote_messages = {}
+ for user_id, by_device in messages.items():
+ if self.is_mine_id(user_id):
+ messages_by_device = {
+ device_id: {
+ "content": message_content,
+ "type": message_type,
+ "sender": sender_user_id,
+ }
+ for device_id, message_content in by_device.items()
+ }
+ if messages_by_device:
+ local_messages[user_id] = messages_by_device
+ else:
+ destination = get_domain_from_id(user_id)
+ remote_messages.setdefault(destination, {})[user_id] = by_device
+
+ message_id = random_string(16)
+
+ remote_edu_contents = {}
+ for destination, messages in remote_messages.items():
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ }
+
+ stream_id = yield self.store.add_messages_to_device_inbox(
+ local_messages, remote_edu_contents
+ )
+
+ self.notifier.on_new_event(
+ "to_device_key", stream_id, users=local_messages.keys()
+ )
+
+ for destination in remote_messages.keys():
+ # Enqueue a new federation transaction to send the new
+ # device messages to each remote destination.
+ self.federation.send_device_messages(destination)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da9f0da69e..16dbddee03 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -265,6 +265,12 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives
+ # Only bother handling the last presence change for each user
+ new_states_dict = {}
+ for new_state in new_states:
+ new_states_dict[new_state.user_id] = new_state
+ new_state = new_states_dict.values()
+
for new_state in new_states:
user_id = new_state.user_id
@@ -651,6 +657,13 @@ class PresenceHandler(object):
)
continue
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got presence update from %r with bad 'user_id': %r",
+ origin, user_id,
+ )
+ continue
+
presence_state = push.get("presence", None)
if not presence_state:
logger.info(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index bf6b1c1535..8758af4ca1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -444,6 +444,16 @@ class RoomListHandler(BaseHandler):
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
+ def get_remote_public_room_list(self, server_name):
+ res = yield self.hs.get_replication_layer().get_public_rooms(
+ [server_name]
+ )
+
+ if server_name not in res:
+ raise SynapseError(404, "Server not found")
+ defer.returnValue(res[server_name])
+
+ @defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0b530b9034..3b687957dd 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -199,7 +199,14 @@ class TypingHandler(object):
user_id = content["user_id"]
# Check that the string is a valid user id
- UserID.from_string(user_id)
+ user = UserID.from_string(user_id)
+
+ if user.domain != origin:
+ logger.info(
+ "Got typing update from %r with bad 'user_id': %r",
+ origin, user_id,
+ )
+ return
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 1ed9034bcb..857bc9795c 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -181,7 +181,7 @@ class ReplicationResource(Resource):
def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token()
- logger.info("Replicating up to %r", current_token)
+ logger.debug("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams)
@@ -195,7 +195,7 @@ class ReplicationResource(Resource):
yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
- logger.info("Replicated %d rows", writer.total)
+ logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 64d8eb2af1..3bfd5e8213 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,13 +16,18 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_inbox", "stream_id",
+ db_conn, "device_max_stream_id", "stream_id",
+ )
+ self._device_inbox_stream_cache = StreamChangeCache(
+ "DeviceInboxStreamChangeCache",
+ self._device_inbox_id_gen.get_current_token()
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
@@ -38,5 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
+ for row in stream["rows"]:
+ stream_id = row[0]
+ user_id = row[1]
+ self._device_inbox_stream_cache.entity_has_changed(
+ user_id, stream_id
+ )
return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 9bff02ee4e..1358d0acab 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from synapse.api.errors import AuthError, Codes
+from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- try:
- access_token = request.args["access_token"][0]
- except KeyError:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 71d58c8e8d..3046da7aec 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
+from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
@@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- if "access_token" not in request.args:
- raise SynapseError(400, "Expected application service token.")
+ as_token = get_access_token_from_request(request)
+
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
- as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- if "access_token" not in request.args:
- raise SynapseError(400, "Expected application service token.")
-
+ access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token(
- request.args["access_token"][0]
+ access_token
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0d81757010..22d6a7d31e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -22,8 +22,8 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
-from synapse.events.utils import serialize_event
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.events.utils import serialize_event, format_event_for_client_v2
+from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging
import urllib
@@ -120,6 +120,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ format = parse_string(request, "format", default="content",
+ allowed_values=["content", "event"])
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
@@ -134,7 +136,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
- defer.returnValue((200, data.get_dict()["content"]))
+
+ if format == "event":
+ event = format_event_for_client_v2(data.get_dict())
+ defer.returnValue((200, event))
+ elif format == "content":
+ defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
@@ -295,15 +302,26 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
+ server = parse_string(request, "server", default=None)
+
try:
yield self.auth.get_user_by_req(request)
- except AuthError:
- # This endpoint isn't authed, but its useful to know who's hitting
- # it if they *do* supply an access token
- pass
+ except AuthError as e:
+ # We allow people to not be authed if they're just looking at our
+ # room list, but require auth when we proxy the request.
+ # In both cases we call the auth function, as that has the side
+ # effect of logging who issued this request if an access token was
+ # provided.
+ if server:
+ raise e
+ else:
+ pass
handler = self.hs.get_room_list_handler()
- data = yield handler.get_aggregated_public_room_list()
+ if server:
+ data = yield handler.get_remote_public_room_list(server)
+ else:
+ data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data))
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index bdccf464a5..2f2c9d0881 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from synapse.api.auth import get_access_token_from_request
+
logger = logging.getLogger(__name__)
@@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response
def _get_key(self, request):
- token = request.args["access_token"][0]
+ token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2121bd75ea..68d18a9b82 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
- if 'access_token' in request.args:
+ if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
+ access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration(
- desired_username, request.args["access_token"][0], body
+ desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 9c10a99acf..5975164b37 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -16,10 +16,11 @@
import logging
from twisted.internet import defer
-from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet
+from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore()
+ self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id):
@@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request)
- # TODO: Prod the notifier to wake up sync streams.
- # TODO: Implement replication for the messages.
- # TODO: Send the messages to remote servers if needed.
-
- local_messages = {}
- for user_id, by_device in content["messages"].items():
- if self.is_mine_id(user_id):
- messages_by_device = {
- device_id: {
- "content": message_content,
- "type": message_type,
- "sender": requester.user.to_string(),
- }
- for device_id, message_content in by_device.items()
- }
- if messages_by_device:
- local_messages[user_id] = messages_by_device
-
- stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
-
- self.notifier.on_new_event(
- "to_device_key", stream_id, users=local_messages.keys()
+ sender_user_id = requester.user.to_string()
+
+ yield self.device_message_handler.send_device_message(
+ sender_user_id, message_type, content["messages"]
)
response = (200, {})
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 4f6f1a7e17..31f94bc6e9 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols))
+class ThirdPartyProtocolServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyProtocolServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ protocols = yield self.appservice_handler.get_3pe_protocols(
+ only_protocol=protocol,
+ )
+ if protocol in protocols:
+ defer.returnValue((200, protocols[protocol]))
+ else:
+ defer.returnValue((404, {"error": "Unknown protocol"}))
+
+
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=())
@@ -57,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
- del fields["access_token"]
+ fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
@@ -81,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
- del fields["access_token"]
+ fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
@@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
+ ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)
diff --git a/synapse/server.py b/synapse/server.py
index af3246504b..f516f08167 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -35,6 +35,7 @@ from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
+from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
@@ -100,6 +101,7 @@ class HomeServer(object):
'application_service_api',
'application_service_scheduler',
'application_service_handler',
+ 'device_message_handler',
'notifier',
'distributor',
'client_resource',
@@ -205,6 +207,9 @@ class HomeServer(object):
def build_device_handler(self):
return DeviceHandler(self)
+ def build_device_message_handler(self):
+ return DeviceMessageHandler(self)
+
def build_e2e_keys_handler(self):
return E2eKeysHandler(self)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6c32773f25..a61e83d5de 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -111,7 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_inbox", "stream_id"
+ db_conn, "device_max_stream_id", "stream_id"
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
@@ -182,6 +182,30 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill,
)
+ max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+ device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
+ db_conn, "device_inbox",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id
+ )
+ self._device_inbox_stream_cache = StreamChangeCache(
+ "DeviceInboxStreamChangeCache", min_device_inbox_id,
+ prefilled_cache=device_inbox_prefill,
+ )
+ # The federation outbox and the local device inbox uses the same
+ # stream_id generator.
+ device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
+ db_conn, "device_federation_outbox",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ )
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
+ prefilled_cache=device_outbox_prefill,
+ )
+
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 30d0e4c5dc..003f5ba203 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -133,10 +133,12 @@ class BackgroundUpdateStore(SQLBaseStore):
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
- retcols=("update_name",),
+ retcols=("update_name", "depends_on"),
)
+ in_flight = set(update["update_name"] for update in updates)
for update in updates:
- self._background_update_queue.append(update['update_name'])
+ if update["depends_on"] not in in_flight:
+ self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
# no work left to do
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 68116b0394..b729b7106e 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -27,63 +27,170 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore):
@defer.inlineCallbacks
- def add_messages_to_device_inbox(self, messages_by_user_then_device):
- """
+ def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
+ remote_messages_by_destination):
+ """Used to send messages from this server.
+
Args:
- messages_by_user_and_device(dict):
+ sender_user_id(str): The ID of the user sending these messages.
+ local_messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message.
+ remote_messages_by_destination(dict):
+ Dictionary of destination server_name to the EDU JSON to send.
Returns:
A deferred stream_id that resolves when the messages have been
inserted.
"""
- def select_devices_txn(txn, user_id, devices):
- if not devices:
- return []
- sql = (
- "SELECT user_id, device_id FROM devices"
- " WHERE user_id = ? AND device_id IN ("
- + ",".join("?" * len(devices))
- + ")"
+ def add_messages_txn(txn, now_ms, stream_id):
+ # Add the local messages directly to the local inbox.
+ self._add_messages_to_local_device_inbox_txn(
+ txn, stream_id, local_messages_by_user_then_device
)
- # TODO: Maybe this needs to be done in batches if there are
- # too many local devices for a given user.
- args = [user_id] + devices
- txn.execute(sql, args)
- return [tuple(row) for row in txn.fetchall()]
-
- def add_messages_to_device_inbox_txn(txn, stream_id):
- local_users_and_devices = set()
- for user_id, messages_by_device in messages_by_user_then_device.items():
- local_users_and_devices.update(
- select_devices_txn(txn, user_id, messages_by_device.keys())
- )
+ # Add the remote messages to the federation outbox.
+ # We'll send them to a remote server when we next send a
+ # federation transaction to that destination.
sql = (
- "INSERT INTO device_inbox"
- " (user_id, device_id, stream_id, message_json)"
+ "INSERT INTO device_federation_outbox"
+ " (destination, stream_id, queued_ts, messages_json)"
" VALUES (?,?,?,?)"
)
rows = []
- for user_id, messages_by_device in messages_by_user_then_device.items():
- for device_id, message in messages_by_device.items():
- message_json = ujson.dumps(message)
- # Only insert into the local inbox if the device exists on
- # this server
- if (user_id, device_id) in local_users_and_devices:
- rows.append((user_id, device_id, stream_id, message_json))
-
+ for destination, edu in remote_messages_by_destination.items():
+ edu_json = ujson.dumps(edu)
+ rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
+ now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_to_device_inbox",
- add_messages_to_device_inbox_txn,
- stream_id
+ add_messages_txn,
+ now_ms,
+ stream_id,
)
+ for user_id in local_messages_by_user_then_device.keys():
+ self._device_inbox_stream_cache.entity_has_changed(
+ user_id, stream_id
+ )
+ for destination in remote_messages_by_destination.keys():
+ self._device_federation_outbox_stream_cache.entity_has_changed(
+ destination, stream_id
+ )
defer.returnValue(self._device_inbox_id_gen.get_current_token())
+ @defer.inlineCallbacks
+ def add_messages_from_remote_to_device_inbox(
+ self, origin, message_id, local_messages_by_user_then_device
+ ):
+ def add_messages_txn(txn, now_ms, stream_id):
+ # Check if we've already inserted a matching message_id for that
+ # origin. This can happen if the origin doesn't receive our
+ # acknowledgement from the first time we received the message.
+ already_inserted = self._simple_select_one_txn(
+ txn, table="device_federation_inbox",
+ keyvalues={"origin": origin, "message_id": message_id},
+ retcols=("message_id",),
+ allow_none=True,
+ )
+ if already_inserted is not None:
+ return
+
+ # Add an entry for this message_id so that we know we've processed
+ # it.
+ self._simple_insert_txn(
+ txn, table="device_federation_inbox",
+ values={
+ "origin": origin,
+ "message_id": message_id,
+ "received_ts": now_ms,
+ },
+ )
+
+ # Add the messages to the approriate local device inboxes so that
+ # they'll be sent to the devices when they next sync.
+ self._add_messages_to_local_device_inbox_txn(
+ txn, stream_id, local_messages_by_user_then_device
+ )
+
+ with self._device_inbox_id_gen.get_next() as stream_id:
+ now_ms = self.clock.time_msec()
+ yield self.runInteraction(
+ "add_messages_from_remote_to_device_inbox",
+ add_messages_txn,
+ now_ms,
+ stream_id,
+ )
+ for user_id in local_messages_by_user_then_device.keys():
+ self._device_inbox_stream_cache.entity_has_changed(
+ user_id, stream_id
+ )
+
+ def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
+ messages_by_user_then_device):
+ sql = (
+ "UPDATE device_max_stream_id"
+ " SET stream_id = ?"
+ " WHERE stream_id < ?"
+ )
+ txn.execute(sql, (stream_id, stream_id))
+
+ local_by_user_then_device = {}
+ for user_id, messages_by_device in messages_by_user_then_device.items():
+ messages_json_for_user = {}
+ devices = messages_by_device.keys()
+ if len(devices) == 1 and devices[0] == "*":
+ # Handle wildcard device_ids.
+ sql = (
+ "SELECT device_id FROM devices"
+ " WHERE user_id = ?"
+ )
+ txn.execute(sql, (user_id,))
+ message_json = ujson.dumps(messages_by_device["*"])
+ for row in txn.fetchall():
+ # Add the message for all devices for this user on this
+ # server.
+ device = row[0]
+ messages_json_for_user[device] = message_json
+ else:
+ if not devices:
+ continue
+ sql = (
+ "SELECT device_id FROM devices"
+ " WHERE user_id = ? AND device_id IN ("
+ + ",".join("?" * len(devices))
+ + ")"
+ )
+ # TODO: Maybe this needs to be done in batches if there are
+ # too many local devices for a given user.
+ txn.execute(sql, [user_id] + devices)
+ for row in txn.fetchall():
+ # Only insert into the local inbox if the device exists on
+ # this server
+ device = row[0]
+ message_json = ujson.dumps(messages_by_device[device])
+ messages_json_for_user[device] = message_json
+
+ if messages_json_for_user:
+ local_by_user_then_device[user_id] = messages_json_for_user
+
+ if not local_by_user_then_device:
+ return
+
+ sql = (
+ "INSERT INTO device_inbox"
+ " (user_id, device_id, stream_id, message_json)"
+ " VALUES (?,?,?,?)"
+ )
+ rows = []
+ for user_id, messages_by_device in local_by_user_then_device.items():
+ for device_id, message_json in messages_by_device.items():
+ rows.append((user_id, device_id, stream_id, message_json))
+
+ txn.executemany(sql, rows)
+
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
@@ -97,6 +204,12 @@ class DeviceInboxStore(SQLBaseStore):
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_stream_id
+ )
+ if not has_changed:
+ return defer.succeed(([], current_stream_id))
+
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
@@ -182,3 +295,71 @@ class DeviceInboxStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
+
+ def get_new_device_msgs_for_remote(
+ self, destination, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ destination(str): The name of the remote server.
+ last_stream_id(int): The last position of the device message stream
+ that the server sent up to.
+ current_stream_id(int): The current position of the device
+ message stream.
+ Returns:
+ Deferred ([dict], int): List of messages for the device and where
+ in the stream the messages got to.
+ """
+
+ has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
+ destination, last_stream_id
+ )
+ if not has_changed or last_stream_id == current_stream_id:
+ return defer.succeed(([], current_stream_id))
+
+ def get_new_messages_for_remote_destination_txn(txn):
+ sql = (
+ "SELECT stream_id, messages_json FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ destination, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn.fetchall():
+ stream_pos = row[0]
+ messages.append(ujson.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_device_msgs_for_remote",
+ get_new_messages_for_remote_destination_txn,
+ )
+
+ def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+ """Used to delete messages when the remote destination acknowledges
+ their receipt.
+
+ Args:
+ destination(str): The destination server_name
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves when the messages have been deleted.
+ """
+ def delete_messages_for_remote_destination_txn(txn):
+ sql = (
+ "DELETE FROM device_federation_outbox"
+ " WHERE destination = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (destination, up_to_stream_id))
+
+ return self.runInteraction(
+ "delete_device_msgs_for_remote",
+ delete_messages_for_remote_destination_txn
+ )
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a67c886f9a..a87d90741a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -343,7 +343,7 @@ class EventPushActionsStore(SQLBaseStore):
def f(txn):
before_clause = ""
if before:
- before_clause = "AND stream_ordering < ?"
+ before_clause = "AND epa.stream_ordering < ?"
args = [user_id, before, limit]
else:
args = [user_id, limit]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 6ab10db328..866d64e679 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -402,7 +402,7 @@ class RoomMemberStore(SQLBaseStore):
keyvalues={
"membership": Membership.JOIN,
},
- batch_size=1000,
+ batch_size=500,
desc="_get_joined_users_from_context",
)
diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/schema/delta/35/add_state_index.sql
new file mode 100644
index 0000000000..0fce26345b
--- /dev/null
+++ b/synapse/storage/schema/delta/35/add_state_index.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');
diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/schema/delta/35/device_outbox.sql
new file mode 100644
index 0000000000..17e6c43105
--- /dev/null
+++ b/synapse/storage/schema/delta/35/device_outbox.sql
@@ -0,0 +1,39 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP TABLE IF EXISTS device_federation_outbox;
+CREATE TABLE device_federation_outbox (
+ destination TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
+ queued_ts BIGINT NOT NULL,
+ messages_json TEXT NOT NULL
+);
+
+
+DROP INDEX IF EXISTS device_federation_outbox_destination_id;
+CREATE INDEX device_federation_outbox_destination_id
+ ON device_federation_outbox(destination, stream_id);
+
+
+DROP TABLE IF EXISTS device_federation_inbox;
+CREATE TABLE device_federation_inbox (
+ origin TEXT NOT NULL,
+ message_id TEXT NOT NULL,
+ received_ts BIGINT NOT NULL
+);
+
+DROP INDEX IF EXISTS device_federation_inbox_sender_id;
+CREATE INDEX device_federation_inbox_sender_id
+ ON device_federation_inbox(origin, message_id);
diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/schema/delta/35/device_stream_id.sql
new file mode 100644
index 0000000000..7ab7d942e2
--- /dev/null
+++ b/synapse/storage/schema/delta/35/device_stream_id.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_max_stream_id (
+ stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_max_stream_id (stream_id)
+ SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index fef87834ca..0cff0a0cda 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -48,6 +48,7 @@ class StateStore(SQLBaseStore):
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
def __init__(self, hs):
super(StateStore, self).__init__(hs)
@@ -55,6 +56,10 @@ class StateStore(SQLBaseStore):
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
+ self.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME,
+ self._background_index_state,
+ )
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
@@ -793,3 +798,31 @@ class StateStore(SQLBaseStore):
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
+
+ @defer.inlineCallbacks
+ def _background_index_state(self, progress, batch_size):
+ def reindex_txn(txn):
+ if isinstance(self.database_engine, PostgresEngine):
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute(
+ "DROP INDEX IF EXISTS state_groups_state_id"
+ )
+ else:
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute(
+ "DROP INDEX IF EXISTS state_groups_state_id"
+ )
+
+ yield self.runInteraction(
+ self.STATE_GROUP_INDEX_UPDATE_NAME, reindex_txn
+ )
+
+ yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+ defer.returnValue(1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b2957eef9f..ea1f0f7c33 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -121,6 +121,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth.check_joined_room = check_joined_room
+ self.datastore.get_to_device_stream_token = lambda: 0
+ self.datastore.get_new_device_msgs_for_remote = (
+ lambda *args, **kargs: ([], 0)
+ )
+ self.datastore.delete_device_msgs_for_remote = (
+ lambda *args, **kargs: None
+ )
+
# Some local users to test with
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
|