diff --git a/synapse/__init__.py b/synapse/__init__.py
index da8ef90a77..f31cb9a3cb 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.19.1"
+__version__ = "0.28.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab1b..f17fda6315 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -39,6 +40,10 @@ AuthEventTypes = (
GUEST_DEVICE_ID = "guest_device"
+class _InvalidMacaroonException(Exception):
+ pass
+
+
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
@@ -51,6 +56,9 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+ self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+ register_cache("token_cache", self.token_cache)
+
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
@@ -144,17 +152,8 @@ class Auth(object):
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-
- logger.debug("calling resolve_state_groups from check_host_in_room")
- entry = yield self.state.resolve_state_groups(
- room_id, latest_event_ids
- )
-
- ret = yield self.store.is_host_joined(
- room_id, host, entry.state_group, entry.state
- )
- defer.returnValue(ret)
+ latest_event_ids = yield self.store.is_host_joined(room_id, host)
+ defer.returnValue(latest_event_ids)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
@@ -205,13 +204,12 @@ class Auth(object):
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- "User-Agent",
- default=[""]
+ b"User-Agent",
+ default=[b""]
)[0]
if user and access_token and ip_addr:
- preserve_context_over_fn(
- self.store.insert_client_ip,
- user=user,
+ self.store.insert_client_ip(
+ user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@@ -272,13 +270,17 @@ class Auth(object):
rights (str): The operation being performed; the access token must
allow this.
Returns:
- dict : dict that includes the user and the ID of their access token.
+ Deferred[dict]: dict that includes:
+ `user` (UserID)
+ `is_guest` (bool)
+ `token_id` (int|None): access token id. May be None if guest
+ `device_id` (str|None): device corresponding to access token
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
try:
- macaroon = pymacaroons.Macaroon.deserialize(token)
- except Exception: # deserialize can throw more-or-less anything
+ user_id, guest = self._parse_and_validate_macaroon(token, rights)
+ except _InvalidMacaroonException:
# doesn't look like a macaroon: treat it as an opaque token which
# must be in the database.
# TODO: it would be nice to get rid of this, but apparently some
@@ -287,19 +289,8 @@ class Auth(object):
defer.returnValue(r)
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
- self.validate_macaroon(
- macaroon, rights, self.hs.config.expire_access_token,
- user_id=user_id,
- )
-
- guest = False
- for caveat in macaroon.caveats:
- if caveat.caveat_id == "guest = true":
- guest = True
-
if guest:
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
@@ -371,6 +362,55 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
+ def _parse_and_validate_macaroon(self, token, rights="access"):
+ """Takes a macaroon and tries to parse and validate it. This is cached
+ if and only if rights == access and there isn't an expiry.
+
+ On invalid macaroon raises _InvalidMacaroonException
+
+ Returns:
+ (user_id, is_guest)
+ """
+ if rights == "access":
+ cached = self.token_cache.get(token, None)
+ if cached:
+ return cached
+
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ except Exception: # deserialize can throw more-or-less anything
+ # doesn't look like a macaroon: treat it as an opaque token which
+ # must be in the database.
+ # TODO: it would be nice to get rid of this, but apparently some
+ # people use access tokens which aren't macaroons
+ raise _InvalidMacaroonException()
+
+ try:
+ user_id = self.get_user_id_from_macaroon(macaroon)
+
+ has_expiry = False
+ guest = False
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith("time "):
+ has_expiry = True
+ elif caveat.caveat_id == "guest = true":
+ guest = True
+
+ self.validate_macaroon(
+ macaroon, rights, self.hs.config.expire_access_token,
+ user_id=user_id,
+ )
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
+ if not has_expiry and rights == "access":
+ self.token_cache[token] = (user_id, guest)
+
+ return user_id, guest
+
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
@@ -483,6 +523,14 @@ class Auth(object):
)
def is_server_admin(self, user):
+ """ Check if the given user is a local server admin.
+
+ Args:
+ user (str): mxid of user to check
+
+ Returns:
+ bool: True if the user is an admin
+ """
return self.store.is_server_admin(user)
@defer.inlineCallbacks
@@ -624,7 +672,7 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -644,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
AuthError: If there isn't an access_token in the request.
"""
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
- query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index ca23c9c460..5baba43966 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,9 @@
"""Contains constants from the specification."""
+# the "depth" field on events is limited to 2**63 - 1
+MAX_DEPTH = 2**63 - 1
+
class Membership(object):
@@ -44,6 +48,7 @@ class JoinRules(object):
class LoginType(object):
PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity"
+ MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 921c457738..a9ff5576f3 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,9 @@
import logging
+import simplejson as json
+from six import iteritems
+
logger = logging.getLogger(__name__)
@@ -45,32 +48,52 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+ THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError):
- """An exception with integer code and message string attributes."""
+ """An exception with integer code and message string attributes.
+ Attributes:
+ code (int): HTTP error code
+ msg (str): string describing the error
+ """
def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
self.msg = msg
- self.response_code_message = None
def error_dict(self):
return cs_error(self.msg)
+class MatrixCodeMessageException(CodeMessageException):
+ """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
+
+ Attributes:
+ errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ """
+ def __init__(self, code, msg, errcode=Codes.UNKNOWN):
+ super(MatrixCodeMessageException, self).__init__(code, msg)
+ self.errcode = errcode
+
+
class SynapseError(CodeMessageException):
- """A base error which can be caught for all synapse events."""
+ """A base exception type for matrix errors which have an errcode and error
+ message (as well as an HTTP status code).
+
+ Attributes:
+ errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ """
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
Args:
code (int): The integer error code (an HTTP response code)
msg (str): The human-readable error message.
- err (str): The error code e.g 'M_FORBIDDEN'
+ errcode (str): The matrix error code e.g 'M_FORBIDDEN'
"""
super(SynapseError, self).__init__(code, msg)
self.errcode = errcode
@@ -81,12 +104,87 @@ class SynapseError(CodeMessageException):
self.errcode,
)
+ @classmethod
+ def from_http_response_exception(cls, err):
+ """Make a SynapseError based on an HTTPResponseException
+
+ This is useful when a proxied request has failed, and we need to
+ decide how to map the failure onto a matrix error to send back to the
+ client.
+
+ An attempt is made to parse the body of the http response as a matrix
+ error. If that succeeds, the errcode and error message from the body
+ are used as the errcode and error message in the new synapse error.
+
+ Otherwise, the errcode is set to M_UNKNOWN, and the error message is
+ set to the reason code from the HTTP response.
+
+ Args:
+ err (HttpResponseException):
+
+ Returns:
+ SynapseError:
+ """
+ # try to parse the body as json, to get better errcode/msg, but
+ # default to M_UNKNOWN with the HTTP status as the error text
+ try:
+ j = json.loads(err.response)
+ except ValueError:
+ j = {}
+ errcode = j.get('errcode', Codes.UNKNOWN)
+ errmsg = j.get('error', err.msg)
+
+ res = SynapseError(err.code, errmsg, errcode)
+ return res
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
+class FederationDeniedError(SynapseError):
+ """An error raised when the server tries to federate with a server which
+ is not on its federation whitelist.
+
+ Attributes:
+ destination (str): The destination which has been denied
+ """
+
+ def __init__(self, destination):
+ """Raised by federation client or server to indicate that we are
+ are deliberately not attempting to contact a given server because it is
+ not on our federation whitelist.
+
+ Args:
+ destination (str): the domain in question
+ """
+
+ self.destination = destination
+
+ super(FederationDeniedError, self).__init__(
+ code=403,
+ msg="Federation denied with %s." % (self.destination,),
+ errcode=Codes.FORBIDDEN,
+ )
+
+
+class InteractiveAuthIncompleteError(Exception):
+ """An error raised when UI auth is not yet complete
+
+ (This indicates we should return a 401 with 'result' as the body)
+
+ Attributes:
+ result (dict): the server response to the request, which should be
+ passed back to the client
+ """
+ def __init__(self, result):
+ super(InteractiveAuthIncompleteError, self).__init__(
+ "Interactive auth not yet complete",
+ )
+ self.result = result
+
+
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
@@ -106,13 +204,11 @@ class UnrecognizedRequestError(SynapseError):
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
- def __init__(self, *args, **kwargs):
- if "errcode" not in kwargs:
- kwargs["errcode"] = Codes.NOT_FOUND
+ def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
super(NotFoundError, self).__init__(
404,
- "Not found",
- **kwargs
+ msg,
+ errcode=errcode
)
@@ -173,7 +269,6 @@ class LimitExceededError(SynapseError):
errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
- self.response_code_message = "Too Many Requests"
def error_dict(self):
return cs_error(
@@ -203,7 +298,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
A dict representing the error response JSON.
"""
err = {"error": msg, "errcode": code}
- for key, value in kwargs.iteritems():
+ for key, value in iteritems(kwargs):
err[key] = value
return err
@@ -243,6 +338,19 @@ class FederationError(RuntimeError):
class HttpResponseException(CodeMessageException):
+ """
+ Represents an HTTP-level failure of an outbound request
+
+ Attributes:
+ response (str): body of response
+ """
def __init__(self, code, msg, response):
- self.response = response
+ """
+
+ Args:
+ code (int): HTTP status code
+ msg (str): reason phrase from HTTP response status line
+ response (str): body of response
+ """
super(HttpResponseException, self).__init__(code, msg)
+ self.response = response
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index fb291d7fb9..db43219d24 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -13,11 +13,174 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
+from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID
-
from twisted.internet import defer
-import ujson as json
+import simplejson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "limit": {
+ "type": "number"
+ },
+ "senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "not_senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ # TODO: We don't limit event type values but we probably should...
+ # check types are valid event types
+ "types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "not_types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+}
+
+ROOM_FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "not_rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "ephemeral": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "include_leave": {
+ "type": "boolean"
+ },
+ "state": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "timeline": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "account_data": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "limit": {
+ "type": "number"
+ },
+ "senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "not_senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "not_types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "not_rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "contains_url": {
+ "type": "boolean"
+ }
+ }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "format": "matrix_user_id"
+ }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "format": "matrix_room_id"
+ }
+}
+
+USER_FILTER_SCHEMA = {
+ "$schema": "http://json-schema.org/draft-04/schema#",
+ "description": "schema for a Sync filter",
+ "type": "object",
+ "definitions": {
+ "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+ "user_id_array": USER_ID_ARRAY_SCHEMA,
+ "filter": FILTER_SCHEMA,
+ "room_filter": ROOM_FILTER_SCHEMA,
+ "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+ },
+ "properties": {
+ "presence": {
+ "$ref": "#/definitions/filter"
+ },
+ "account_data": {
+ "$ref": "#/definitions/filter"
+ },
+ "room": {
+ "$ref": "#/definitions/room_filter"
+ },
+ "event_format": {
+ "type": "string",
+ "enum": ["client", "federation"]
+ },
+ "event_fields": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ # Don't allow '\\' in event field filters. This makes matching
+ # events a lot easier as we can then use a negative lookbehind
+ # assertion to split '\.' If we allowed \\ then it would
+ # incorrectly split '\\.' See synapse.events.utils.serialize_event
+ "pattern": "^((?!\\\).)*$"
+ }
+ }
+ },
+ "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+ return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+ return UserID.from_string(user_id_str)
class Filtering(object):
@@ -52,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
-
- top_level_definitions = [
- "presence", "account_data"
- ]
-
- room_level_definitions = [
- "state", "timeline", "ephemeral", "account_data"
- ]
-
- for key in top_level_definitions:
- if key in user_filter_json:
- self._check_definition(user_filter_json[key])
-
- if "room" in user_filter_json:
- self._check_definition_room_lists(user_filter_json["room"])
- for key in room_level_definitions:
- if key in user_filter_json["room"]:
- self._check_definition(user_filter_json["room"][key])
-
- if "event_fields" in user_filter_json:
- if type(user_filter_json["event_fields"]) != list:
- raise SynapseError(400, "event_fields must be a list of strings")
- for field in user_filter_json["event_fields"]:
- if not isinstance(field, basestring):
- raise SynapseError(400, "Event field must be a string")
- # Don't allow '\\' in event field filters. This makes matching
- # events a lot easier as we can then use a negative lookbehind
- # assertion to split '\.' If we allowed \\ then it would
- # incorrectly split '\\.' See synapse.events.utils.serialize_event
- if r'\\' in field:
- raise SynapseError(
- 400, r'The escape character \ cannot itself be escaped'
- )
-
- def _check_definition_room_lists(self, definition):
- """Check that "rooms" and "not_rooms" are lists of room ids if they
- are present
-
- Args:
- definition(dict): The filter definition
- Raises:
- SynapseError: If there was a problem with this definition.
- """
- # check rooms are valid room IDs
- room_id_keys = ["rooms", "not_rooms"]
- for key in room_id_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for room_id in definition[key]:
- RoomID.from_string(room_id)
-
- def _check_definition(self, definition):
- """Check if the provided definition is valid.
-
- This inspects not only the types but also the values to make sure they
- make sense.
-
- Args:
- definition(dict): The filter definition
- Raises:
- SynapseError: If there was a problem with this definition.
- """
- # NB: Filters are the complete json blobs. "Definitions" are an
- # individual top-level key e.g. public_user_data. Filters are made of
- # many definitions.
- if type(definition) != dict:
- raise SynapseError(
- 400, "Expected JSON object, not %s" % (definition,)
- )
-
- self._check_definition_room_lists(definition)
-
- # check senders are valid user IDs
- user_id_keys = ["senders", "not_senders"]
- for key in user_id_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for user_id in definition[key]:
- UserID.from_string(user_id)
-
- # TODO: We don't limit event type values but we probably should...
- # check types are valid event types
- event_keys = ["types", "not_types"]
- for key in event_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for event_type in definition[key]:
- if not isinstance(event_type, basestring):
- raise SynapseError(400, "Event type should be a string")
+ try:
+ jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+ format_checker=FormatChecker())
+ except jsonschema.ValidationError as e:
+ raise SynapseError(400, e.message)
class FilterCollection(object):
@@ -253,19 +329,35 @@ class Filter(object):
Returns:
bool: True if the event matches
"""
- sender = event.get("sender", None)
- if not sender:
- # Presence events have their 'sender' in content.user_id
- content = event.get("content")
- # account_data has been allowed to have non-dict content, so check type first
- if isinstance(content, dict):
- sender = content.get("user_id")
+ # We usually get the full "events" as dictionaries coming through,
+ # except for presence which actually gets passed around as its own
+ # namedtuple type.
+ if isinstance(event, UserPresenceState):
+ sender = event.user_id
+ room_id = None
+ ev_type = "m.presence"
+ is_url = False
+ else:
+ sender = event.get("sender", None)
+ if not sender:
+ # Presence events had their 'sender' in content.user_id, but are
+ # now handled above. We don't know if anything else uses this
+ # form. TODO: Check this and probably remove it.
+ content = event.get("content")
+ # account_data has been allowed to have non-dict content, so
+ # check type first
+ if isinstance(content, dict):
+ sender = content.get("user_id")
+
+ room_id = event.get("room_id", None)
+ ev_type = event.get("type", None)
+ is_url = "url" in event.get("content", {})
return self.check_fields(
- event.get("room_id", None),
+ room_id,
sender,
- event.get("type", None),
- "url" in event.get("content", {})
+ ev_type,
+ is_url,
)
def check_fields(self, room_id, sender, event_type, contains_url):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
new file mode 100644
index 0000000000..e4318cdfc3
--- /dev/null
+++ b/synapse/app/_base.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import logging
+import sys
+
+try:
+ import affinity
+except Exception:
+ affinity = None
+
+from daemonize import Daemonize
+from synapse.util import PreserveLoggingContext
+from synapse.util.rlimit import change_resource_limit
+from twisted.internet import error, reactor
+
+logger = logging.getLogger(__name__)
+
+
+def start_worker_reactor(appname, config):
+ """ Run the reactor in the main process
+
+ Daemonizes if necessary, and then configures some resources, before starting
+ the reactor. Pulls configuration from the 'worker' settings in 'config'.
+
+ Args:
+ appname (str): application name which will be sent to syslog
+ config (synapse.config.Config): config object
+ """
+
+ logger = logging.getLogger(config.worker_app)
+
+ start_reactor(
+ appname,
+ config.soft_file_limit,
+ config.gc_thresholds,
+ config.worker_pid_file,
+ config.worker_daemonize,
+ config.worker_cpu_affinity,
+ logger,
+ )
+
+
+def start_reactor(
+ appname,
+ soft_file_limit,
+ gc_thresholds,
+ pid_file,
+ daemonize,
+ cpu_affinity,
+ logger,
+):
+ """ Run the reactor in the main process
+
+ Daemonizes if necessary, and then configures some resources, before starting
+ the reactor
+
+ Args:
+ appname (str): application name which will be sent to syslog
+ soft_file_limit (int):
+ gc_thresholds:
+ pid_file (str): name of pid file to write to if daemonize is True
+ daemonize (bool): true to run the reactor in a background process
+ cpu_affinity (int|None): cpu affinity mask
+ logger (logging.Logger): logger instance to pass to Daemonize
+ """
+
+ def run():
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
+ logger.info("Running")
+ if cpu_affinity is not None:
+ if not affinity:
+ quit_with_error(
+ "Missing package 'affinity' required for cpu_affinity\n"
+ "option\n\n"
+ "Install by running:\n\n"
+ " pip install affinity\n\n"
+ )
+ logger.info("Setting CPU affinity to %s" % cpu_affinity)
+ affinity.set_process_affinity_mask(0, cpu_affinity)
+ change_resource_limit(soft_file_limit)
+ if gc_thresholds:
+ gc.set_threshold(*gc_thresholds)
+ reactor.run()
+
+ if daemonize:
+ daemon = Daemonize(
+ app=appname,
+ pid=pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+def quit_with_error(error_string):
+ message_lines = error_string.split("\n")
+ line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
+ sys.stderr.write("*" * line_length + '\n')
+ for line in message_lines:
+ sys.stderr.write(" %s\n" % (line.rstrip(),))
+ sys.stderr.write("*" * line_length + '\n')
+ sys.exit(1)
+
+
+def listen_tcp(bind_addresses, port, factory, backlog=50):
+ """
+ Create a TCP socket for a port and several addresses
+ """
+ for address in bind_addresses:
+ try:
+ reactor.listenTCP(
+ port,
+ factory,
+ backlog,
+ address
+ )
+ except error.CannotListenError as e:
+ check_bind_error(e, address, bind_addresses)
+
+
+def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+ """
+ Create an SSL socket for a port and several addresses
+ """
+ for address in bind_addresses:
+ try:
+ reactor.listenSSL(
+ port,
+ factory,
+ context_factory,
+ backlog,
+ address
+ )
+ except error.CannotListenError as e:
+ check_bind_error(e, address, bind_addresses)
+
+
+def check_bind_error(e, address, bind_addresses):
+ """
+ This method checks an exception occurred while binding on 0.0.0.0.
+ If :: is specified in the bind addresses a warning is shown.
+ The exception is still raised otherwise.
+
+ Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
+ because :: binds on both IPv4 and IPv6 (as per RFC 3493).
+ When binding on 0.0.0.0 after :: this can safely be ignored.
+
+ Args:
+ e (Exception): Exception that was caught.
+ address (str): Address on which binding was attempted.
+ bind_addresses (list): Addresses on which the service listens.
+ """
+ if address == '0.0.0.0' and '::' in bind_addresses:
+ logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+ else:
+ raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 1900930053..58f2c9d68c 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -13,37 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-
-from synapse import events
-
from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.appservice")
@@ -56,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@@ -84,19 +64,18 @@ class AppserviceServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse appservice now listening on port %d", port)
@@ -105,45 +84,42 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return ASReplicationHandler(self)
+
+
+class ASReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(ASReplicationHandler, self).__init__(hs.get_datastore())
+ self.appservice_handler = hs.get_application_service_handler()
+
+ def on_rdata(self, stream_name, token, rows):
+ super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+
+ if stream_name == "events":
+ max_stream_id = self.store.get_room_max_stream_ordering()
+ run_in_background(self._notify_app_services, max_stream_id)
+
@defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
- appservice_handler = self.get_application_service_handler()
-
- @defer.inlineCallbacks
- def replicate(results):
- stream = results.get("events")
- if stream:
- max_stream_id = stream["position"]
- yield appservice_handler.notify_interested_services(max_stream_id)
-
- while True:
- try:
- args = store.stream_positions()
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- replicate(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(30)
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
def start(config_options):
@@ -157,7 +133,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.appservice"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -186,33 +162,13 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
- ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-appservice",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-appservice", config)
if __name__ == '__main__':
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 4d081eccd1..267d34c881 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -13,45 +13,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
+from synapse.crypto import context_factory
from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.client_reader")
@@ -63,26 +56,14 @@ class ClientReaderSlavedStore(
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
+ TransactionStore,
+ SlavedClientIpStore,
BaseSlavedStore,
- ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
class ClientReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@@ -107,19 +88,18 @@ class ClientReaderServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse client reader now listening on port %d", port)
@@ -128,36 +108,23 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
+ self.get_tcp_replication().start_replication(self)
- while True:
- try:
- args = store.stream_positions()
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(5)
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
def start(config_options):
@@ -171,7 +138,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.client_reader"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -189,36 +156,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
- ss.replicate()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-client-reader",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-client-reader", config)
if __name__ == '__main__':
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
new file mode 100644
index 0000000000..b915d12d53
--- /dev/null
+++ b/synapse/app/event_creator.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.room import (
+ RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet,
+ JoinRoomAliasServlet,
+)
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.event_creator")
+
+
+class EventCreatorSlavedStore(
+ DirectoryStore,
+ TransactionStore,
+ SlavedProfileStore,
+ SlavedAccountDataStore,
+ SlavedPusherStore,
+ SlavedReceiptsStore,
+ SlavedPushRuleStore,
+ SlavedDeviceStore,
+ SlavedClientIpStore,
+ SlavedApplicationServiceStore,
+ SlavedEventStore,
+ SlavedRegistrationStore,
+ RoomStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class EventCreatorServer(HomeServer):
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ RoomSendEventRestServlet(self).register(resource)
+ RoomMembershipRestServlet(self).register(resource)
+ RoomStateEventRestServlet(self).register(resource)
+ JoinRoomAliasServlet(self).register(resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ )
+ )
+
+ logger.info("Synapse event creator now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse event creator", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.event_creator"
+
+ assert config.worker_replication_http_port is not None
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = EventCreatorServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.start_listening(config.worker_listeners)
+
+ def start():
+ ss.get_state_handler().start_caching()
+ ss.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ _base.start_worker_reactor("synapse-event-creator", config)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 90a4816753..c1dc66dd17 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -13,43 +13,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
+from synapse import events
+from synapse.api.urls import FEDERATION_PREFIX
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.api.urls import FEDERATION_PREFIX
-from synapse.federation.transport.server import TransportLayerServer
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.federation_reader")
@@ -66,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@@ -98,19 +77,18 @@ class FederationReaderServer(HomeServer):
FEDERATION_PREFIX: TransportLayerServer(self),
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse federation reader now listening on port %d", port)
@@ -119,36 +97,22 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
+ self.get_tcp_replication().start_replication(self)
- while True:
- try:
- args = store.stream_positions()
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(5)
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
def start(config_options):
@@ -162,7 +126,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_reader"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -180,36 +144,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
- ss.replicate()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-federation-reader",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-federation-reader", config)
if __name__ == '__main__':
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 411e47d98d..a08af83a4c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -13,69 +13,69 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
-from synapse.http.site import SynapseSite
from synapse.federation import send_queue
-from synapse.federation.units import Edu
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
from synapse.storage.engines import create_engine
-from synapse.storage.presence import UserPresenceState
-from synapse.util.async import sleep
+from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
-import ujson as json
-
-logger = logging.getLogger("synapse.app.appservice")
+logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
- SlavedRegistrationStore, SlavedDeviceStore,
+ SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
):
- pass
+ def __init__(self, db_conn, hs):
+ super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+
+ # We pull out the current federation stream position now so that we
+ # always have a known value for the federation position in memory so
+ # that we don't have to bounce via a deferred once when we start the
+ # replication streams.
+ self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+ def _get_federation_out_pos(self, db_conn):
+ sql = (
+ "SELECT stream_id FROM federation_stream_position"
+ " WHERE type = ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+ txn = db_conn.cursor()
+ txn.execute(sql, ("federation",))
+ rows = txn.fetchall()
+ txn.close()
+
+ return rows[0][0] if rows else -1
-class FederationSenderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
+class FederationSenderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@@ -91,19 +91,18 @@ class FederationSenderServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse federation_sender now listening on port %d", port)
@@ -112,41 +111,39 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
- send_handler = FederationSenderHandler(self)
-
- send_handler.on_start()
-
- while True:
- try:
- args = store.stream_positions()
- args.update((yield send_handler.stream_positions()))
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- yield send_handler.process_replication(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(30)
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return FederationSenderReplicationHandler(self)
+
+
+class FederationSenderReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
+ self.send_handler = FederationSenderHandler(hs, self)
+
+ def on_rdata(self, stream_name, token, rows):
+ super(FederationSenderReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ self.send_handler.process_replication_rows(stream_name, token, rows)
+
+ def get_streams_to_replicate(self):
+ args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
+ args.update(self.send_handler.stream_positions())
+ return args
def start(config_options):
@@ -160,7 +157,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_sender"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -192,42 +189,27 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
- ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
-
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-federation-sender",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-federation-sender", config)
class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries
to the federation sender.
"""
- def __init__(self, hs):
+ def __init__(self, hs, replication_client):
self.store = hs.get_datastore()
self.federation_sender = hs.get_federation_sender()
+ self.replication_client = replication_client
+
+ self.federation_position = self.store.federation_out_pos_startup
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+ self._last_ack = self.federation_position
self._room_serials = {}
self._room_typing = {}
@@ -239,98 +221,38 @@ class FederationSenderHandler(object):
self.store.get_room_max_stream_ordering()
)
- @defer.inlineCallbacks
def stream_positions(self):
- stream_id = yield self.store.get_federation_out_pos("federation")
- defer.returnValue({
- "federation": stream_id,
-
- # Ack stuff we've "processed", this should only be called from
- # one process.
- "federation_ack": stream_id,
- })
+ return {"federation": self.federation_position}
- @defer.inlineCallbacks
- def process_replication(self, result):
+ def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
- fed_stream = result.get("federation")
- if fed_stream:
- latest_id = int(fed_stream["position"])
-
- # The federation stream containis a bunch of different types of
- # rows that need to be handled differently. We parse the rows, put
- # them into the appropriate collection and then send them off.
- presence_to_send = {}
- keyed_edus = {}
- edus = {}
- failures = {}
- device_destinations = set()
-
- # Parse the rows in the stream
- for row in fed_stream["rows"]:
- position, typ, content_js = row
- content = json.loads(content_js)
-
- if typ == send_queue.PRESENCE_TYPE:
- destination = content["destination"]
- state = UserPresenceState.from_dict(content["state"])
-
- presence_to_send.setdefault(destination, []).append(state)
- elif typ == send_queue.KEYED_EDU_TYPE:
- key = content["key"]
- edu = Edu(**content["edu"])
-
- keyed_edus.setdefault(
- edu.destination, {}
- )[(edu.destination, tuple(key))] = edu
- elif typ == send_queue.EDU_TYPE:
- edu = Edu(**content)
-
- edus.setdefault(edu.destination, []).append(edu)
- elif typ == send_queue.FAILURE_TYPE:
- destination = content["destination"]
- failure = content["failure"]
-
- failures.setdefault(destination, []).append(failure)
- elif typ == send_queue.DEVICE_MESSAGE_TYPE:
- device_destinations.add(content["destination"])
- else:
- raise Exception("Unrecognised federation type: %r", typ)
-
- # We've finished collecting, send everything off
- for destination, states in presence_to_send.items():
- self.federation_sender.send_presence(destination, states)
-
- for destination, edu_map in keyed_edus.items():
- for key, edu in edu_map.items():
- self.federation_sender.send_edu(
- edu.destination, edu.edu_type, edu.content, key=key,
- )
-
- for destination, edu_list in edus.items():
- for edu in edu_list:
- self.federation_sender.send_edu(
- edu.destination, edu.edu_type, edu.content, key=None,
- )
+ if stream_name == "federation":
+ send_queue.process_rows_for_federation(self.federation_sender, rows)
+ run_in_background(self.update_token, token)
- for destination, failure_list in failures.items():
- for failure in failure_list:
- self.federation_sender.send_failure(destination, failure)
-
- for destination in device_destinations:
- self.federation_sender.send_device_messages(destination)
+ # We also need to poke the federation sender when new events happen
+ elif stream_name == "events":
+ self.federation_sender.notify_new_events(token)
- # Record where we are in the stream.
- yield self.store.update_federation_out_pos(
- "federation", latest_id
- )
+ @defer.inlineCallbacks
+ def update_token(self, token):
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (yield self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ yield self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
- # We also need to poke the federation sender when new events happen
- event_stream = result.get("events")
- if event_stream:
- latest_pos = event_stream["position"]
- self.federation_sender.notify_new_events(latest_pos)
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(self.federation_position)
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
if __name__ == '__main__':
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
new file mode 100644
index 0000000000..b349e3e3ce
--- /dev/null
+++ b/synapse/app/frontend_proxy.py
@@ -0,0 +1,227 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.api.errors import SynapseError
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request,
+)
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha._base import client_v2_patterns
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.frontend_proxy")
+
+
+class KeyUploadServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(KeyUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ if device_id is not None:
+ # passing the device_id here is deprecated; however, we allow it
+ # for now for compatibility with older clients.
+ if (requester.device_id is not None and
+ device_id != requester.device_id):
+ logger.warning("Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id, device_id)
+ else:
+ device_id = requester.device_id
+
+ if device_id is None:
+ raise SynapseError(
+ 400,
+ "To upload keys, you must pass device_id when authenticating"
+ )
+
+ if body:
+ # They're actually trying to upload something, proxy to main synapse.
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
+ headers = {
+ "Authorization": auth_headers,
+ }
+ result = yield self.http_client.post_json_get_json(
+ self.main_uri + request.uri,
+ body,
+ headers=headers,
+ )
+
+ defer.returnValue((200, result))
+ else:
+ # Just interested in counts.
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class FrontendProxySlavedStore(
+ SlavedDeviceStore,
+ SlavedClientIpStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class FrontendProxyServer(HomeServer):
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ KeyUploadServlet(self).register(resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ )
+ )
+
+ logger.info("Synapse client reader now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse frontend proxy", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.frontend_proxy"
+
+ assert config.worker_main_http_uri is not None
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = FrontendProxyServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.start_listening(config.worker_listeners)
+
+ def start():
+ ss.get_state_handler().start_caching()
+ ss.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ _base.start_worker_reactor("synapse-frontend-proxy", config)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e0b87468fe..a0e465d644 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,59 +13,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import synapse
-
import gc
import logging
import os
import sys
-from synapse.config._base import ConfigError
-
-from synapse.python_dependencies import (
- check_requirements, DEPENDENCY_LINKS
-)
-
-from synapse.rest import ClientRestResource
-from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
-from synapse.storage import are_all_users_on_domain
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
-
-from synapse.server import HomeServer
-from twisted.internet import reactor, task, defer
-from twisted.application import service
-from twisted.web.resource import Resource, EncodingResourceWrapper
-from twisted.web.static import File
-from twisted.web.server import GzipEncoderFactory
+import synapse
+import synapse.config.logger
+from synapse import events
+from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
+ LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
+ STATIC_PREFIX, WEB_CLIENT_PREFIX
+from synapse.app import _base
+from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.crypto import context_factory
+from synapse.federation.transport.server import TransportLayerServer
+from synapse.module_api import ModuleApi
+from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
-from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
+from synapse.http.site import SynapseSite
+from synapse.metrics import register_memory_metrics
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
+ check_requirements
+from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.api.urls import (
- FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
- SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
- SERVER_KEY_V2_PREFIX,
-)
-from synapse.config.homeserver import HomeServerConfig
-from synapse.crypto import context_factory
+from synapse.rest.media.v0.content_repository import ContentRepoResource
+from synapse.server import HomeServer
+from synapse.storage import are_all_users_on_domain
+from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
-from synapse.metrics import register_memory_metrics, get_metrics_for
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
-from synapse.federation.transport.server import TransportLayerServer
-
+from synapse.util.manhole import manhole
+from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-
-from synapse.http.site import SynapseSite
-
-from synapse import events
-
-from daemonize import Daemonize
+from twisted.application import service
+from twisted.internet import defer, reactor
+from twisted.web.resource import EncodingResourceWrapper, NoResource
+from twisted.web.server import GzipEncoderFactory
+from twisted.web.static import File
logger = logging.getLogger("synapse.app.homeserver")
@@ -90,7 +84,7 @@ def build_resource_for_web_client(hs):
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
- % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
+ % {"dep": CONDITIONAL_REQUIREMENTS["web_client"].keys()[0]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
@@ -117,90 +111,121 @@ class SynapseHomeServer(HomeServer):
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
- if name == "client":
- client_resource = ClientRestResource(self)
- if res["compress"]:
- client_resource = gz_wrap(client_resource)
-
- resources.update({
- "/_matrix/client/api/v1": client_resource,
- "/_matrix/client/r0": client_resource,
- "/_matrix/client/unstable": client_resource,
- "/_matrix/client/v2_alpha": client_resource,
- "/_matrix/client/versions": client_resource,
- })
-
- if name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
-
- if name in ["static", "client"]:
- resources.update({
- STATIC_PREFIX: File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- ),
- })
-
- if name in ["media", "federation", "client"]:
- media_repo = MediaRepositoryResource(self)
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
-
- if name in ["keys", "federation"]:
- resources.update({
- SERVER_KEY_PREFIX: LocalKey(self),
- SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
- })
-
- if name == "webclient":
- resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
-
- if name == "metrics" and self.get_config().enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(self)
-
- if name == "replication":
- resources[REPLICATION_PREFIX] = ReplicationResource(self)
+ resources.update(self._configure_named_resource(
+ name, res.get("compress", False),
+ ))
+
+ additional_resources = listener_config.get("additional_resources", {})
+ logger.debug("Configuring additional resources: %r",
+ additional_resources)
+ module_api = ModuleApi(self, self.get_auth_handler())
+ for path, resmodule in additional_resources.items():
+ handler_cls, config = load_module(resmodule)
+ handler = handler_cls(config, module_api)
+ resources[path] = AdditionalResource(self, handler.handle_request)
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
- root_resource = Resource()
+ root_resource = NoResource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
- for address in bind_addresses:
- reactor.listenSSL(
- port,
- SynapseSite(
- "synapse.access.https.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- self.tls_server_context_factory,
- interface=address
- )
+ listen_ssl(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.https.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ self.tls_server_context_factory,
+ )
+
else:
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse now listening on port %d", port)
+ def _configure_named_resource(self, name, compress=False):
+ """Build a resource map for a named resource
+
+ Args:
+ name (str): named resource: one of "client", "federation", etc
+ compress (bool): whether to enable gzip compression for this
+ resource
+
+ Returns:
+ dict[str, Resource]: map from path to HTTP resource
+ """
+ resources = {}
+ if name == "client":
+ client_resource = ClientRestResource(self)
+ if compress:
+ client_resource = gz_wrap(client_resource)
+
+ resources.update({
+ "/_matrix/client/api/v1": client_resource,
+ "/_matrix/client/r0": client_resource,
+ "/_matrix/client/unstable": client_resource,
+ "/_matrix/client/v2_alpha": client_resource,
+ "/_matrix/client/versions": client_resource,
+ })
+
+ if name == "federation":
+ resources.update({
+ FEDERATION_PREFIX: TransportLayerServer(self),
+ })
+
+ if name in ["static", "client"]:
+ resources.update({
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ ),
+ })
+
+ if name in ["media", "federation", "client"]:
+ if self.get_config().enable_media_repo:
+ media_repo = self.get_media_repository_resource()
+ resources.update({
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path
+ ),
+ })
+ elif name == "media":
+ raise ConfigError(
+ "'media' resource conflicts with enable_media_repo=False",
+ )
+
+ if name in ["keys", "federation"]:
+ resources.update({
+ SERVER_KEY_PREFIX: LocalKey(self),
+ SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
+ })
+
+ if name == "webclient":
+ resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
+
+ if name == "metrics" and self.get_config().enable_metrics:
+ resources[METRICS_PREFIX] = MetricsResource(self)
+
+ if name == "replication":
+ resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
+
+ return resources
+
def start_listening(self):
config = self.get_config()
@@ -208,17 +233,24 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
+ listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ elif listener["type"] == "replication":
bind_addresses = listener["bind_addresses"]
-
for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ factory = ReplicationStreamProtocolFactory(self)
+ server_listener = reactor.listenTCP(
+ listener["port"], factory, interface=address
+ )
+ reactor.addSystemEventTrigger(
+ "before", "shutdown", server_listener.stopListening,
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -239,29 +271,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
-
-def quit_with_error(error_string):
- message_lines = error_string.split("\n")
- line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
- sys.stderr.write("*" * line_length + '\n')
- for line in message_lines:
- sys.stderr.write(" %s\n" % (line.rstrip(),))
- sys.stderr.write("*" * line_length + '\n')
- sys.exit(1)
-
def setup(config_options):
"""
@@ -286,7 +295,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- config.setup_logging()
+ synapse.config.logger.setup_logging(config, use_worker_options=False)
# check any extra requirements we have now we have a config
check_requirements(config)
@@ -340,7 +349,7 @@ def setup(config_options):
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
- hs.get_replication_layer().start_get_pdu_cache()
+ hs.get_federation_client().start_get_pdu_cache()
register_memory_metrics(hs)
@@ -389,10 +398,15 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
- start_time = hs.get_clock().time()
+ clock = hs.get_clock()
+ start_time = clock.time()
stats = {}
+ # Contains the list of processes we will be monitoring
+ # currently either 0 or 1
+ stats_process = []
+
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@@ -401,41 +415,36 @@ def run(hs):
if uptime < 0:
uptime = 0
- # If the stats directory is empty then this is the first time we've
- # reported stats.
- first_time = not stats
-
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
+ total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- daily_messages = yield hs.get_datastore().count_daily_messages()
- if daily_messages is not None:
- stats["daily_messages"] = daily_messages
- else:
- stats.pop("daily_messages", None)
-
- if first_time:
- # Add callbacks to report the synapse stats as metrics whenever
- # prometheus requests them, typically every 30s.
- # As some of the stats are expensive to calculate we only update
- # them when synapse phones home to matrix.org every 24 hours.
- metrics = get_metrics_for("synapse.usage")
- metrics.add_callback("timestamp", lambda: stats["timestamp"])
- metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
- metrics.add_callback("total_users", lambda: stats["total_users"])
- metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
- metrics.add_callback(
- "daily_active_users", lambda: stats["daily_active_users"]
- )
- metrics.add_callback(
- "daily_messages", lambda: stats.get("daily_messages", 0)
- )
+ stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in r30_results.iteritems():
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ if len(stats_process) > 0:
+ stats["memory_rss"] = 0
+ stats["cpu_average"] = 0
+ for process in stats_process:
+ stats["memory_rss"] += process.memory_info().rss
+ stats["cpu_average"] += int(process.cpu_percent(interval=None))
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
@@ -446,37 +455,48 @@ def run(hs):
except Exception as e:
logger.warn("Error reporting stats: %s", e)
- if hs.config.report_stats:
- phone_home_task = task.LoopingCall(phone_stats_home)
- logger.info("Scheduling stats reporting for 24 hour intervals")
- phone_home_task.start(60 * 60 * 24, now=False)
-
- def in_thread():
- # Uncomment to enable tracing of log context changes.
- # sys.settrace(logcontext_tracer)
- with LoggingContext("run"):
- change_resource_limit(hs.config.soft_file_limit)
- if hs.config.gc_thresholds:
- gc.set_threshold(*hs.config.gc_thresholds)
- reactor.run()
-
- if hs.config.daemonize:
-
- if hs.config.print_pidfile:
- print (hs.config.pid_file)
-
- daemon = Daemonize(
- app="synapse-homeserver",
- pid=hs.config.pid_file,
- action=lambda: in_thread(),
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
+ def performance_stats_init():
+ try:
+ import psutil
+ process = psutil.Process()
+ # Ensure we can fetch both, and make the initial request for cpu_percent
+ # so the next request will use this as the initial point.
+ process.memory_info().rss
+ process.cpu_percent(interval=None)
+ logger.info("report_stats can use psutil")
+ stats_process.append(process)
+ except (ImportError, AttributeError):
+ logger.warn(
+ "report_stats enabled but psutil is not installed or incorrect version."
+ " Disabling reporting of memory/cpu stats."
+ " Ensuring psutil is available will help matrix.org track performance"
+ " changes across releases."
+ )
- daemon.start()
- else:
- in_thread()
+ if hs.config.report_stats:
+ logger.info("Scheduling stats reporting for 3 hour intervals")
+ clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
+
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
+ # We wait 5 minutes to send the first set of stats as the server can
+ # be quite busy the first few minutes
+ clock.call_later(5 * 60, phone_stats_home)
+
+ if hs.config.daemonize and hs.config.print_pidfile:
+ print (hs.config.pid_file)
+
+ _base.start_reactor(
+ "synapse-homeserver",
+ hs.config.soft_file_limit,
+ hs.config.gc_thresholds,
+ hs.config.pid_file,
+ hs.config.daemonize,
+ hs.config.cpu_affinity,
+ logger,
+ )
def main():
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index ef17b158a5..fc8282bbc1 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -13,45 +13,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
+from synapse import events
+from synapse.api.urls import (
+ CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
+)
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
-from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.api.urls import (
- CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
-)
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.media_repository")
@@ -59,27 +51,15 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
+ SlavedClientIpStore,
+ TransactionStore,
BaseSlavedStore,
MediaRepositoryStore,
- ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@@ -95,7 +75,7 @@ class MediaRepositoryServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
- media_repo = MediaRepositoryResource(self)
+ media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
@@ -104,19 +84,18 @@ class MediaRepositoryServer(HomeServer):
),
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse media repository now listening on port %d", port)
@@ -125,36 +104,22 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
+ self.get_tcp_replication().start_replication(self)
- while True:
- try:
- args = store.stream_positions()
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(5)
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
def start(config_options):
@@ -168,7 +133,14 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ if config.enable_media_repo:
+ _base.quit_with_error(
+ "enable_media_repo must be disabled in the main synapse process\n"
+ "before the media repo can be run in a separate worker.\n"
+ "Please add ``enable_media_repo: false`` to the main config\n"
+ )
+
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -186,36 +158,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
- ss.replicate()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-media-repository",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-media-repository", config)
if __name__ == '__main__':
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 073f2c2489..26930d1b3b 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,39 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.storage.roommember import RoomMemberStore
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.storage.engines import create_engine
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.util.async import sleep
+from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.pusher")
@@ -82,42 +74,15 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
- who_forgot_in_room = (
- RoomMemberStore.__dict__["who_forgot_in_room"]
- )
-
class PusherServer(HomeServer):
-
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def remove_pusher(self, app_id, push_key, user_id):
- http_client = self.get_simple_http_client()
- replication_url = self.config.worker_replication_url
- url = replication_url + "/remove_pushers"
- return http_client.post_json_get_json(url, {
- "remove": [{
- "app_id": app_id,
- "push_key": push_key,
- "user_id": user_id,
- }]
- })
+ self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -129,19 +94,18 @@ class PusherServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse pusher now listening on port %d", port)
@@ -150,88 +114,67 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return PusherReplicationHandler(self)
+
+
+class PusherReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(PusherReplicationHandler, self).__init__(hs.get_datastore())
+
+ self.pusher_pool = hs.get_pusherpool()
+
+ def on_rdata(self, stream_name, token, rows):
+ super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+ run_in_background(self.poke_pushers, stream_name, token, rows)
+
@defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
- pusher_pool = self.get_pusherpool()
-
- def stop_pusher(user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- pushers_for_user = pusher_pool.pushers.get(user_id, {})
- pusher = pushers_for_user.pop(key, None)
- if pusher is None:
- return
- logger.info("Stopping pusher %r / %r", user_id, key)
- pusher.on_stop()
-
- def start_pusher(user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- logger.info("Starting pusher %r / %r", user_id, key)
- return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
-
- @defer.inlineCallbacks
- def poke_pushers(results):
- pushers_rows = set(
- map(tuple, results.get("pushers", {}).get("rows", []))
- )
- deleted_pushers_rows = set(
- map(tuple, results.get("deleted_pushers", {}).get("rows", []))
- )
- for row in sorted(pushers_rows | deleted_pushers_rows):
- if row in deleted_pushers_rows:
- user_id, app_id, pushkey = row[1:4]
- stop_pusher(user_id, app_id, pushkey)
- elif row in pushers_rows:
- user_id = row[1]
- app_id = row[5]
- pushkey = row[8]
- yield start_pusher(user_id, app_id, pushkey)
-
- stream = results.get("events")
- if stream and stream["rows"]:
- min_stream_id = stream["rows"][0][0]
- max_stream_id = stream["position"]
- preserve_fn(pusher_pool.on_new_notifications)(
- min_stream_id, max_stream_id
+ def poke_pushers(self, stream_name, token, rows):
+ try:
+ if stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ elif stream_name == "events":
+ yield self.pusher_pool.on_new_notifications(
+ token, token,
)
-
- stream = results.get("receipts")
- if stream and stream["rows"]:
- rows = stream["rows"]
- affected_room_ids = set(row[1] for row in rows)
- min_stream_id = rows[0][0]
- max_stream_id = stream["position"]
- preserve_fn(pusher_pool.on_new_receipts)(
- min_stream_id, max_stream_id, affected_room_ids
+ elif stream_name == "receipts":
+ yield self.pusher_pool.on_new_receipts(
+ token, token, set(row.room_id for row in rows)
)
+ except Exception:
+ logger.exception("Error poking pushers")
+
+ def stop_pusher(self, user_id, app_id, pushkey):
+ key = "%s:%s" % (app_id, pushkey)
+ pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+ pusher = pushers_for_user.pop(key, None)
+ if pusher is None:
+ return
+ logger.info("Stopping pusher %r / %r", user_id, key)
+ pusher.on_stop()
- while True:
- try:
- args = store.stream_positions()
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- poke_pushers(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(30)
+ def start_pusher(self, user_id, app_id, pushkey):
+ key = "%s:%s" % (app_id, pushkey)
+ logger.info("Starting pusher %r / %r", user_id, key)
+ return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
def start(config_options):
@@ -245,7 +188,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.pusher"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -274,34 +217,14 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
- ps.replicate()
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-pusher",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-pusher", config)
if __name__ == '__main__':
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 3f29595256..7152b1deb4 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -13,105 +13,87 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
+import logging
+import sys
import synapse
-
-from synapse.api.constants import EventTypes, PresenceState
+from synapse.api.constants import EventTypes
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.events import FrozenEvent
-from synapse.handlers.presence import PresenceHandler
-from synapse.http.site import SynapseSite
+from synapse.handlers.presence import PresenceHandler, get_interested_parties
from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.rest.client.v2_alpha import sync
-from synapse.rest.client.v1 import events
-from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
-from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
+from synapse.rest.client.v2_alpha import sync
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
-from synapse.storage.presence import PresenceStore, UserPresenceState
+from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
-from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import contextlib
-import gc
-import ujson as json
+from six import iteritems
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore(
- SlavedPushRuleStore,
- SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
+ SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
+ SlavedPushRuleStore,
+ SlavedEventStore,
+ SlavedClientIpStore,
RoomStore,
BaseSlavedStore,
- ClientIpStore, # After BaseSlavedStore because the constructor is different
):
- who_forgot_in_room = (
- RoomMemberStore.__dict__["who_forgot_in_room"]
- )
-
did_forget = (
RoomMemberStore.__dict__["did_forget"]
)
- # XXX: This is a bit broken because we don't persist the accepted list in a
- # way that can be replicated. This means that we don't have a way to
- # invalidate the cache correctly.
- get_presence_list_accepted = PresenceStore.__dict__[
- "get_presence_list_accepted"
- ]
- get_presence_list_observers_accepted = PresenceStore.__dict__[
- "get_presence_list_observers_accepted"
- ]
-
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
+ self.hs = hs
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
- self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -121,17 +103,52 @@ class SynchrotronPresence(object):
for state in active_presence
}
- self.process_id = random_string(16)
- logger.info("Presence process_id is %r", self.process_id)
+ # user_id -> last_sync_ms. Lists the users that have stopped syncing
+ # but we haven't notified the master of that yet
+ self.users_going_offline = {}
- self._sending_sync = False
- self._need_to_send_sync = False
- self.clock.looping_call(
- self._send_syncing_users_regularly,
- UPDATE_SYNCING_USERS_MS,
+ self._send_stop_syncing_loop = self.clock.looping_call(
+ self.send_stop_syncing, 10 * 1000
)
- reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
+ self.process_id = random_string(16)
+ logger.info("Presence process_id is %r", self.process_id)
+
+ def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
+
+ def mark_as_coming_online(self, user_id):
+ """A user has started syncing. Send a UserSync to the master, unless they
+ had recently stopped syncing.
+
+ Args:
+ user_id (str)
+ """
+ going_offline = self.users_going_offline.pop(user_id, None)
+ if not going_offline:
+ # Safe to skip because we haven't yet told the master they were offline
+ self.send_user_sync(user_id, True, self.clock.time_msec())
+
+ def mark_as_going_offline(self, user_id):
+ """A user has stopped syncing. We wait before notifying the master as
+ its likely they'll come back soon. This allows us to avoid sending
+ a stopped syncing immediately followed by a started syncing notification
+ to the master
+
+ Args:
+ user_id (str)
+ """
+ self.users_going_offline[user_id] = self.clock.time_msec()
+
+ def send_stop_syncing(self):
+ """Check if there are any users who have stopped syncing a while ago
+ and haven't come back yet. If there are poke the master about them.
+ """
+ now = self.clock.time_msec()
+ for user_id, last_sync_ms in self.users_going_offline.items():
+ if now - last_sync_ms > 10 * 1000:
+ self.users_going_offline.pop(user_id, None)
+ self.send_user_sync(user_id, False, last_sync_ms)
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
@@ -139,18 +156,16 @@ class SynchrotronPresence(object):
get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
- _get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
- @defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
- prev_states = yield self.current_state_for_users([user_id])
- if prev_states[user_id].state == PresenceState.OFFLINE:
- # TODO: Don't block the sync request on this HTTP hit.
- yield self._send_syncing_users_now()
+
+ # If we went from no in flight sync to some, notify replication
+ if self.user_to_num_current_syncs[user_id] == 1:
+ self.mark_as_coming_online(user_id)
def _end():
# We check that the user_id is in user_to_num_current_syncs because
@@ -159,6 +174,10 @@ class SynchrotronPresence(object):
if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1
+ # If we went from one in flight sync to non, notify replication
+ if self.user_to_num_current_syncs[user_id] == 0:
+ self.mark_as_going_offline(user_id)
+
@contextlib.contextmanager
def _user_syncing():
try:
@@ -166,56 +185,12 @@ class SynchrotronPresence(object):
finally:
_end()
- defer.returnValue(_user_syncing())
-
- @defer.inlineCallbacks
- def _on_shutdown(self):
- # When the synchrotron is shutdown tell the master to clear the in
- # progress syncs for this process
- self.user_to_num_current_syncs.clear()
- yield self._send_syncing_users_now()
-
- def _send_syncing_users_regularly(self):
- # Only send an update if we aren't in the middle of sending one.
- if not self._sending_sync:
- preserve_fn(self._send_syncing_users_now)()
-
- @defer.inlineCallbacks
- def _send_syncing_users_now(self):
- if self._sending_sync:
- # We don't want to race with sending another update.
- # Instead we wait for that update to finish and send another
- # update afterwards.
- self._need_to_send_sync = True
- return
-
- # Flag that we are sending an update.
- self._sending_sync = True
-
- yield self.http_client.post_json_get_json(self.syncing_users_url, {
- "process_id": self.process_id,
- "syncing_users": [
- user_id for user_id, count in self.user_to_num_current_syncs.items()
- if count > 0
- ],
- })
-
- # Unset the flag as we are no longer sending an update.
- self._sending_sync = False
- if self._need_to_send_sync:
- # If something happened while we were sending the update then
- # we might need to send another update.
- # TODO: Check if the update that was sent matches the current state
- # as we only need to send an update if they are different.
- self._need_to_send_sync = False
- yield self._send_syncing_users_now()
+ return defer.succeed(_user_syncing())
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
- parties = yield self._get_interested_parties(
- states, calculate_remote_hosts=False
- )
- room_ids_to_states, users_to_states, _ = parties
+ parties = yield get_interested_parties(self.store, states)
+ room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
@@ -223,26 +198,24 @@ class SynchrotronPresence(object):
)
@defer.inlineCallbacks
- def process_replication(self, result):
- stream = result.get("presence", {"rows": []})
- states = []
- for row in stream["rows"]:
- (
- position, user_id, state, last_active_ts,
- last_federation_update_ts, last_user_sync_ts, status_msg,
- currently_active
- ) = row
- state = UserPresenceState(
- user_id, state, last_active_ts,
- last_federation_update_ts, last_user_sync_ts, status_msg,
- currently_active
- )
- self.user_to_current_state[user_id] = state
- states.append(state)
+ def process_replication_rows(self, token, rows):
+ states = [UserPresenceState(
+ row.user_id, row.state, row.last_active_ts,
+ row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
+ row.currently_active
+ ) for row in rows]
+
+ for state in states:
+ self.user_to_current_state[row.user_id] = state
+
+ stream_id = token
+ yield self.notify_from_replication(states, stream_id)
- if states and "position" in stream:
- stream_id = int(stream["position"])
- yield self.notify_from_replication(states, stream_id)
+ def get_currently_syncing_users(self):
+ return [
+ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
+ if count > 0
+ ]
class SynchrotronTyping(object):
@@ -257,16 +230,12 @@ class SynchrotronTyping(object):
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
- def process_replication(self, result):
- stream = result.get("typing")
- if stream:
- self._latest_room_serial = int(stream["position"])
+ def process_replication_rows(self, token, rows):
+ self._latest_room_serial = token
- for row in stream["rows"]:
- position, room_id, typing_json = row
- typing = json.loads(typing_json)
- self._room_serials[room_id] = position
- self._room_typing[room_id] = typing
+ for row in rows:
+ self._room_serials[row.room_id] = token
+ self._room_typing[row.room_id] = row.user_ids
class SynchrotronApplicationService(object):
@@ -275,19 +244,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@@ -315,19 +271,18 @@ class SynchrotronServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
)
+ )
logger.info("Synapse synchrotron now listening on port %d", port)
@@ -336,135 +291,106 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
- notifier = self.get_notifier()
- presence_handler = self.get_presence_handler()
- typing_handler = self.get_typing_handler()
-
- def notify_from_stream(
- result, stream_name, stream_key, room=None, user=None
- ):
- stream = result.get(stream_name)
- if stream:
- position_index = stream["field_names"].index("position")
- if room:
- room_index = stream["field_names"].index(room)
- if user:
- user_index = stream["field_names"].index(user)
-
- users = ()
- rooms = ()
- for row in stream["rows"]:
- position = row[position_index]
-
- if user:
- users = (row[user_index],)
-
- if room:
- rooms = (row[room_index],)
-
- notifier.on_new_event(
- stream_key, position, users=users, rooms=rooms
- )
+ self.get_tcp_replication().start_replication(self)
- @defer.inlineCallbacks
- def notify_device_list_update(result):
- stream = result.get("device_lists")
- if not stream:
- return
+ def build_tcp_replication(self):
+ return SyncReplicationHandler(self)
- position_index = stream["field_names"].index("position")
- user_index = stream["field_names"].index("user_id")
+ def build_presence_handler(self):
+ return SynchrotronPresence(self)
- for row in stream["rows"]:
- position = row[position_index]
- user_id = row[user_index]
+ def build_typing_handler(self):
+ return SynchrotronTyping(self)
- rooms = yield store.get_rooms_for_user(user_id)
- room_ids = [r.room_id for r in rooms]
- notifier.on_new_event(
- "device_list_key", position, rooms=room_ids,
- )
+class SyncReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(SyncReplicationHandler, self).__init__(hs.get_datastore())
- @defer.inlineCallbacks
- def notify(result):
- stream = result.get("events")
- if stream:
- max_position = stream["position"]
- for row in stream["rows"]:
- position = row[0]
- internal = json.loads(row[1])
- event_json = json.loads(row[2])
- event = FrozenEvent(event_json, internal_metadata_dict=internal)
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- notifier.on_new_room_event(
- event, position, max_position, extra_users
- )
+ self.store = hs.get_datastore()
+ self.typing_handler = hs.get_typing_handler()
+ # NB this is a SynchrotronPresence, not a normal PresenceHandler
+ self.presence_handler = hs.get_presence_handler()
+ self.notifier = hs.get_notifier()
- notify_from_stream(
- result, "push_rules", "push_rules_key", user="user_id"
- )
- notify_from_stream(
- result, "user_account_data", "account_data_key", user="user_id"
- )
- notify_from_stream(
- result, "room_account_data", "account_data_key", user="user_id"
- )
- notify_from_stream(
- result, "tag_account_data", "account_data_key", user="user_id"
- )
- notify_from_stream(
- result, "receipts", "receipt_key", room="room_id"
- )
- notify_from_stream(
- result, "typing", "typing_key", room="room_id"
- )
- notify_from_stream(
- result, "to_device", "to_device_key", user="user_id"
- )
- yield notify_device_list_update(result)
+ def on_rdata(self, stream_name, token, rows):
+ super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+ run_in_background(self.process_and_notify, stream_name, token, rows)
- while True:
- try:
- args = store.stream_positions()
- args.update(typing_handler.stream_positions())
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- typing_handler.process_replication(result)
- yield presence_handler.process_replication(result)
- yield notify(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(5)
+ def get_streams_to_replicate(self):
+ args = super(SyncReplicationHandler, self).get_streams_to_replicate()
+ args.update(self.typing_handler.stream_positions())
+ return args
- def build_presence_handler(self):
- return SynchrotronPresence(self)
+ def get_currently_syncing_users(self):
+ return self.presence_handler.get_currently_syncing_users()
- def build_typing_handler(self):
- return SynchrotronTyping(self)
+ @defer.inlineCallbacks
+ def process_and_notify(self, stream_name, token, rows):
+ try:
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ event = yield self.store.get_event(row.event_id)
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows],
+ )
+ elif stream_name in ("account_data", "tag_account_data",):
+ self.notifier.on_new_event(
+ "account_data_key", token, users=[row.user_id for row in rows],
+ )
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event(
+ "to_device_key", token, users=entities,
+ )
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = yield self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event(
+ "device_list_key", token, rooms=all_room_ids,
+ )
+ elif stream_name == "presence":
+ yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows],
+ )
+ except Exception:
+ logger.exception("Error processing replication")
def start(config_options):
@@ -478,7 +404,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.synchrotron"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -496,33 +422,13 @@ def start(config_options):
ss.setup()
ss.start_listening(config.worker_listeners)
- def run():
- with LoggingContext("run"):
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_datastore().start_profiling()
- ss.replicate()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-synchrotron",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-synchrotron", config)
if __name__ == '__main__':
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index c045588866..712dfa870e 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -23,14 +23,27 @@ import signal
import subprocess
import sys
import yaml
+import errno
+import time
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
+YELLOW = "\x1b[1;33m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m"
+def pid_running(pid):
+ try:
+ os.kill(pid, 0)
+ return True
+ except OSError as err:
+ if err.errno == errno.EPERM:
+ return True
+ return False
+
+
def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL:
stream.write(message + "\n")
@@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n")
+def abort(message, colour=RED, stream=sys.stderr):
+ write(message, colour, stream)
+ sys.exit(1)
+
+
def start(configfile):
write("Starting ...")
args = SYNAPSE
@@ -45,7 +63,8 @@ def start(configfile):
try:
subprocess.check_call(args)
- write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
+ write("started synapse.app.homeserver(%r)" %
+ (configfile,), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting (exit code: %d); see above for logs" % e.returncode,
@@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
def stop(pidfile, app):
if os.path.exists(pidfile):
pid = int(open(pidfile).read())
- os.kill(pid, signal.SIGTERM)
- write("stopped %s" % (app,), colour=GREEN)
+ try:
+ os.kill(pid, signal.SIGTERM)
+ write("stopped %s" % (app,), colour=GREEN)
+ except OSError as err:
+ if err.errno == errno.ESRCH:
+ write("%s not running" % (app,), colour=YELLOW)
+ elif err.errno == errno.EPERM:
+ abort("Cannot stop %s: Operation not permitted" % (app,))
+ else:
+ abort("Cannot stop %s: Unknown error" % (app,))
Worker = collections.namedtuple("Worker", [
@@ -98,7 +125,7 @@ def main():
"configfile",
nargs="?",
default="homeserver.yaml",
- help="the homeserver config file, defaults to homserver.yaml",
+ help="the homeserver config file, defaults to homeserver.yaml",
)
parser.add_argument(
"-w", "--worker",
@@ -157,6 +184,9 @@ def main():
worker_configfiles.append(worker_configfile)
if options.all_processes:
+ # To start the main synapse with -a you need to add a worker file
+ # with worker_app == "synapse.app.homeserver"
+ start_stop_synapse = False
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
@@ -173,10 +203,29 @@ def main():
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
- worker_pidfile = worker_config["worker_pid_file"]
- worker_daemonize = worker_config["worker_daemonize"]
- assert worker_daemonize # TODO print something more user friendly
- worker_cache_factor = worker_config.get("synctl_cache_factor")
+ if worker_app == "synapse.app.homeserver":
+ # We need to special case all of this to pick up options that may
+ # be set in the main config file or in this worker config file.
+ worker_pidfile = (
+ worker_config.get("pid_file")
+ or pidfile
+ )
+ worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
+ daemonize = worker_config.get("daemonize") or config.get("daemonize")
+ assert daemonize, "Main process must have daemonize set to true"
+
+ # The master process doesn't support using worker_* config.
+ for key in worker_config:
+ if key == "worker_app": # But we allow worker_app
+ continue
+ assert not key.startswith("worker_"), \
+ "Main process cannot use worker_* config"
+ else:
+ worker_pidfile = worker_config["worker_pid_file"]
+ worker_daemonize = worker_config["worker_daemonize"]
+ assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+ worker_configfile, "worker_daemonize")
+ worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
@@ -190,10 +239,26 @@ def main():
if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver")
- # TODO: Wait for synapse to actually shutdown before starting it again
+ # Wait for synapse to actually shutdown before starting it again
+ if action == "restart":
+ running_pids = []
+ if start_stop_synapse and os.path.exists(pidfile):
+ running_pids.append(int(open(pidfile).read()))
+ for worker in workers:
+ if os.path.exists(worker.pidfile):
+ running_pids.append(int(open(worker.pidfile).read()))
+ if len(running_pids) > 0:
+ write("Waiting for process to exit before restarting...")
+ for running_pid in running_pids:
+ while pid_running(running_pid):
+ time.sleep(0.2)
+ write("All processes exited; now restarting...")
if action == "start" or action == "restart":
if start_stop_synapse:
+ # Check if synapse is already running
+ if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
+ abort("synapse.app.homeserver already running")
start(configfile)
for worker in workers:
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
new file mode 100644
index 0000000000..5ba7e9b416
--- /dev/null
+++ b/synapse/app/user_dir.py
@@ -0,0 +1,231 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha import user_directory
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.storage.user_directory import UserDirectoryStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext, run_in_background
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import reactor, defer
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.user_dir")
+
+
+class UserDirectorySlaveStore(
+ SlavedEventStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedClientIpStore,
+ UserDirectoryStore,
+ BaseSlavedStore,
+):
+ def __init__(self, db_conn, hs):
+ super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
+ self._current_state_delta_pos = events_max
+
+ def stream_positions(self):
+ result = super(UserDirectorySlaveStore, self).stream_positions()
+ result["current_state_deltas"] = self._current_state_delta_pos
+ return result
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "current_state_deltas":
+ self._current_state_delta_pos = token
+ for row in rows:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.room_id, token
+ )
+ return super(UserDirectorySlaveStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
+
+
+class UserDirectoryServer(HomeServer):
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ user_directory.register_servlets(self, resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ )
+ )
+
+ logger.info("Synapse user_dir now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return UserDirectoryReplicationHandler(self)
+
+
+class UserDirectoryReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
+ self.user_directory = hs.get_user_directory_handler()
+
+ def on_rdata(self, stream_name, token, rows):
+ super(UserDirectoryReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ if stream_name == "current_state_deltas":
+ run_in_background(self._notify_directory)
+
+ @defer.inlineCallbacks
+ def _notify_directory(self):
+ try:
+ yield self.user_directory.notify_new_event()
+ except Exception:
+ logger.exception("Error notifiying user directory of state update")
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse user directory", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.user_dir"
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ if config.update_user_directory:
+ sys.stderr.write(
+ "\nThe update_user_directory must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``update_user_directory: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.update_user_directory = True
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ps = UserDirectoryServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ps.setup()
+ ps.start_listening(config.worker_listeners)
+
+ def start():
+ ps.get_datastore().start_profiling()
+ ps.get_state_handler().start_caching()
+
+ reactor.callWhenRunning(start)
+
+ _base.start_worker_reactor("synapse-user-dir", config)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..5fdb579723 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import GroupID, get_domain_from_id
from twisted.internet import defer
import logging
import re
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -80,12 +84,13 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
- def __init__(self, token, url=None, namespaces=None, hs_token=None,
+ def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
+ self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
@@ -124,29 +129,41 @@ class ApplicationService(object):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
- if not isinstance(regex_obj.get("regex"), basestring):
+ group_id = regex_obj.get("group_id")
+ if group_id:
+ if not isinstance(group_id, str):
+ raise ValueError(
+ "Expected string for 'group_id' in ns '%s'" % ns
+ )
+ try:
+ GroupID.from_string(group_id)
+ except Exception:
+ raise ValueError(
+ "Expected valid group ID for 'group_id' in ns '%s'" % ns
+ )
+
+ if get_domain_from_id(group_id) != self.server_name:
+ raise ValueError(
+ "Expected 'group_id' to be this host in ns '%s'" % ns
+ )
+
+ regex = regex_obj.get("regex")
+ if isinstance(regex, string_types):
+ regex_obj["regex"] = re.compile(regex) # Pre-compile regex
+ else:
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
return namespaces
- def _matches_regex(self, test_string, namespace_key, return_obj=False):
- if not isinstance(test_string, basestring):
- logger.error(
- "Expected a string to test regex against, but got %s",
- test_string
- )
- return False
-
+ def _matches_regex(self, test_string, namespace_key):
for regex_obj in self.namespaces[namespace_key]:
- if re.match(regex_obj["regex"], test_string):
- if return_obj:
- return regex_obj
- return True
- return False
+ if regex_obj["regex"].match(test_string):
+ return regex_obj
+ return None
def _is_exclusive(self, ns_key, test_string):
- regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+ regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
return False
@@ -166,7 +183,14 @@ class ApplicationService(object):
if not store:
defer.returnValue(False)
- member_list = yield store.get_users_in_room(event.room_id)
+ does_match = yield self._matches_user_in_member_list(event.room_id, store)
+ defer.returnValue(does_match)
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def _matches_user_in_member_list(self, room_id, store, cache_context):
+ member_list = yield store.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
# check joined member events
for user_id in member_list:
@@ -219,10 +243,10 @@ class ApplicationService(object):
)
def is_interested_in_alias(self, alias):
- return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+ return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
def is_interested_in_room(self, room_id):
- return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+ return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
def is_exclusive_user(self, user_id):
return (
@@ -239,6 +263,31 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
+ def get_exlusive_user_regexes(self):
+ """Get the list of regexes used to determine if a user is exclusively
+ registered by the AS
+ """
+ return [
+ regex_obj["regex"]
+ for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+ if regex_obj["exclusive"]
+ ]
+
+ def get_groups_for_user(self, user_id):
+ """Get the groups that this user is associated with by this AS
+
+ Args:
+ user_id (str): The ID of the user.
+
+ Returns:
+ iterable[str]: an iterable that yields group_id strings.
+ """
+ return (
+ regex_obj["group_id"]
+ for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+ if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+ )
+
def is_rate_limited(self):
return self.rate_limited
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6893610e71..00efff1464 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -72,7 +72,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
+ self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
+ timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks
def query_user(self, service, user_id):
@@ -192,9 +193,7 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(None)
key = (service.id, protocol)
- return self.protocol_meta_cache.get(key) or (
- self.protocol_meta_cache.set(key, _get())
- )
+ return self.protocol_meta_cache.wrap(key, _get)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 68a9de17b8..6eddbc0828 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -51,7 +51,7 @@ components.
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
import logging
@@ -106,7 +106,7 @@ class _ServiceQueuer(object):
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
- preserve_fn(self._send_request)(service)
+ run_in_background(self._send_request, service)
@defer.inlineCallbacks
def _send_request(self, service):
@@ -123,7 +123,7 @@ class _ServiceQueuer(object):
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
- except:
+ except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
@@ -152,10 +152,10 @@ class _TransactionController(object):
if sent:
yield txn.complete(self.store)
else:
- preserve_fn(self._start_recoverer)(service)
- except Exception as e:
- logger.exception(e)
- preserve_fn(self._start_recoverer)(service)
+ run_in_background(self._start_recoverer, service)
+ except Exception:
+ logger.exception("Error creating appservice transaction")
+ run_in_background(self._start_recoverer, service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
@@ -176,17 +176,20 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- logger.info(
- "Application service falling behind. Starting recoverer. AS ID %s",
- service.id
- )
- recoverer = self.recoverer_fn(service, self.on_recovered)
- self.add_recoverers([recoverer])
- recoverer.recover()
+ try:
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ logger.info(
+ "Application service falling behind. Starting recoverer. AS ID %s",
+ service.id
+ )
+ recoverer = self.recoverer_fn(service, self.on_recovered)
+ self.add_recoverers([recoverer])
+ recoverer.recover()
+ except Exception:
+ logger.exception("Error starting AS recoverer")
@defer.inlineCallbacks
def _is_service_up(self, service):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1ab5593c6e..b748ed2b0a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,6 +19,8 @@ import os
import yaml
from textwrap import dedent
+from six import integer_types
+
class ConfigError(Exception):
pass
@@ -49,7 +51,7 @@ Missing mandatory `server_name` config option.
class Config(object):
@staticmethod
def parse_size(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
@@ -61,7 +63,7 @@ class Config(object):
@staticmethod
def parse_duration(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
second = 1000
minute = 60 * second
@@ -82,21 +84,37 @@ class Config(object):
return os.path.abspath(file_path) if file_path else file_path
@classmethod
+ def path_exists(cls, file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+ @classmethod
def check_file(cls, file_path, config_name):
if file_path is None:
raise ConfigError(
"Missing config for %s."
- " You must specify a path for the config file. You can "
- "do this with the -c or --config-path option. "
- "Adding --generate-config along with --server-name "
- "<server name> will generate a config file at the given path."
% (config_name,)
)
- if not os.path.exists(file_path):
+ try:
+ os.stat(file_path)
+ except OSError as e:
raise ConfigError(
- "File %s config for %s doesn't exist."
- " Try running again with --generate-config"
- % (file_path, config_name,)
+ "Error accessing file '%s' (config for %s): %s"
+ % (file_path, config_name, e.strerror)
)
return cls.abspath(file_path)
@@ -248,7 +266,7 @@ class Config(object):
" -c CONFIG-FILE\""
)
(config_path,) = config_files
- if not os.path.exists(config_path):
+ if not cls.path_exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
@@ -261,33 +279,33 @@ class Config(object):
"Must specify a server_name to a generate config for."
" Pass -H server.name."
)
- if not os.path.exists(config_dir_path):
+ if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
- with open(config_path, "wb") as config_file:
- config_bytes, config = obj.generate_config(
+ with open(config_path, "w") as config_file:
+ config_str, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
)
obj.invoke_all("generate_files", config)
- config_file.write(config_bytes)
- print (
+ config_file.write(config_str)
+ print((
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
- ) % (config_path, server_name)
- print (
+ ) % (config_path, server_name))
+ print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
- print (
+ print((
"Config file %r already exists. Generating any missing key"
" files."
- ) % (config_path,)
+ ) % (config_path,))
generate_keys = True
parser = argparse.ArgumentParser(
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 82c50b8240..277305e184 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -17,10 +17,12 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
-import urllib
import yaml
import logging
+from six import string_types
+from six.moves.urllib import parse as urlparse
+
logger = logging.getLogger(__name__)
@@ -89,21 +91,21 @@ def _load_appservice(hostname, as_info, config_filename):
"id", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
- if not isinstance(as_info.get(field), basestring):
+ if not isinstance(as_info.get(field), string_types):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (not isinstance(as_info.get("url"), basestring) and
+ if (not isinstance(as_info.get("url"), string_types) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
- if urllib.quote(localpart) != localpart:
+ if urlparse.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
@@ -128,7 +130,7 @@ def _load_appservice(hostname, as_info, config_filename):
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
- if not isinstance(regex_obj.get("regex"), basestring):
+ if not isinstance(regex_obj.get("regex"), string_types):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
@@ -154,6 +156,7 @@ def _load_appservice(hostname, as_info, config_filename):
)
return ApplicationService(
token=as_info["as_token"],
+ hostname=hostname,
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 938f6f25f8..8109e5f95e 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -41,7 +41,7 @@ class CasConfig(Config):
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
- # service_url: "https://homesever.domain.com:8448"
+ # service_url: "https://homeserver.domain.com:8448"
# #required_attributes:
# # name: value
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 0030b5db1e..fe156b6930 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -71,6 +71,15 @@ class EmailConfig(Config):
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
+ self.email_smtp_user = email_config.get(
+ "smtp_user", None
+ )
+ self.email_smtp_pass = email_config.get(
+ "smtp_pass", None
+ )
+ self.require_transport_security = email_config.get(
+ "require_transport_security", False
+ )
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
@@ -91,10 +100,17 @@ class EmailConfig(Config):
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
+ #
+ # If your SMTP server requires authentication, the optional smtp_user &
+ # smtp_pass variables should be used
+ #
#email:
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
+ # smtp_user: "exampleusername"
+ # smtp_pass: "examplepassword"
+ # require_transport_security: False
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix
# template_dir: res/templates
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
new file mode 100644
index 0000000000..997fa2881f
--- /dev/null
+++ b/synapse/config/groups.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class GroupsConfig(Config):
+ def read_config(self, config):
+ self.enable_group_creation = config.get("enable_group_creation", False)
+ self.group_creation_prefix = config.get("group_creation_prefix", "")
+
+ def default_config(self, **kwargs):
+ return """\
+ # Whether to allow non server admins to create groups on this server
+ enable_group_creation: false
+
+ # If enabled, non server admins can only create groups with local parts
+ # starting with this prefix
+ # group_creation_prefix: "unofficial/"
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 0f890fc04a..bf19cfee29 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -33,6 +33,10 @@ from .jwt import JWTConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig
from .workers import WorkerConfig
+from .push import PushConfig
+from .spam_checker import SpamCheckerConfig
+from .groups import GroupsConfig
+from .user_directory import UserDirectoryConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -40,7 +44,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig,):
+ WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+ SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
pass
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 6ee643793e..4b8fc063d0 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -118,10 +118,9 @@ class KeyConfig(Config):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return read_signing_keys(signing_keys.splitlines(True))
- except Exception:
+ except Exception as e:
raise ConfigError(
- "Error reading signing_key."
- " Try running again with --generate-config"
+ "Error reading signing_key: %s" % (str(e))
)
def read_old_signing_keys(self, old_signing_keys):
@@ -141,7 +140,8 @@ class KeyConfig(Config):
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
- if not os.path.exists(signing_key_path):
+
+ if not self.path_exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 77ded0ad25..6a7228dc2f 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -28,34 +28,35 @@ DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
- precise:
- format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
-- %(message)s'
+ precise:
+ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
+%(request)s - %(message)s'
filters:
- context:
- (): synapse.util.logcontext.LoggingContextFilter
- request: ""
+ context:
+ (): synapse.util.logcontext.LoggingContextFilter
+ request: ""
handlers:
- file:
- class: logging.handlers.RotatingFileHandler
- formatter: precise
- filename: ${log_file}
- maxBytes: 104857600
- backupCount: 10
- filters: [context]
- level: INFO
- console:
- class: logging.StreamHandler
- formatter: precise
- filters: [context]
+ file:
+ class: logging.handlers.RotatingFileHandler
+ formatter: precise
+ filename: ${log_file}
+ maxBytes: 104857600
+ backupCount: 10
+ filters: [context]
+ console:
+ class: logging.StreamHandler
+ formatter: precise
+ filters: [context]
loggers:
synapse:
level: INFO
synapse.storage.SQL:
+ # beware: increasing this to DEBUG will make synapse log sensitive
+ # information such as access tokens.
level: INFO
root:
@@ -68,21 +69,15 @@ class LoggingConfig(Config):
def read_config(self, config):
self.verbosity = config.get("verbose", 0)
+ self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
- log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
- # Logging verbosity level.
- verbose: 0
-
- # File to write logging to
- log_file: "%(log_file)s"
-
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
@@ -90,6 +85,8 @@ class LoggingConfig(Config):
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
+ if args.no_redirect_stdio is not None:
+ self.no_redirect_stdio = args.no_redirect_stdio
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
@@ -99,48 +96,68 @@ class LoggingConfig(Config):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
- help="The verbosity level."
+ help="The verbosity level. Specify multiple times to increase "
+ "verbosity. (Ignored if --log-config is specified.)"
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
- help="File to log to."
+ help="File to log to. (Ignored if --log-config is specified.)"
)
logging_group.add_argument(
'--log-config', dest="log_config", default=None,
help="Python logging config file"
)
+ logging_group.add_argument(
+ '-n', '--no-redirect-stdio',
+ action='store_true', default=None,
+ help="Do not redirect stdout/stderr to the log"
+ )
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
- with open(log_config, "wb") as log_config_file:
+ log_file = self.abspath("homeserver.log")
+ with open(log_config, "w") as log_config_file:
log_config_file.write(
- DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+ DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
- def setup_logging(self):
- setup_logging(self.log_config, self.log_file, self.verbosity)
+def setup_logging(config, use_worker_options=False):
+ """ Set up python logging
+
+ Args:
+ config (LoggingConfig | synapse.config.workers.WorkerConfig):
+ configuration data
+
+ use_worker_options (bool): True to use 'worker_log_config' and
+ 'worker_log_file' options instead of 'log_config' and 'log_file'.
+ """
+ log_config = (config.worker_log_config if use_worker_options
+ else config.log_config)
+ log_file = (config.worker_log_file if use_worker_options
+ else config.log_file)
-def setup_logging(log_config=None, log_file=None, verbosity=None):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
- if log_config is None:
+ if log_config is None:
+ # We don't have a logfile, so fall back to the 'verbosity' param from
+ # the config or cmdline. (Note that we generate a log config for new
+ # installs, so this will be an unusual case)
level = logging.INFO
level_for_storage = logging.INFO
- if verbosity:
+ if config.verbosity:
level = logging.DEBUG
- if verbosity > 1:
+ if config.verbosity > 1:
level_for_storage = logging.DEBUG
- # FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('')
logger.setLevel(level)
- logging.getLogger('synapse.storage').setLevel(level_for_storage)
+ logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:
@@ -153,24 +170,37 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
-
- # TODO(paul): obviously this is a terrible mechanism for
- # stealing SIGHUP, because it means no other part of synapse
- # can use it instead. If we want to catch SIGHUP anywhere
- # else as well, I'd suggest we find a nicer way to broadcast
- # it around.
- if getattr(signal, "SIGHUP"):
- signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
+
+ def sighup(signum, stack):
+ pass
+
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
else:
- with open(log_config, 'r') as f:
- logging.config.dictConfig(yaml.load(f))
+ def load_log_config():
+ with open(log_config, 'r') as f:
+ logging.config.dictConfig(yaml.load(f))
+
+ def sighup(signum, stack):
+ # it might be better to use a file watcher or something for this.
+ logging.info("Reloading log config from %s due to SIGHUP",
+ log_config)
+ load_log_config()
+
+ load_log_config()
+
+ # TODO(paul): obviously this is a terrible mechanism for
+ # stealing SIGHUP, because it means no other part of synapse
+ # can use it instead. If we want to catch SIGHUP anywhere
+ # else as well, I'd suggest we find a nicer way to broadcast
+ # it around.
+ if getattr(signal, "SIGHUP"):
+ signal.signal(signal.SIGHUP, sighup)
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
@@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
#
# However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver()
- globalLogBeginner.beginLoggingTo([observer])
+ globalLogBeginner.beginLoggingTo(
+ [observer],
+ redirectStandardIO=not config.no_redirect_stdio,
+ )
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83762d089a..6602c5b4c7 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -13,44 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ConfigError
+from ._base import Config
-import importlib
+from synapse.util.module_loader import load_module
+
+LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
+ providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
- self.ldap_enabled = ldap_config.get("enabled", False)
- if self.ldap_enabled:
- from ldap_auth_provider import LdapAuthProvider
- parsed_config = LdapAuthProvider.parse_config(ldap_config)
- self.password_providers.append((LdapAuthProvider, parsed_config))
+ if ldap_config.get("enabled", False):
+ providers.append({
+ 'module': LDAP_PROVIDER,
+ 'config': ldap_config,
+ })
- providers = config.get("password_providers", [])
+ providers.extend(config.get("password_providers", []))
for provider in providers:
+ mod_name = provider['module']
+
# This is for backwards compat when the ldap auth provider resided
# in this package.
- if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
- from ldap_auth_provider import LdapAuthProvider
- provider_class = LdapAuthProvider
- else:
- # We need to import the module, and then pick the class out of
- # that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
- module = importlib.import_module(module)
- provider_class = getattr(module, clz)
+ if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
+ mod_name = LDAP_PROVIDER
+
+ (provider_class, provider_config) = load_module({
+ "module": mod_name,
+ "config": provider['config'],
+ })
- try:
- provider_config = provider_class.parse_config(provider["config"])
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
diff --git a/synapse/config/push.py b/synapse/config/push.py
new file mode 100644
index 0000000000..b7e0d46afa
--- /dev/null
+++ b/synapse/config/push.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class PushConfig(Config):
+ def read_config(self, config):
+ push_config = config.get("push", {})
+ self.push_include_content = push_config.get("include_content", True)
+
+ # There was a a 'redact_content' setting but mistakenly read from the
+ # 'email'section'. Check for the flag in the 'push' section, and log,
+ # but do not honour it to avoid nasty surprises when people upgrade.
+ if push_config.get("redact_content") is not None:
+ print(
+ "The push.redact_content content option has never worked. "
+ "Please set push.include_content if you want this behaviour"
+ )
+
+ # Now check for the one in the 'email' section and honour it,
+ # with a warning.
+ push_config = config.get("email", {})
+ redact_content = push_config.get("redact_content")
+ if redact_content is not None:
+ print(
+ "The 'email.redact_content' option is deprecated: "
+ "please set push.include_content instead"
+ )
+ self.push_include_content = not redact_content
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ #push:
+ # include_content: true
+ """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 87e500c97a..c5384b3ad4 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
+ self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+ self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -41,6 +43,8 @@ class RegistrationConfig(Config):
self.allow_guest_access and config.get("invite_3pid_guest", False)
)
+ self.auto_join_rooms = config.get("auto_join_rooms", [])
+
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -50,13 +54,32 @@ class RegistrationConfig(Config):
# Enable registration for new users.
enable_registration: False
+ # The user must provide all of the below types of 3PID when registering.
+ #
+ # registrations_require_3pid:
+ # - email
+ # - msisdn
+
+ # Mandate that users are only allowed to associate certain formats of
+ # 3PIDs with accounts on this server.
+ #
+ # allowed_local_3pids:
+ # - medium: email
+ # pattern: ".*@matrix\\.org"
+ # - medium: email
+ # pattern: ".*@vector\\.im"
+ # - medium: msisdn
+ # pattern: "\\+44"
+
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
- # The default number of rounds is 12.
+ # The default number is 12 (which equates to 2^12 rounds).
+ # N.B. that increasing this will exponentially increase the time required
+ # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
@@ -69,6 +92,12 @@ class RegistrationConfig(Config):
trusted_third_party_id_servers:
- matrix.org
- vector.im
+ - riot.im
+
+ # Users who register on this homeserver will automatically be joined
+ # to these rooms
+ #auto_join_rooms:
+ # - "#example:example.com"
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 2c6f57168e..25ea77738a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -16,6 +16,8 @@
from ._base import Config, ConfigError
from collections import namedtuple
+from synapse.util.module_loader import load_module
+
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
+MediaStorageProviderConfig = namedtuple(
+ "MediaStorageProviderConfig", (
+ "store_local", # Whether to store newly uploaded local files
+ "store_remote", # Whether to store newly downloaded remote files
+ "store_synchronous", # Whether to wait for successful storage for local uploads
+ ),
+)
+
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
@@ -70,7 +80,64 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
+
self.media_store_path = self.ensure_directory(config["media_store_path"])
+
+ backup_media_store_path = config.get("backup_media_store_path")
+
+ synchronous_backup_media_store = config.get(
+ "synchronous_backup_media_store", False
+ )
+
+ storage_providers = config.get("media_storage_providers", [])
+
+ if backup_media_store_path:
+ if storage_providers:
+ raise ConfigError(
+ "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+ )
+
+ storage_providers = [{
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {
+ "directory": backup_media_store_path,
+ }
+ }]
+
+ # This is a list of config that can be used to create the storage
+ # providers. The entries are tuples of (Class, class_config,
+ # MediaStorageProviderConfig), where Class is the class of the provider,
+ # the class_config the config to pass to it, and
+ # MediaStorageProviderConfig are options for StorageProviderWrapper.
+ #
+ # We don't create the storage providers here as not all workers need
+ # them to be started.
+ self.media_storage_providers = []
+
+ for provider_config in storage_providers:
+ # We special case the module "file_system" so as not to need to
+ # expose FileStorageProviderBackend
+ if provider_config["module"] == "file_system":
+ provider_config["module"] = (
+ "synapse.rest.media.v1.storage_provider"
+ ".FileStorageProviderBackend"
+ )
+
+ provider_class, parsed_config = load_module(provider_config)
+
+ wrapper_config = MediaStorageProviderConfig(
+ provider_config.get("store_local", False),
+ provider_config.get("store_remote", False),
+ provider_config.get("store_synchronous", False),
+ )
+
+ self.media_storage_providers.append(
+ (provider_class, parsed_config, wrapper_config,)
+ )
+
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -115,6 +182,20 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
+ # Media storage providers allow media to be stored in different
+ # locations.
+ # media_storage_providers:
+ # - module: file_system
+ # # Whether to write new local files.
+ # store_local: false
+ # # Whether to write new remote media
+ # store_remote: false
+ # # Whether to block upload requests waiting for write to this
+ # # provider to complete
+ # store_synchronous: false
+ # config:
+ # directory: /mnt/some/other/directory
+
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 1f9999d57a..8f0b6d1f28 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,12 +30,42 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
+ self.cpu_affinity = config.get("cpu_affinity")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
# "disable" federation
self.send_federation = config.get("send_federation", True)
+ # Whether to update the user directory or not. This should be set to
+ # false only if we are updating the user directory in a worker
+ self.update_user_directory = config.get("update_user_directory", True)
+
+ # whether to enable the media repository endpoints. This should be set
+ # to false if the media repository is running as a separate endpoint;
+ # doing so ensures that we will not run cache cleanup jobs on the
+ # master, potentially causing inconsistency.
+ self.enable_media_repo = config.get("enable_media_repo", True)
+
+ self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+
+ # Whether we should block invites sent to users on this server
+ # (other than those sent by local server admins)
+ self.block_non_admin_invites = config.get(
+ "block_non_admin_invites", False,
+ )
+
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None
+ federation_domain_whitelist = config.get(
+ "federation_domain_whitelist", None
+ )
+ # turn the whitelist into a hash for speed of lookup
+ if federation_domain_whitelist is not None:
+ self.federation_domain_whitelist = {}
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -141,9 +172,36 @@ class ServerConfig(Config):
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
+ # CPU affinity mask. Setting this restricts the CPUs on which the
+ # process will be scheduled. It is represented as a bitmask, with the
+ # lowest order bit corresponding to the first logical CPU and the
+ # highest order bit corresponding to the last logical CPU. Not all CPUs
+ # may exist on a given system but a mask may specify more CPUs than are
+ # present.
+ #
+ # For example:
+ # 0x00000001 is processor #0,
+ # 0x00000003 is processors #0 and #1,
+ # 0xFFFFFFFF is all processors (#0 through #31).
+ #
+ # Pinning a Python process to a single CPU is desirable, because Python
+ # is inherently single-threaded due to the GIL, and can suffer a
+ # 30-40%% slowdown due to cache blow-out and thread context switching
+ # if the scheduler happens to schedule the underlying threads across
+ # different cores. See
+ # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
+ #
+ # cpu_affinity: 0xFFFFFFFF
+
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
+ # The root directory to server for the above web client.
+ # If left undefined, synapse will serve the matrix-angular-sdk web client.
+ # Make sure matrix-angular-sdk is installed with pip if web_client is True
+ # and web_client_location is undefined
+ # web_client_location: "/path/to/web/root"
+
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
@@ -155,6 +213,25 @@ class ServerConfig(Config):
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
+ # Set the limit on the returned events in the timeline in the get
+ # and sync operations. The default value is -1, means no upper limit.
+ # filter_timeline_limit: 5000
+
+ # Whether room invites to users on this server should be blocked
+ # (except those sent by local server admins). The default is False.
+ # block_non_admin_invites: True
+
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ # federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
@@ -165,13 +242,12 @@ class ServerConfig(Config):
port: %(bind_port)s
# Local addresses to listen on.
- # This will listen on all IPv4 addresses by default.
+ # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
+ # addresses by default. For most other OSes, this will only listen
+ # on IPv6.
bind_addresses:
+ - '::'
- '0.0.0.0'
- # Uncomment to listen on all IPv6 interfaces
- # N.B: On at least Linux this will also listen on all IPv4
- # addresses, so you will need to comment out the line above.
- # - '::'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
@@ -198,11 +274,18 @@ class ServerConfig(Config):
- names: [federation] # Federation APIs
compress: false
+ # optional list of additional endpoints which can be loaded via
+ # dynamic modules
+ # additional_resources:
+ # "/_matrix/my/custom/endpoint":
+ # module: my_module.CustomRequestHandler
+ # config: {}
+
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
- bind_addresses: ['0.0.0.0']
+ bind_addresses: ['::', '0.0.0.0']
type: http
x_forwarded: false
@@ -216,7 +299,7 @@ class ServerConfig(Config):
# Turn on the twisted ssh manhole service on localhost on the given
# port.
# - port: 9000
- # bind_address: 127.0.0.1
+ # bind_addresses: ['::1', '127.0.0.1']
# type: manhole
""" % locals()
@@ -254,7 +337,7 @@ def read_gc_thresholds(thresholds):
return (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
- except:
+ except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
new file mode 100644
index 0000000000..3fec42bdb0
--- /dev/null
+++ b/synapse/config/spam_checker.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class SpamCheckerConfig(Config):
+ def read_config(self, config):
+ self.spam_checker = None
+
+ provider = config.get("spam_checker", None)
+ if provider is not None:
+ self.spam_checker = load_module(provider)
+
+ def default_config(self, **kwargs):
+ return """\
+ # spam_checker:
+ # module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e081840a83..b66154bc7c 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
- # to the list. So if federation traffic is handle directly by synapse
+ # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
@@ -109,6 +109,12 @@ class TlsConfig(Config):
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
+ #
+ # You can calculate a fingerprint from a given TLS listener via:
+ # openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
+ # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
+ # or by checking matrix.org/federationtester/api/report?server_name=$host
+ #
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()
@@ -126,8 +132,8 @@ class TlsConfig(Config):
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
- if not os.path.exists(tls_private_key_path):
- with open(tls_private_key_path, "w") as private_key_file:
+ if not self.path_exists(tls_private_key_path):
+ with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -141,8 +147,8 @@ class TlsConfig(Config):
crypto.FILETYPE_PEM, private_key_pem
)
- if not os.path.exists(tls_certificate_path):
- with open(tls_certificate_path, "w") as certificate_file:
+ if not self.path_exists(tls_certificate_path):
+ with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
@@ -159,7 +165,7 @@ class TlsConfig(Config):
certificate_file.write(cert_pem)
- if not os.path.exists(tls_dh_params_path):
+ if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
new file mode 100644
index 0000000000..38e8947843
--- /dev/null
+++ b/synapse/config/user_directory.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class UserDirectoryConfig(Config):
+ """User Directory Configuration
+ Configuration for the behaviour of the /user_directory API
+ """
+
+ def read_config(self, config):
+ self.user_directory_search_all_users = False
+ user_directory_config = config.get("user_directory", None)
+ if user_directory_config:
+ self.user_directory_search_all_users = (
+ user_directory_config.get("search_all_users", False)
+ )
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # User Directory configuration
+ #
+ # 'search_all_users' defines whether to search all users visible to your HS
+ # when searching the user directory, rather than limiting to users visible
+ # in public rooms. Defaults to false. If you set it True, you'll have to run
+ # UPDATE user_directory_stream_pos SET stream_id = NULL;
+ # on your database to tell it to rebuild the user_directory search indexes.
+ #
+ #user_directory:
+ # search_all_users: false
+ """
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index eeb693027b..3a4e16fa96 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -23,6 +23,7 @@ class VoipConfig(Config):
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+ self.turn_allow_guests = config.get("turn_allow_guests", True)
def default_config(self, **kwargs):
return """\
@@ -41,4 +42,11 @@ class VoipConfig(Config):
# How long generated TURN credentials last
turn_user_lifetime: "1h"
+
+ # Whether guests should be allowed to use the TURN server.
+ # This defaults to True, otherwise VoIP will be unreliable for guests.
+ # However, it does introduce a slight security risk as it allows users to
+ # connect to arbitrary endpoints without having first signed up for a
+ # valid account (e.g. by passing a CAPTCHA).
+ turn_allow_guests: True
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index b165c67ee7..80baf0ce0e 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -23,12 +23,30 @@ class WorkerConfig(Config):
def read_config(self, config):
self.worker_app = config.get("worker_app")
+
+ # Canonicalise worker_app so that master always has None
+ if self.worker_app == "synapse.app.homeserver":
+ self.worker_app = None
+
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
- self.worker_replication_url = config.get("worker_replication_url")
+
+ # The host used to connect to the main synapse
+ self.worker_replication_host = config.get("worker_replication_host", None)
+
+ # The port on the main synapse for TCP replication
+ self.worker_replication_port = config.get("worker_replication_port", None)
+
+ # The port on the main synapse for HTTP replication endpoint
+ self.worker_replication_http_port = config.get("worker_replication_http_port")
+
+ self.worker_name = config.get("worker_name", self.worker_app)
+
+ self.worker_main_http_uri = config.get("worker_main_http_uri", None)
+ self.worker_cpu_affinity = config.get("worker_cpu_affinity")
if self.worker_listeners:
for listener in self.worker_listeners:
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index aad4752fe7..0397f73ab4 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -13,8 +13,8 @@
# limitations under the License.
from twisted.internet import ssl
-from OpenSSL import SSL
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName
+from OpenSSL import SSL, crypto
+from twisted.internet._sslverify import _defaultCurveName
import logging
@@ -32,9 +32,10 @@ class ServerContextFactory(ssl.ContextFactory):
@staticmethod
def configure_context(context, config):
try:
- _ecCurve = _OpenSSLECCurve(_defaultCurveName)
- _ecCurve.addECKeyToContext(context)
- except:
+ _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
+ context.set_tmp_ecdh(_ecCurve)
+
+ except Exception:
logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate_chain_file(config.tls_certificate_file)
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index ec7711ba7d..aaa3efaca3 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
- if name not in event.hashes:
+
+ # some malformed events lack a 'hashes'. Protect against it being missing
+ # or a weird type by basically treating it the same as an unhashed event.
+ hashes = event.get("hashes")
+ if not isinstance(hashes, dict):
+ raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+
+ if name not in hashes:
raise SynapseError(
400,
"Algorithm %s not in hashes %s" % (
- name, list(event.hashes),
+ name, list(hashes),
),
Codes.UNAUTHORIZED,
)
- message_hash_base64 = event.hashes[name]
+ message_hash_base64 = hashes[name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
- except:
+ except Exception:
raise SynapseError(
400,
"Invalid base64: %s" % (message_hash_base64,),
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c2bd64d6c2..f1fd488b90 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from synapse.util import logcontext
from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import (
- preserve_context_over_fn, preserve_context_over_deferred
-)
import simplejson as json
import logging
@@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
- protocol = yield preserve_context_over_fn(
- endpoint.connect, factory
- )
- server_response, server_certificate = yield preserve_context_over_deferred(
- protocol.remote_key
- )
- defer.returnValue((server_response, server_certificate))
- return
+ with logcontext.PreserveLoggingContext():
+ protocol = yield endpoint.connect(factory)
+ server_response, server_certificate = yield protocol.remote_key
+ defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9b3..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,12 +16,11 @@
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
- preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
- preserve_fn
+ PreserveLoggingContext,
+ preserve_fn,
+ run_in_background,
)
from synapse.util.metrics import Measure
@@ -58,7 +58,8 @@ Attributes:
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
"""
@@ -75,31 +76,41 @@ class Keyring(object):
self.perspective_servers = self.config.perspectives
self.hs = hs
+ # map from server name to Deferred. Has an entry for each server with
+ # an ongoing key download; the Deferred completes once the download
+ # completes.
+ #
+ # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
- return self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ return logcontext.make_deferred_yieldable(
+ self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+ )
def verify_json_objects_for_server(self, server_and_json):
- """Bulk verfies signatures of json objects, bulk fetching keys as
+ """Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
- list of deferreds indicating success or failure to verify each
- json object's signature for the given server_name.
+ List<Deferred>: for each input pair, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
"""
verify_requests = []
for server_name, json_object in server_and_json:
- logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
+ logger.warn("Request from %s: no supported signature keys",
+ server_name)
deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
@@ -108,76 +119,69 @@ class Keyring(object):
else:
deferred = defer.Deferred()
+ logger.debug("Verifying for %s with key_ids %s",
+ server_name, key_ids)
+
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
verify_requests.append(verify_request)
- @defer.inlineCallbacks
- def handle_key_deferred(verify_request):
- server_name = verify_request.server_name
- try:
- _, key_id, verify_key = yield verify_request.deferred
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ run_in_background(self._start_key_lookups, verify_requests)
- json_object = verify_request.json_object
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ handle = preserve_fn(_handle_key_deferred)
+ return [
+ handle(rq) for rq in verify_requests
+ ]
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ @defer.inlineCallbacks
+ def _start_key_lookups(self, verify_requests):
+ """Sets off the key fetches for each verify request
- server_to_deferred = {
- server_name: defer.Deferred()
- for server_name, _ in server_and_json
- }
+ Once each fetch completes, verify_request.deferred will be resolved.
- with PreserveLoggingContext():
+ Args:
+ verify_requests (List[VerifyKeyRequest]):
+ """
+
+ try:
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
+ server_to_deferred = {
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
+ }
# We want to wait for any previous lookups to complete before
# proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
server_to_deferred,
)
# Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(verify_requests)
- )
+ self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
server_to_request_ids = {}
- def remove_deferreds(res, server_name, verify_request):
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
@@ -187,17 +191,11 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
- deferred.addBoth(remove_deferreds, server_name, verify_request)
-
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- return [
- preserve_context_over_fn(handle_key_deferred, verify_request)
- for verify_request in verify_requests
- ]
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
+ except Exception:
+ logger.exception("Error starting key lookups")
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -206,7 +204,13 @@ class Keyring(object):
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
- resolved once we've finished looking up keys for that server
+ resolved once we've finished looking up keys for that server.
+ The Deferreds should be regular twisted ones which call their
+ callbacks with no logcontext.
+
+ Returns: a Deferred which resolves once all key lookups for the given
+ servers have completed. Follows the synapse rules of logcontext
+ preservation.
"""
while True:
wait_on = [
@@ -220,19 +224,23 @@ class Keyring(object):
else:
break
+ def rm(r, server_name_):
+ self.key_downloads.pop(server_name_, None)
+ return r
+
for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(preserve_context_over_deferred(deferred))
- self.key_downloads[server_name] = d
+ self.key_downloads[server_name] = deferred
+ deferred.addBoth(rm, server_name)
- def rm(r, server_name):
- self.key_downloads.pop(server_name, None)
- return r
+ def _get_server_verify_keys(self, verify_requests):
+ """Tries to find at least one key for each verify request
- d.addBoth(rm, server_name)
+ For each verify_request, verify_request.deferred is called back with
+ params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+ with a SynapseError if none of the keys are found.
- def get_server_verify_keys(self, verify_requests):
- """Takes a dict of KeyGroups and tries to find at least one key for
- each group.
+ Args:
+ verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
# These are functions that produce keys given a list of key ids
@@ -245,8 +253,11 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
+ # dict[str, dict[str, VerifyKey]]: results so far.
+ # map server_name -> key_id -> VerifyKey
merged_results = {}
+ # dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -290,33 +301,46 @@ class Keyring(object):
if not missing_keys:
break
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ with PreserveLoggingContext():
+ for verify_request in requests_missing_keys:
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ with PreserveLoggingContext():
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
- do_iterations().addErrback(on_err)
+ run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
- res = yield preserve_context_over_deferred(defer.gatherResults(
+ """
+
+ Args:
+ server_name_and_key_ids (list[(str, iterable[str])]):
+ list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+ Returns:
+ Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+ server_name -> key_id -> VerifyKey
+ """
+ res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.get_server_verify_keys)(
- server_name, key_ids
+ run_in_background(
+ self.store.get_server_verify_keys,
+ server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(dict(res))
@@ -333,17 +357,17 @@ class Keyring(object):
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
defer.returnValue({})
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(p_name, p_keys)
+ run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
union_of_keys = {}
for result in results:
@@ -356,40 +380,34 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
- with limiter:
- keys = None
- try:
- keys = yield self.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except Exception as e:
- logger.info(
- "Unable to get key %r for %r directly: %s %s",
- key_ids, server_name,
- type(e).__name__, str(e.message),
- )
+ keys = None
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except Exception as e:
+ logger.info(
+ "Unable to get key %r for %r directly: %s %s",
+ key_ids, server_name,
+ type(e).__name__, str(e),
+ )
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
- keys = {server_name: keys}
+ keys = {server_name: keys}
defer.returnValue(keys)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(server_name, key_ids)
+ run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
merged = {}
for result in results:
@@ -466,9 +484,10 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -476,7 +495,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -524,9 +543,10 @@ class Keyring(object):
keys.update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -534,7 +554,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -600,9 +620,10 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_keys_json)(
+ run_in_background(
+ self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -613,7 +634,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
results[server_name] = response_keys
@@ -691,7 +712,6 @@ class Keyring(object):
defer.returnValue(verify_keys)
- @defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
@@ -702,12 +722,57 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
- yield preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_verify_key)(
+ run_in_background(
+ self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
+ try:
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = verify_request.json_object
+
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except Exception:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4096c606f1..cd5627e36a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
- 403, "You cannot unban user &s." % (target_user_id,)
+ 403, "You cannot unban user %s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
@@ -443,12 +443,12 @@ def _check_power_levels(event, auth_events):
for k, v in user_list.items():
try:
UserID.from_string(k)
- except:
+ except Exception:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
- except:
+ except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
@@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
("invite", None),
]
- old_list = current_state.content.get("users")
+ old_list = current_state.content.get("users", {})
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
- old_list = current_state.content.get("events")
- new_list = event.content.get("events")
+ old_list = current_state.content.get("events", {})
+ new_list = event.content.get("events", {})
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..c3ff85c49a 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -47,14 +47,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key):
+ # We want to be able to use hasattr with the event dict properties.
+ # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+ # we need to transform the KeyError into an AttributeError
def getter(self):
- return self._event_dict[key]
+ try:
+ return self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
def setter(self, v):
- self._event_dict[key] = v
+ try:
+ self._event_dict[key] = v
+ except KeyError:
+ raise AttributeError(key)
def delete(self):
- del self._event_dict[key]
+ try:
+ del self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
return property(
getter,
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 365fd96bd2..13fbba68c0 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -55,7 +55,7 @@ class EventBuilderFactory(object):
local_part = str(int(self.clock.time())) + i + random_string(5)
- e_id = EventID.create(local_part, self.hostname)
+ e_id = EventID(local_part, self.hostname)
return e_id.to_string()
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 11605b34a3..8e684d91b5 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -13,17 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from frozendict import frozendict
+
class EventContext(object):
+ """
+ Attributes:
+ current_state_ids (dict[(str, str), str]):
+ The current state map including the current event.
+ (type, state_key) -> event_id
+
+ prev_state_ids (dict[(str, str), str]):
+ The current state map excluding the current event.
+ (type, state_key) -> event_id
+
+ state_group (int|None): state group id, if the state has been stored
+ as a state group. This is usually only None if e.g. the event is
+ an outlier.
+ rejected (bool|str): A rejection reason if the event was rejected, else
+ False
+
+ push_actions (list[(str, list[object])]): list of (user_id, actions)
+ tuples
+
+ prev_group (int): Previously persisted state group. ``None`` for an
+ outlier.
+ delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
+ (type, state_key) -> event_id. ``None`` for an outlier.
+
+ prev_state_events (?): XXX: is this ever set to anything other than
+ the empty list?
+ """
+
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
- "push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
+ "app_service",
]
def __init__(self):
@@ -34,7 +66,6 @@ class EventContext(object):
self.state_group = None
self.rejected = False
- self.push_actions = []
# A previously persisted state group and a delta between that
# and this state.
@@ -42,3 +73,100 @@ class EventContext(object):
self.delta_ids = None
self.prev_state_events = None
+
+ self.app_service = None
+
+ def serialize(self, event):
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Args:
+ event (FrozenEvent): The event that this context relates to
+
+ Returns:
+ dict
+ """
+
+ # We don't serialize the full state dicts, instead they get pulled out
+ # of the DB on the other side. However, the other side can't figure out
+ # the prev_state_ids, so if we're a state event we include the event
+ # id that we replaced in the state.
+ if event.is_state():
+ prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
+ else:
+ prev_state_id = None
+
+ return {
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None
+ }
+
+ @staticmethod
+ @defer.inlineCallbacks
+ def deserialize(store, input):
+ """Converts a dict that was produced by `serialize` back into a
+ EventContext.
+
+ Args:
+ store (DataStore): Used to convert AS ID to AS object
+ input (dict): A dict produced by `serialize`
+
+ Returns:
+ EventContext
+ """
+ context = EventContext()
+ context.state_group = input["state_group"]
+ context.rejected = input["rejected"]
+ context.prev_group = input["prev_group"]
+ context.delta_ids = _decode_state_dict(input["delta_ids"])
+ context.prev_state_events = input["prev_state_events"]
+
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ prev_state_id = input["prev_state_id"]
+ event_type = input["event_type"]
+ event_state_key = input["event_state_key"]
+
+ context.current_state_ids = yield store.get_state_ids_for_group(
+ context.state_group,
+ )
+ if prev_state_id and event_state_key:
+ context.prev_state_ids = dict(context.current_state_ids)
+ context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
+ else:
+ context.prev_state_ids = context.current_state_ids
+
+ app_service_id = input["app_service_id"]
+ if app_service_id:
+ context.app_service = store.get_app_service_by_id(app_service_id)
+
+ defer.returnValue(context)
+
+
+def _encode_state_dict(state_dict):
+ """Since dicts of (type, state_key) -> event_id cannot be serialized in
+ JSON we need to convert them to a form that can.
+ """
+ if state_dict is None:
+ return None
+
+ return [
+ (etype, state_key, v)
+ for (etype, state_key), v in state_dict.iteritems()
+ ]
+
+
+def _decode_state_dict(input):
+ """Decodes a state dict encoded using `_encode_state_dict` above
+ """
+ if input is None:
+ return None
+
+ return frozendict({(etype, state_key,): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
new file mode 100644
index 0000000000..633e068eb8
--- /dev/null
+++ b/synapse/events/spamcheck.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class SpamChecker(object):
+ def __init__(self, hs):
+ self.spam_checker = None
+
+ module = None
+ config = None
+ try:
+ module, config = hs.config.spam_checker
+ except Exception:
+ pass
+
+ if module is not None:
+ self.spam_checker = module(config=config)
+
+ def check_event_for_spam(self, event):
+ """Checks if a given event is considered "spammy" by this server.
+
+ If the server considers an event spammy, then it will be rejected if
+ sent by a local user. If it is sent by a user on another server, then
+ users receive a blank event.
+
+ Args:
+ event (synapse.events.EventBase): the event to be checked
+
+ Returns:
+ bool: True if the event is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ return self.spam_checker.check_event_for_spam(event)
+
+ def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ """Checks if a given user may send an invite
+
+ If this method returns false, the invite will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may send an invite, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+
+ def user_may_create_room(self, userid):
+ """Checks if a given user may create a room
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may create a room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room(userid)
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Checks if a given user may create a room alias
+
+ If this method returns false, the association request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_alias (string): The alias to be created
+
+ Returns:
+ bool: True if the user may create a room alias, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+
+ def user_may_publish_room(self, userid, room_id):
+ """Checks if a given user may publish a room to the directory
+
+ If this method returns false, the publish request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_id (string): The ID of the room that would be published
+
+ Returns:
+ bool: True if the user may publish the room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_publish_room(userid, room_id)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 5bbaef8187..824f4a42e3 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -225,7 +225,22 @@ def format_event_for_client_v2_without_room_id(d):
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
- token_id=None, only_event_fields=None):
+ token_id=None, only_event_fields=None, is_invite=False):
+ """Serialize event for clients
+
+ Args:
+ e (EventBase)
+ time_now_ms (int)
+ as_client_event (bool)
+ event_format
+ token_id
+ only_event_fields
+ is_invite (bool): Whether this is an invite that is being sent to the
+ invitee
+
+ Returns:
+ dict
+ """
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@@ -251,6 +266,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
+ # If this is an invite for somebody else, then we don't care about the
+ # invite_room_state as that's meant solely for the invitee. Other clients
+ # will already have the state since they're in the room.
+ if not is_invite:
+ d["unsigned"].pop("invite_room_state", None)
+
if as_client_event:
d = event_format(d)
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
""" This package includes all the federation specific logic.
"""
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
- transport = hs.get_federation_transport_client()
-
- return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..4cc98a3fe8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import six
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import SynapseError, Codes
from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
-
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.hs = hs
+
+ self.server_name = hs.hostname
+ self.keyring = hs.get_keyring()
+ self.spam_checker = hs.get_spam_checker()
+ self.store = hs.get_datastore()
+ self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +60,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
-
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
+
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -114,15 +113,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
- """Throws a SynapseError if a PDU does not have the correct
- signatures.
+ """Checks that each of the received events is correctly signed by the
+ sending server.
+
+ Args:
+ pdus (list[FrozenEvent]): the events to be checked
Returns:
- FrozenEvent: Either the given event or it redacted if it failed the
- content hash check.
+ list[Deferred]: for each input event, a deferred which:
+ * returns the original event if the checks pass
+ * returns a redacted version of the event (if the signature
+ matched but the hash did not)
+ * throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,26 +138,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
+ ctx = logcontext.LoggingContext.current_context()
+
def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
- return pdu
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if self.spam_checker.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
@@ -160,3 +180,40 @@ class FederationBase(object):
)
return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+ """Construct a FrozenEvent from an event json received over federation
+
+ Args:
+ pdu_json (object): pdu as received over federation
+ outlier (bool): True to mark this event as an outlier
+
+ Returns:
+ FrozenEvent
+
+ Raises:
+ SynapseError: if the pdu is missing required fields or is otherwise
+ not a valid matrix event
+ """
+ # we could probably enforce a bunch of other fields here (room_id, sender,
+ # origin, etc etc)
+ assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+
+ depth = pdu_json['depth']
+ if not isinstance(depth, six.integer_types):
+ raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+ Codes.BAD_JSON)
+
+ if depth < 0:
+ raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+ elif depth > MAX_DEPTH:
+ raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b5bcfd705a..6163f7c466 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,28 +14,30 @@
# limitations under the License.
+import copy
+import itertools
+import logging
+import random
+
+from six.moves import range
+
from twisted.internet import defer
-from .federation_base import FederationBase
from synapse.api.constants import Membership
-
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError,
+ CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
-from synapse.util import unwrapFirstError
+from synapse.events import builder
+from synapse.federation.federation_base import (
+ FederationBase,
+ event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
-
-import copy
-import itertools
-import logging
-import random
-
+from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -58,6 +60,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
+ self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
@@ -88,7 +91,7 @@ class FederationClient(FederationBase):
@log_function
def make_query(self, destination, query_type, args,
- retry_on_dns_fail=False):
+ retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@@ -98,6 +101,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
a Deferred which will eventually yield a JSON object from the
@@ -106,7 +111,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type)
return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+ destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+ ignore_backoff=ignore_backoff,
)
@log_function
@@ -181,15 +187,15 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
- self.event_from_pdu_json(p, outlier=False)
+ event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+ pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(pdus)
@@ -206,8 +212,7 @@ class FederationClient(FederationBase):
Args:
destinations (list): Which home servers to query
- pdu_origin (str): The home server that originally sent the pdu.
- event_id (str)
+ event_id (str): event to fetch
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -235,31 +240,24 @@ class FederationClient(FederationBase):
continue
try:
- limiter = yield get_retry_limiter(
- destination,
- self._clock,
- self.store,
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id, timeout=timeout,
)
- with limiter:
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id, timeout=timeout,
- )
-
- logger.debug("transaction_data %r", transaction_data)
+ logger.debug("transaction_data %r", transaction_data)
- pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
- for p in transaction_data["pdus"]
- ]
+ pdu_list = [
+ event_from_pdu_json(p, outlier=outlier)
+ for p in transaction_data["pdus"]
+ ]
- if pdu_list and pdu_list[0]:
- pdu = pdu_list[0]
+ if pdu_list and pdu_list[0]:
+ pdu = pdu_list[0]
- # Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ # Check signatures are correct.
+ signed_pdu = yield self._check_sigs_and_hash(pdu)
- break
+ break
pdu_attempts[destination] = now
@@ -271,6 +269,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e.message)
+ continue
except Exception as e:
pdu_attempts[destination] = now
@@ -341,11 +342,11 @@ class FederationClient(FederationBase):
)
pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -395,7 +396,7 @@ class FederationClient(FederationBase):
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
- seen_events = yield self.store.have_events(event_ids)
+ seen_events = yield self.store.have_seen_events(event_ids)
signed_events = []
failed_to_fetch = set()
@@ -414,18 +415,19 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
- for i in xrange(0, len(missing_events), batch_size):
+ for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- preserve_fn(self.get_pdu)(
+ run_in_background(
+ self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
- res = yield preserve_context_over_deferred(
+ res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
@@ -446,7 +448,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
@@ -479,8 +481,13 @@ class FederationClient(FederationBase):
content (object): Any additional data to put into the content field
of the event.
Return:
- A tuple of (origin (str), event (object)) where origin is the remote
- homeserver which generated the event.
+ Deferred: resolves to a tuple of (origin (str), event (object))
+ where origin is the remote homeserver which generated the event.
+
+ Fails with a ``CodeMessageException`` if the chosen remote server
+ returns a 300/400 code.
+
+ Fails with a ``RuntimeError`` if no servers were reachable.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
@@ -533,6 +540,27 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
+ """Sends a join event to one of a list of homeservers.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ Args:
+ destinations (str): Candidate homeservers which are probably
+ participating in the room.
+ pdu (BaseEvent): event to be sent
+
+ Return:
+ Deferred: resolves to a dict with members ``origin`` (a string
+ giving the serer the event was sent to, ``state`` (?) and
+ ``auth_chain``.
+
+ Fails with a ``CodeMessageException`` if the chosen remote server
+ returns a 300/400 code.
+
+ Fails with a ``RuntimeError`` if no servers were reachable.
+ """
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -549,12 +577,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -629,7 +657,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict)
- pdu = self.event_from_pdu_json(pdu_dict)
+ pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
@@ -640,6 +668,26 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
+ """Sends a leave event to one of a list of homeservers.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ This is mostly useful to reject received invites.
+
+ Args:
+ destinations (str): Candidate homeservers which are probably
+ participating in the room.
+ pdu (BaseEvent): event to be sent
+
+ Return:
+ Deferred: resolves to None.
+
+ Fails with a ``CodeMessageException`` if the chosen remote server
+ returns a non-200 code.
+
+ Fails with a ``RuntimeError`` if no servers were reachable.
+ """
for destination in destinations:
if destination == self.server_name:
continue
@@ -699,7 +747,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -747,7 +795,7 @@ class FederationClient(FederationBase):
)
events = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content.get("events", [])
]
@@ -764,15 +812,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e922b7ff4a..247ddc89d5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,27 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
-
+import simplejson as json
from twisted.internet import defer
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+ FederationBase,
+ event_from_pdu_json,
+)
-from synapse.util.async import Linearizer
-from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
import synapse.metrics
+from synapse.types import get_domain_from_id
+from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logutils import log_function
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
+from six import iteritems
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
@@ -51,49 +56,18 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
+ self.handler = hs.get_handlers().federation_handler
- self._room_pdu_linearizer = Linearizer("fed_room_pdu")
- self._server_linearizer = Linearizer("fed_server")
-
- # We cache responses to state queries, as they take a while and often
- # come in waves.
- self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
-
- def set_handler(self, handler):
- """Sets the handler that the replication layer will use to communicate
- receipt of new PDUs from other home servers. The required methods are
- documented on :py:class:`.ReplicationHandler`.
- """
- self.handler = handler
-
- def register_edu_handler(self, edu_type, handler):
- if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
- self.edu_handlers[edu_type] = handler
-
- def register_query_handler(self, query_type, handler):
- """Sets the handler callable that will be used to handle an incoming
- federation Query of the given type.
-
- Args:
- query_type (str): Category name of the query, which should match
- the string used by make_query.
- handler (callable): Invoked to handle incoming queries of this type
+ self._server_linearizer = async.Linearizer("fed_server")
+ self._transaction_linearizer = async.Linearizer("fed_txn_handler")
- handler is invoked as:
- result = handler(args)
+ self.transaction_actions = TransactionActions(self.store)
- where 'args' is a dict mapping strings to strings of the query
- arguments. It should return a Deferred that will eventually yield an
- object to encode as JSON.
- """
- if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
+ self.registry = hs.get_federation_registry()
- self.query_handlers[query_type] = handler
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
@defer.inlineCallbacks
@log_function
@@ -110,25 +84,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
+ # keep this as early as possible to make the calculated origin ts as
+ # accurate as possible.
+ request_time = self._clock.time_msec()
+
transaction = Transaction(**transaction_data)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ if not transaction.transaction_id:
+ raise Exception("Transaction missing transaction_id")
+ if not transaction.origin:
+ raise Exception("Transaction missing origin")
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
+ # use a linearizer to ensure that we don't process the same transaction
+ # multiple times in parallel.
+ with (yield self._transaction_linearizer.queue(
+ (transaction.origin, transaction.transaction_id),
+ )):
+ result = yield self._handle_incoming_transaction(
+ transaction, request_time,
+ )
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _handle_incoming_transaction(self, transaction, request_time):
+ """ Process an incoming transaction and return the HTTP response
+ Args:
+ transaction (Transaction): incoming transaction
+ request_time (int): timestamp that the HTTP request arrived at
+
+ Returns:
+ Deferred[(int, object)]: http response code and body
+ """
response = yield self.transaction_actions.have_responded(transaction)
if response:
@@ -141,38 +131,49 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- results = []
-
- for pdu in pdu_list:
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if transaction.origin != get_domain_from_id(pdu.event_id):
- if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) == 'join' and
- self.hs.is_mine_id(pdu.state_key)
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, transaction.origin
- )
- continue
- else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, transaction.origin
- )
+ received_pdus_counter.inc_by(len(transaction.pdus))
+
+ pdus_by_room = {}
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = request_time - int(p["age"])
+ del p["age"]
- try:
- yield self._handle_new_pdu(transaction.origin, pdu)
- results.append({})
- except FederationError as e:
- self.send_failure(e, transaction.origin)
- results.append({"error": str(e)})
- except Exception as e:
- results.append({"error": str(e)})
- logger.exception("Failed to handle PDU")
+ event = event_from_pdu_json(p)
+ room_id = event.room_id
+ pdus_by_room.setdefault(room_id, []).append(event)
+
+ pdu_results = {}
+
+ # we can process different rooms in parallel (which is useful if they
+ # require callouts to other servers to fetch missing events), but
+ # impose a limit to avoid going too crazy with ram/cpu.
+ @defer.inlineCallbacks
+ def process_pdus_for_room(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ try:
+ yield self._handle_received_pdu(
+ transaction.origin, pdu
+ )
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warn("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ pdu_results[event_id] = {"error": str(e)}
+ logger.exception("Failed to handle PDU %s", event_id)
+
+ yield async.concurrently_execute(
+ process_pdus_for_room, pdus_by_room.keys(),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
@@ -182,17 +183,16 @@ class FederationServer(FederationBase):
edu.content
)
- for failure in getattr(transaction, "pdu_failures", []):
- logger.info("Got failure %r", failure)
-
- logger.debug("Returning: %s", str(results))
+ pdu_failures = getattr(transaction, "pdu_failures", [])
+ for failure in pdu_failures:
+ logger.info("Got failure %r", failure)
response = {
- "pdus": dict(zip(
- (p.event_id for p in pdu_list), results
- )),
+ "pdus": pdu_results,
}
+ logger.debug("Returning: %s", str(response))
+
yield self.transaction_actions.set_response(
transaction,
200, response
@@ -202,16 +202,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
-
- if edu_type in self.edu_handlers:
- try:
- yield self.edu_handlers[edu_type](origin, content)
- except SynapseError as e:
- logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception as e:
- logger.exception("Failed to handle edu %r", edu_type)
- else:
- logger.warn("Received EDU of type %s with no handler", edu_type)
+ yield self.registry.on_edu(edu_type, origin, content)
@defer.inlineCallbacks
@log_function
@@ -223,15 +214,17 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
- result = self._state_resp_cache.get((room_id, event_id))
- if not result:
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.set(
- (room_id, event_id),
- self._on_context_state_request_compute(room_id, event_id)
- )
- else:
- resp = yield result
+ # we grab the linearizer to protect ourselves from servers which hammer
+ # us. In theory we might already have the response to this query
+ # in the cache so we could return it without waiting for the linearizer
+ # - but that's non-trivial to get right, and anyway somewhat defeats
+ # the point of the linearizer.
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id, event_id,
+ )
defer.returnValue((200, resp))
@@ -300,14 +293,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type)
-
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue(
- (404, "No handler for Query type '%s'" % (query_type,))
- )
+ resp = yield self.registry.on_query(query_type, args)
+ defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
@@ -317,7 +304,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -325,7 +312,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -345,7 +332,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -382,7 +369,7 @@ class FederationServer(FederationBase):
"""
with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -437,6 +424,16 @@ class FederationServer(FederationBase):
key_id: json.loads(json_bytes)
}
+ logger.info(
+ "Claimed one-time-keys: %s",
+ ",".join((
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )),
+ )
+
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
@@ -497,26 +494,59 @@ class FederationServer(FederationBase):
)
@defer.inlineCallbacks
- @log_function
- def _handle_new_pdu(self, origin, pdu, get_missing=True):
+ def _handle_received_pdu(self, origin, pdu):
+ """ Process a PDU received in a federation /send/ transaction.
+
+ If the event is invalid, then this method throws a FederationError.
+ (The error will then be logged and sent back to the sender (which
+ probably won't do anything with it), and other events in the
+ transaction will be processed as normal).
+
+ It is likely that we'll then receive other events which refer to
+ this rejected_event in their prev_events, etc. When that happens,
+ we'll attempt to fetch the rejected event again, which will presumably
+ fail, so those second-generation events will also get rejected.
+
+ Eventually, we get to the point where there are more than 10 events
+ between any new events and the original rejected event. Since we
+ only try to backfill 10 events deep on received pdu, we then accept the
+ new event, possibly introducing a discontinuity in the DAG, with new
+ forward extremities, so normal service is approximately returned,
+ until we try to backfill across the discontinuity.
- # We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
- )
+ Args:
+ origin (str): server which sent the pdu
+ pdu (FrozenEvent): received pdu
- # FIXME: Currently we fetch an event again when we already have it
- # if it has been marked as an outlier.
+ Returns (Deferred): completes with None
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
- )
- if already_seen:
- logger.debug("Already seen pdu %s", pdu.event_id)
- return
+ Raises: FederationError if the signatures / hash do not match, or
+ if the event was unacceptable for any other reason (eg, too large,
+ too many prev_events, couldn't find the prev_events)
+ """
+ # check that it's actually being sent from a valid destination to
+ # workaround bug #1753 in 0.18.5 and 0.18.6
+ if origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
+ if not (
+ pdu.type == 'm.room.member' and
+ pdu.content and
+ pdu.content.get("membership", None) == 'join'
+ ):
+ logger.info(
+ "Discarding PDU %s from invalid origin %s",
+ pdu.event_id, origin
+ )
+ return
+ else:
+ logger.info(
+ "Accepting join PDU %s from %s",
+ pdu.event_id, origin
+ )
# Check signature.
try:
@@ -529,156 +559,11 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- state = None
-
- auth_chain = []
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
-
- fetch_state = False
-
- # Get missing pdus if necessary.
- if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = yield self.handler.get_min_depth_for_context(
- pdu.room_id
- )
-
- logger.debug(
- "_handle_new_pdu min_depth for %s: %d",
- pdu.room_id, min_depth
- )
-
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
-
- if min_depth and pdu.depth < min_depth:
- # This is so that we don't notify the user about this
- # message, to work around the fact that some events will
- # reference really really old events we really don't want to
- # send to the clients.
- pdu.internal_metadata.outlier = True
- elif min_depth and pdu.depth > min_depth:
- if get_missing and prevs - seen:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
- logger.info(
- "Acquiring lock for room %r to fetch %d missing events: %r...",
- pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
- )
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired lock for room %r to fetch %d missing events",
- pdu.room_id, len(prevs - seen),
- )
-
- # We recalculate seen, since it may have changed.
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.keys())
-
- if prevs - seen:
- latest = yield self.store.get_latest_event_ids_in_room(
- pdu.room_id
- )
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest)
- latest |= seen
-
- logger.info(
- "Missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
- )
-
- # XXX: we set timeout to 10s to help workaround
- # https://github.com/matrix-org/synapse/issues/1733.
- # The reason is to avoid holding the linearizer lock
- # whilst processing inbound /send transactions, causing
- # FDs to stack up and block other inbound transactions
- # which empirically can currently take up to 30 minutes.
- #
- # N.B. this explicitly disables retry attempts.
- #
- # N.B. this also increases our chances of falling back to
- # fetching fresh state for the room if the missing event
- # can't be found, which slightly reduces our security.
- # it may also increase our DAG extremity count for the room,
- # causing additional state resolution? See #1760.
- # However, fetching state doesn't hold the linearizer lock
- # apparently.
- #
- # see https://github.com/matrix-org/synapse/pull/1744
-
- missing_events = yield self.get_missing_events(
- origin,
- pdu.room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=10000,
- )
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for e in missing_events:
- yield self._handle_new_pdu(
- origin,
- e,
- get_missing=False
- )
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
-
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
- if prevs - seen:
- logger.info(
- "Still missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
- )
- fetch_state = True
-
- if fetch_state:
- # We need to get the state at this event, since we haven't
- # processed all the prev events.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- try:
- state, auth_chain = yield self.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
- except:
- logger.exception("Failed to get state for event: %s", pdu.event_id)
-
- yield self.handler.on_receive_pdu(
- origin,
- pdu,
- state=state,
- auth_chain=auth_chain,
- )
+ yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def exchange_third_party_invite(
self,
@@ -701,3 +586,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict
)
defer.returnValue(ret)
+
+
+class FederationHandlerRegistry(object):
+ """Allows classes to register themselves as handlers for a given EDU or
+ query type for incoming federation traffic.
+ """
+ def __init__(self):
+ self.edu_handlers = {}
+ self.query_handlers = {}
+
+ def register_edu_handler(self, edu_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation EDU of the given type.
+
+ Args:
+ edu_type (str): The type of the incoming EDU to register handler for
+ handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+ of the given type. The arguments are the origin server name and
+ the EDU contents.
+ """
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+ self.edu_handlers[edu_type] = handler
+
+ def register_query_handler(self, query_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation query of the given type.
+
+ Args:
+ query_type (str): Category name of the query, which should match
+ the string used by make_query.
+ handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+ incoming queries of this type. The return will be yielded
+ on and the result used as the response to the query request.
+ """
+ if query_type in self.query_handlers:
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
+
+ self.query_handlers[query_type] = handler
+
+ @defer.inlineCallbacks
+ def on_edu(self, edu_type, origin, content):
+ handler = self.edu_handlers.get(edu_type)
+ if not handler:
+ logger.warn("No handler registered for EDU type %s", edu_type)
+
+ try:
+ yield handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception as e:
+ logger.exception("Failed to handle edu %r", edu_type)
+
+ def on_query(self, query_type, args):
+ handler = self.query_handlers.get(query_type)
+ if not handler:
+ logger.warn("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+ return handler(args)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
- """This layer is responsible for replicating with remote home servers over
- the given transport. I.e., does the sending and receiving of PDUs to
- remote home servers.
-
- The layer communicates with the rest of the server via a registered
- ReplicationHandler.
-
- In more detail, the layer:
- * Receives incoming data and processes it into transactions and pdus.
- * Fetches any PDUs it thinks it might have missed.
- * Keeps the current state for contexts up to date by applying the
- suitable conflict resolution.
- * Sends outgoing pdus wrapped in transactions.
- * Fills out the references to previous pdus/transactions appropriately
- for outgoing data.
- """
-
- def __init__(self, hs, transport_layer):
- self.server_name = hs.hostname
-
- self.keyring = hs.get_keyring()
-
- self.transport_layer = transport_layer
-
- self.federation_client = self
-
- self.store = hs.get_datastore()
-
- self.handler = None
- self.edu_handlers = {}
- self.query_handlers = {}
-
- self._clock = hs.get_clock()
-
- self.transaction_actions = TransactionActions(self.store)
-
- self.hs = hs
-
- super(ReplicationLayer, self).__init__(hs)
-
- def __str__(self):
- return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86f0..0f0c687b37 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,21 +31,21 @@ Events are replicated via a separate events stream.
from .units import Edu
+from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
import synapse.metrics
from blist import sorteddict
-import ujson
+from collections import namedtuple
+import logging
-metrics = synapse.metrics.get_metrics_for(__name__)
+from six import itervalues, iteritems
+
+logger = logging.getLogger(__name__)
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
+metrics = synapse.metrics.get_metrics_for(__name__)
class FederationRemoteSendQueue(object):
@@ -54,18 +54,20 @@ class FederationRemoteSendQueue(object):
def __init__(self, hs):
self.server_name = hs.hostname
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
- self.presence_map = {}
- self.presence_changed = sorteddict()
+ self.presence_map = {} # Pending presence map user_id -> UserPresenceState
+ self.presence_changed = sorteddict() # Stream position -> user_id
- self.keyed_edu = {}
- self.keyed_edu_changed = sorteddict()
+ self.keyed_edu = {} # (destination, key) -> EDU
+ self.keyed_edu_changed = sorteddict() # stream position -> (destination, key)
- self.edus = sorteddict()
+ self.edus = sorteddict() # stream position -> Edu
- self.failures = sorteddict()
+ self.failures = sorteddict() # stream position -> (destination, Failure)
- self.device_messages = sorteddict()
+ self.device_messages = sorteddict() # stream position -> destination
self.pos = 1
self.pos_time = sorteddict()
@@ -121,7 +123,9 @@ class FederationRemoteSendQueue(object):
del self.presence_changed[key]
user_ids = set(
- user_id for uids in self.presence_changed.values() for _, user_id in uids
+ user_id
+ for uids in itervalues(self.presence_changed)
+ for user_id in uids
)
to_del = [
@@ -186,37 +190,50 @@ class FederationRemoteSendQueue(object):
else:
self.edus[pos] = edu
- def send_presence(self, destination, states):
- """As per TransactionQueue"""
+ self.notifier.on_new_replication_data()
+
+ def send_presence(self, states):
+ """As per TransactionQueue
+
+ Args:
+ states (list(UserPresenceState))
+ """
pos = self._next_pos()
- self.presence_map.update({
- state.user_id: state
- for state in states
- })
+ # We only want to send presence for our own users, so lets always just
+ # filter here just in case.
+ local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- self.presence_changed[pos] = [
- (destination, state.user_id) for state in states
- ]
+ self.presence_map.update({state.user_id: state for state in local_states})
+ self.presence_changed[pos] = [state.user_id for state in local_states]
+
+ self.notifier.on_new_replication_data()
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
+ self.notifier.on_new_replication_data()
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.device_messages[pos] = destination
+ self.notifier.on_new_replication_data()
def get_current_token(self):
return self.pos - 1
- def get_replication_rows(self, token, limit, federation_ack=None):
- """
+ def federation_ack(self, token):
+ self._clear_queue_before_pos(token)
+
+ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ """Get rows to be sent over federation between the two tokens
+
Args:
- token (int)
+ from_token (int)
+ to_token(int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
@@ -225,9 +242,11 @@ class FederationRemoteSendQueue(object):
# TODO: Handle limit.
# To handle restarts where we wrap around
- if token > self.pos:
- token = -1
+ if from_token > self.pos:
+ from_token = -1
+ # list of tuple(int, BaseFederationRow), where the first is the position
+ # of the federation stream.
rows = []
# There should be only one reader, so lets delete everything its
@@ -237,62 +256,295 @@ class FederationRemoteSendQueue(object):
# Fetch changed presence
keys = self.presence_changed.keys()
- i = keys.bisect_right(token)
- dest_user_ids = set(
- (pos, dest_user_id)
- for pos in keys[i:]
- for dest_user_id in self.presence_changed[pos]
- )
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ dest_user_ids = [
+ (pos, user_id)
+ for pos in keys[i:j]
+ for user_id in self.presence_changed[pos]
+ ]
- for (key, (dest, user_id)) in dest_user_ids:
- rows.append((key, PRESENCE_TYPE, ujson.dumps({
- "destination": dest,
- "state": self.presence_map[user_id].as_dict(),
- })))
+ for (key, user_id) in dest_user_ids:
+ rows.append((key, PresenceRow(
+ state=self.presence_map[user_id],
+ )))
# Fetch changes keyed edus
keys = self.keyed_edu_changed.keys()
- i = keys.bisect_right(token)
- keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
-
- for (pos, (destination, edu_key)) in keyed_edus:
- rows.append(
- (pos, KEYED_EDU_TYPE, ujson.dumps({
- "key": edu_key,
- "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
- }))
- )
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ # We purposefully clobber based on the key here, python dict comprehensions
+ # always use the last value, so this will correctly point to the last
+ # stream position.
+ keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+
+ for ((destination, edu_key), pos) in iteritems(keyed_edus):
+ rows.append((pos, KeyedEduRow(
+ key=edu_key,
+ edu=self.keyed_edu[(destination, edu_key)],
+ )))
# Fetch changed edus
keys = self.edus.keys()
- i = keys.bisect_right(token)
- edus = set((k, self.edus[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ edus = ((k, self.edus[k]) for k in keys[i:j])
for (pos, edu) in edus:
- rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+ rows.append((pos, EduRow(edu)))
# Fetch changed failures
keys = self.failures.keys()
- i = keys.bisect_right(token)
- failures = set((k, self.failures[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ failures = ((k, self.failures[k]) for k in keys[i:j])
for (pos, (destination, failure)) in failures:
- rows.append((pos, FAILURE_TYPE, ujson.dumps({
- "destination": destination,
- "failure": failure,
- })))
+ rows.append((pos, FailureRow(
+ destination=destination,
+ failure=failure,
+ )))
# Fetch changed device messages
keys = self.device_messages.keys()
- i = keys.bisect_right(token)
- device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ device_messages = {self.device_messages[k]: k for k in keys[i:j]}
- for (pos, destination) in device_messages:
- rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
- "destination": destination,
- })))
+ for (destination, pos) in iteritems(device_messages):
+ rows.append((pos, DeviceRow(
+ destination=destination,
+ )))
# Sort rows based on pos
rows.sort()
- return rows
+ return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+ """Base class for rows to be sent in the federation stream.
+
+ Specifies how to identify, serialize and deserialize the different types.
+ """
+
+ TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+
+ @staticmethod
+ def from_data(data):
+ """Parse the data from the federation stream into a row.
+
+ Args:
+ data: The value of ``data`` from FederationStreamRow.data, type
+ depends on the type of stream
+ """
+ raise NotImplementedError()
+
+ def to_data(self):
+ """Serialize this row to be sent over the federation stream.
+
+ Returns:
+ The value to be sent in FederationStreamRow.data. The type depends
+ on the type of stream.
+ """
+ raise NotImplementedError()
+
+ def add_to_buffer(self, buff):
+ """Add this row to the appropriate field in the buffer ready for this
+ to be sent over federation.
+
+ We use a buffer so that we can batch up events that have come in at
+ the same time and send them all at once.
+
+ Args:
+ buff (BufferedToSend)
+ """
+ raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+ "state", # UserPresenceState
+))):
+ TypeId = "p"
+
+ @staticmethod
+ def from_data(data):
+ return PresenceRow(
+ state=UserPresenceState.from_dict(data)
+ )
+
+ def to_data(self):
+ return self.state.as_dict()
+
+ def add_to_buffer(self, buff):
+ buff.presence.append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+ "key", # tuple(str) - the edu key passed to send_edu
+ "edu", # Edu
+))):
+ """Streams EDUs that have an associated key that is ued to clobber. For example,
+ typing EDUs clobber based on room_id.
+ """
+
+ TypeId = "k"
+
+ @staticmethod
+ def from_data(data):
+ return KeyedEduRow(
+ key=tuple(data["key"]),
+ edu=Edu(**data["edu"]),
+ )
+
+ def to_data(self):
+ return {
+ "key": self.key,
+ "edu": self.edu.get_internal_dict(),
+ }
+
+ def add_to_buffer(self, buff):
+ buff.keyed_edus.setdefault(
+ self.edu.destination, {}
+ )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+ "edu", # Edu
+))):
+ """Streams EDUs that don't have keys. See KeyedEduRow
+ """
+ TypeId = "e"
+
+ @staticmethod
+ def from_data(data):
+ return EduRow(Edu(**data))
+
+ def to_data(self):
+ return self.edu.get_internal_dict()
+
+ def add_to_buffer(self, buff):
+ buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+ "destination", # str
+ "failure",
+))):
+ """Streams failures to a remote server. Failures are issued when there was
+ something wrong with a transaction the remote sent us, e.g. it included
+ an event that was invalid.
+ """
+
+ TypeId = "f"
+
+ @staticmethod
+ def from_data(data):
+ return FailureRow(
+ destination=data["destination"],
+ failure=data["failure"],
+ )
+
+ def to_data(self):
+ return {
+ "destination": self.destination,
+ "failure": self.failure,
+ }
+
+ def add_to_buffer(self, buff):
+ buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+ "destination", # str
+))):
+ """Streams the fact that either a) there is pending to device messages for
+ users on the remote, or b) a local users device has changed and needs to
+ be sent to the remote.
+ """
+ TypeId = "d"
+
+ @staticmethod
+ def from_data(data):
+ return DeviceRow(destination=data["destination"])
+
+ def to_data(self):
+ return {"destination": self.destination}
+
+ def add_to_buffer(self, buff):
+ buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+ Row.TypeId: Row
+ for Row in (
+ PresenceRow,
+ KeyedEduRow,
+ EduRow,
+ FailureRow,
+ DeviceRow,
+ )
+}
+
+
+ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
+ "presence", # list(UserPresenceState)
+ "keyed_edus", # dict of destination -> { key -> Edu }
+ "edus", # dict of destination -> [Edu]
+ "failures", # dict of destination -> [failures]
+ "device_destinations", # set of destinations
+))
+
+
+def process_rows_for_federation(transaction_queue, rows):
+ """Parse a list of rows from the federation stream and put them in the
+ transaction queue ready for sending to the relevant homeservers.
+
+ Args:
+ transaction_queue (TransactionQueue)
+ rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+ """
+
+ # The federation stream contains a bunch of different types of
+ # rows that need to be handled differently. We parse the rows, put
+ # them into the appropriate collection and then send them off.
+
+ buff = ParsedFederationStreamData(
+ presence=[],
+ keyed_edus={},
+ edus={},
+ failures={},
+ device_destinations=set(),
+ )
+
+ # Parse the rows in the stream and add to the buffer
+ for row in rows:
+ if row.type not in TypeToRow:
+ logger.error("Unrecognized federation row type %r", row.type)
+ continue
+
+ RowType = TypeToRow[row.type]
+ parsed_row = RowType.from_data(row.data)
+ parsed_row.add_to_buffer(buff)
+
+ if buff.presence:
+ transaction_queue.send_presence(buff.presence)
+
+ for destination, edu_map in iteritems(buff.keyed_edus):
+ for key, edu in edu_map.items():
+ transaction_queue.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=key,
+ )
+
+ for destination, edu_list in iteritems(buff.edus):
+ for edu in edu_list:
+ transaction_queue.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=None,
+ )
+
+ for destination, failure_list in iteritems(buff.failures):
+ for failure in failure_list:
+ transaction_queue.send_failure(destination, failure)
+
+ for destination in buff.device_destinations:
+ transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index bb3d9258a6..ded2b1871a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,22 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import datetime
from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
+from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
- get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id
-from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
import logging
@@ -43,6 +40,10 @@ sent_pdus_destination_dist = client_metrics.register_distribution(
)
sent_edus_counter = client_metrics.register_counter("sent_edus")
+sent_transactions_counter = client_metrics.register_counter("sent_transactions")
+
+events_processed_counter = client_metrics.register_counter("events_processed")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -79,8 +80,18 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {}
- # Presence needs to be separate as we send single aggragate EDUs
+ # Map of user_id -> UserPresenceState for all the pending presence
+ # to be sent out by user_id. Entries here get processed and put in
+ # pending_presence_by_dest
+ self.pending_presence = {}
+
+ # Map of destination -> user_id -> UserPresenceState of pending presence
+ # to be sent to each destinations
self.pending_presence_by_dest = presence = {}
+
+ # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
+ # based on their key (e.g. typing events by room_id)
+ # Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
metrics.register_callback(
@@ -99,7 +110,12 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
+ # destination -> stream_id of last successfully sent to-device message.
+ # NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
+
+ # destination -> stream_id of last successfully sent device list
+ # update.
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id
@@ -110,6 +126,8 @@ class TransactionQueue(object):
self._is_processing = False
self._last_poked_id = -1
+ self._processing_pending_presence = False
+
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -130,7 +148,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- @defer.inlineCallbacks
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
@@ -140,12 +157,19 @@ class TransactionQueue(object):
if self._is_processing:
return
+ # fire off a processing loop in the background. It's likely it will
+ # outlast the current request, so run it in the sentinel logcontext.
+ with PreserveLoggingContext():
+ self._process_event_queue_loop()
+
+ @defer.inlineCallbacks
+ def _process_event_queue_loop(self):
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=20,
+ last_token, self._last_poked_id, limit=100,
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -153,28 +177,35 @@ class TransactionQueue(object):
if not events and next_token >= self._last_poked_id:
break
- for event in events:
+ @defer.inlineCallbacks
+ def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
if not is_mine and send_on_behalf_of is None:
- continue
-
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- users_in_room = yield self.state.get_current_user_in_room(
- event.room_id, latest_event_ids=[
- prev_id for prev_id, _ in event.prev_events
- ],
- )
+ return
+
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = yield self.state.get_current_hosts_in_room(
+ event.room_id, latest_event_ids=[
+ prev_id for prev_id, _ in event.prev_events
+ ],
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
+
+ destinations = set(destinations)
- destinations = set(
- get_domain_from_id(user_id) for user_id in users_in_room
- )
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
@@ -185,10 +216,44 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ events_by_room = {}
+ for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.run_in_background(handle_room_events, evs)
+ for evs in events_by_room.itervalues()
+ ],
+ consumeErrors=True
+ ))
+
yield self.store.update_federation_out_pos(
"events", next_token
)
+ if events:
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_lag.set(
+ now - ts, "federation_sender",
+ )
+ synapse.metrics.event_processing_last_ts.set(
+ ts, "federation_sender",
+ )
+
+ events_processed_counter.inc_by(len(events))
+
+ synapse.metrics.event_processing_positions.set(
+ next_token, "federation_sender",
+ )
+
finally:
self._is_processing = False
@@ -217,21 +282,75 @@ class TransactionQueue(object):
(pdu, order)
)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
- def send_presence(self, destination, states):
- if not self.can_send_to(destination):
- return
+ @logcontext.preserve_fn # the caller should not yield on this
+ @defer.inlineCallbacks
+ def send_presence(self, states):
+ """Send the new presence states to the appropriate destinations.
+
+ This actually queues up the presence states ready for sending and
+ triggers a background task to process them and send out the transactions.
+
+ Args:
+ states (list(UserPresenceState))
+ """
- self.pending_presence_by_dest.setdefault(destination, {}).update({
+ # First we queue up the new presence by user ID, so multiple presence
+ # updates in quick successtion are correctly handled
+ # We only want to send presence for our own users, so lets always just
+ # filter here just in case.
+ self.pending_presence.update({
state.user_id: state for state in states
+ if self.is_mine_id(state.user_id)
})
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ # We then handle the new pending presence in batches, first figuring
+ # out the destinations we need to send each state to and then poking it
+ # to attempt a new transaction. We linearize this so that we don't
+ # accidentally mess up the ordering and send multiple presence updates
+ # in the wrong order
+ if self._processing_pending_presence:
+ return
+
+ self._processing_pending_presence = True
+ try:
+ while True:
+ states_map = self.pending_presence
+ self.pending_presence = {}
+
+ if not states_map:
+ break
+
+ yield self._process_presence_inner(states_map.values())
+ except Exception:
+ logger.exception("Error sending presence states to servers")
+ finally:
+ self._processing_pending_presence = False
+
+ @measure_func("txnqueue._process_presence")
+ @defer.inlineCallbacks
+ def _process_presence_inner(self, states):
+ """Given a list of states populate self.pending_presence_by_dest and
+ poke to send a new transaction to each destination
+
+ Args:
+ states (list(UserPresenceState))
+ """
+ hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
+
+ for destinations, states in hosts_and_states:
+ for destination in destinations:
+ if not self.can_send_to(destination):
+ continue
+
+ self.pending_presence_by_dest.setdefault(
+ destination, {}
+ ).update({
+ state.user_id: state for state in states
+ })
+
+ self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
@@ -253,9 +372,7 @@ class TransactionQueue(object):
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
@@ -268,9 +385,7 @@ class TransactionQueue(object):
destination, []
).append(failure)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
@@ -279,15 +394,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination):
return
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def get_current_token(self):
return 0
- @defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
+ """Try to start a new transaction to this destination
+
+ If there is already a transaction in progress to this destination,
+ returns immediately. Otherwise kicks off the process of sending a
+ transaction in the background.
+
+ Args:
+ destination (str):
+
+ Returns:
+ None
+ """
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
@@ -300,12 +424,46 @@ class TransactionQueue(object):
)
return
+ logger.debug("TX [%s] Starting transaction loop", destination)
+
+ # Drop the logcontext before starting the transaction. It doesn't
+ # really make sense to log all the outbound transactions against
+ # whatever path led us to this point: that's pretty arbitrary really.
+ #
+ # (this also means we can fire off _perform_transaction without
+ # yielding)
+ with logcontext.PreserveLoggingContext():
+ self._transaction_transmission_loop(destination)
+
+ @defer.inlineCallbacks
+ def _transaction_transmission_loop(self, destination):
+ pending_pdus = []
try:
self.pending_transactions[destination] = 1
+ # This will throw if we wouldn't retry. We do this here so we fail
+ # quickly, but we will later check this again in the http client,
+ # hence why we throw the result away.
+ yield get_retry_limiter(destination, self.clock, self.store)
+
+ # XXX: what's this for?
yield run_on_reactor()
+ pending_pdus = []
while True:
+ device_message_edus, device_stream_id, dev_list_id = (
+ yield self._get_new_device_messages(destination)
+ )
+
+ # BEGIN CRITICAL SECTION
+ #
+ # In order to avoid a race condition, we need to make sure that
+ # the following code (from popping the queues up to the point
+ # where we decide if we actually have any pending messages) is
+ # atomic - otherwise new PDUs or EDUs might arrive in the
+ # meantime, but not get sent because we hold the
+ # pending_transactions flag.
+
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
@@ -315,17 +473,6 @@ class TransactionQueue(object):
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
- limiter = yield get_retry_limiter(
- destination,
- self.clock,
- self.store,
- backoff_on_404=True, # If we get a 404 the other side has gone
- )
-
- device_message_edus, device_stream_id, dev_list_id = (
- yield self._get_new_device_messages(destination)
- )
-
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
@@ -355,11 +502,13 @@ class TransactionQueue(object):
)
return
+ # END CRITICAL SECTION
+
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
- limiter=limiter,
)
if success:
+ sent_transactions_counter.inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
@@ -375,12 +524,26 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break
- except NotRetryingDestination:
+ except NotRetryingDestination as e:
logger.debug(
- "TX [%s] not ready for retry yet - "
+ "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now",
destination,
+ datetime.datetime.fromtimestamp(
+ (e.retry_last_ts + e.retry_interval) / 1000.0
+ ),
)
+ except FederationDeniedError as e:
+ logger.info(e)
+ except Exception as e:
+ logger.warn(
+ "TX [%s] Failed to send transaction: %s",
+ destination,
+ e,
+ )
+ for p, _ in pending_pdus:
+ logger.info("Failed to send event %s to %s", p.event_id,
+ destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@@ -420,7 +583,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures, limiter):
+ pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@@ -430,132 +593,104 @@ class TransactionQueue(object):
success = True
- try:
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- txn_id = str(self._next_txn_id)
+ txn_id = str(self._next_txn_id)
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pdus),
- len(edus),
- len(failures)
- )
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pdus),
+ len(edus),
+ len(failures)
+ )
- logger.debug("TX [%s] Persisting transaction...", destination)
+ logger.debug("TX [%s] Persisting transaction...", destination)
- transaction = Transaction.create_new(
- origin_server_ts=int(self.clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
- self._next_txn_id += 1
+ self._next_txn_id += 1
- yield self.transaction_actions.prepare_to_send(transaction)
+ yield self.transaction_actions.prepare_to_send(transaction)
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pdus),
- len(edus),
- len(failures),
- )
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pdus),
+ len(edus),
+ len(failures),
+ )
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self.clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
-
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
)
- raise e
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+ if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
+ raise e
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
-
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
- logger.debug("TX [%s] Marked as delivered", destination)
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- if code != 200:
- for p in pdus:
- logger.info(
- "Failed to send event %s to %s", p.event_id, destination
- )
- success = False
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- success = False
+ logger.debug("TX [%s] Marked as delivered", destination)
+ if code != 200:
for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
-
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
success = False
- for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
-
defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2cc4..6db8efa6dd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
import logging
+import urllib
logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state/%s/" % room_id
+ path = _create_path(PREFIX, "/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -71,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state_ids/%s/" % room_id
+ path = _create_path(PREFIX, "/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -93,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = PREFIX + "/event/%s/" % (event_id, )
+ path = _create_path(PREFIX, "/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@@ -119,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = PREFIX + "/backfill/%s/" % (room_id,)
+ path = _create_path(PREFIX, "/backfill/%s/", room_id)
args = {
"v": event_tuples,
@@ -157,12 +159,15 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
+ path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
response = yield self.client.put_json(
transaction.destination,
- path=PREFIX + "/send/%s/" % transaction.transaction_id,
+ path=path,
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
+ backoff_on_404=True, # If we get a 404 the other side has gone
)
logger.debug(
@@ -174,8 +179,9 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def make_query(self, destination, query_type, args, retry_on_dns_fail):
- path = PREFIX + "/query/%s" % query_type
+ def make_query(self, destination, query_type, args, retry_on_dns_fail,
+ ignore_backoff=False):
+ path = _create_path(PREFIX, "/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@@ -183,6 +189,7 @@ class TransportLayerClient(object):
args=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
+ ignore_backoff=ignore_backoff,
)
defer.returnValue(content)
@@ -190,19 +197,54 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def make_membership_event(self, destination, room_id, user_id, membership):
+ """Asks a remote server to build and sign us a membership event
+
+ Note that this does not append any events to any graphs.
+
+ Args:
+ destination (str): address of remote homeserver
+ room_id (str): room to join/leave
+ user_id (str): user to be joined/left
+ membership (str): one of join/leave
+
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body (ie, the new event).
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if the remote destination
+ is not in our federation whitelist
+ """
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+ path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
+
+ ignore_backoff = False
+ retry_on_dns_fail = False
+
+ if membership == Membership.LEAVE:
+ # we particularly want to do our best to send leave events. The
+ # problem is that if it fails, we won't retry it later, so if the
+ # remote server was just having a momentary blip, the room will be
+ # out of sync.
+ ignore_backoff = True
+ retry_on_dns_fail = True
content = yield self.client.get_json(
destination=destination,
path=path,
- retry_on_dns_fail=False,
+ retry_on_dns_fail=retry_on_dns_fail,
timeout=20000,
+ ignore_backoff=ignore_backoff,
)
defer.returnValue(content)
@@ -210,7 +252,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -223,12 +265,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
+
+ # we want to do our best to send this through. The problem is
+ # that if it fails, we won't retry it later, so if the remote
+ # server was just having a momentary blip, the room will be out of
+ # sync.
+ ignore_backoff=True,
)
defer.returnValue(response)
@@ -236,12 +284,13 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
- path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
+ ignore_backoff=True,
)
defer.returnValue(response)
@@ -269,6 +318,7 @@ class TransportLayerClient(object):
destination=remote_server,
path=path,
args=args,
+ ignore_backoff=True,
)
defer.returnValue(response)
@@ -276,7 +326,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+ path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@@ -289,7 +339,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@@ -301,7 +351,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
- path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@@ -363,7 +413,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = PREFIX + "/user/devices/" + user_id
+ path = _create_path(PREFIX, "/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@@ -413,7 +463,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
- path = PREFIX + "/get_missing_events/%s" % (room_id,)
+ path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@@ -428,3 +478,475 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ @log_function
+ def get_group_profile(self, destination, group_id, requester_user_id):
+ """Get a group profile
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_profile(self, destination, group_id, requester_user_id, content):
+ """Update a remote group profile
+
+ Args:
+ destination (str)
+ group_id (str)
+ requester_user_id (str)
+ content (dict): The new profile of the group
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_summary(self, destination, group_id, requester_user_id):
+ """Get a group summary
+ """
+ path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ """Get all rooms in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+ content):
+ """Add a room to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
+ config_key, content):
+ """Update room in group
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/room/%s/config/%s",
+ group_id, room_id, config_key,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ """Remove a room from a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users that have been invited to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def accept_group_invite(self, destination, group_id, user_id, content):
+ """Accept a group invite
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/users/%s/accept_invite",
+ group_id, user_id,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def join_group(self, destination, group_id, user_id, content):
+ """Attempts to join a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ """Invite a user to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group_notification(self, destination, group_id, user_id, content):
+ """Sent by group server to inform a user's server that they have been
+ invited.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group(self, destination, group_id, requester_user_id,
+ user_id, content):
+ """Remove a user fron a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group_notification(self, destination, group_id, user_id,
+ content):
+ """Sent by group server to inform a user's server that they have been
+ kicked from the group.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def renew_group_attestation(self, destination, group_id, user_id, content):
+ """Sent by either a group server or a user's server to periodically update
+ the attestations
+ """
+
+ path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id, content):
+ """Update a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id):
+ """Delete a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_categories(self, destination, group_id, requester_user_id):
+ """Get all categories in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ """Get category info in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_category(self, destination, group_id, requester_user_id, category_id,
+ content):
+ """Update a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_category(self, destination, group_id, requester_user_id,
+ category_id):
+ """Delete a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_roles(self, destination, group_id, requester_user_id):
+ """Get all roles in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Get a roles info
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_role(self, destination, group_id, requester_user_id, role_id,
+ content):
+ """Update a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Delete a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id, content):
+ """Update a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def set_group_join_policy(self, destination, group_id, requester_user_id,
+ content):
+ """Sets the join policy for a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+ return self.client.put_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id):
+ """Delete a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def bulk_get_publicised_groups(self, destination, user_ids):
+ """Get the groups a list of users are publicising
+ """
+
+ path = PREFIX + "/get_groups_publicised"
+
+ content = {"user_ids": user_ids}
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+
+def _create_path(prefix, path, *args):
+ """Creates a path from the prefix, path template and args. Ensures that
+ all args are url encoded.
+
+ Example:
+
+ _create_path(PREFIX, "/event/%s/", event_id)
+
+ Args:
+ prefix (str)
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index c840da834c..19d09f5422 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +17,7 @@
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -24,7 +25,8 @@ from synapse.http.servlet import (
)
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.types import ThirdPartyInstanceID
+from synapse.util.logcontext import run_in_background
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools
import logging
@@ -79,6 +81,8 @@ class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
+ self.store = hs.get_datastore()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -110,7 +114,7 @@ class Authenticator(object):
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
- except:
+ except Exception:
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
@@ -128,6 +132,12 @@ class Authenticator(object):
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
+ if (
+ self.federation_domain_whitelist is not None and
+ origin not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(origin)
+
if not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -138,18 +148,30 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
+ # If we get a valid signed request from the other side, its probably
+ # alive
+ retry_timings = yield self.store.get_destination_retry_timings(origin)
+ if retry_timings and retry_timings["retry_last_ts"]:
+ run_in_background(self._reset_retry_timings, origin)
+
defer.returnValue(origin)
+ @defer.inlineCallbacks
+ def _reset_retry_timings(self, origin):
+ try:
+ logger.info("Marking origin %r as up", origin)
+ yield self.store.set_destination_retry_timings(origin, 0, 0)
+ except Exception:
+ logger.exception("Error resetting retry timings on %s", origin)
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
- def __init__(self, handler, authenticator, ratelimiter, server_name,
- room_list_handler):
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
- self.room_list_handler = room_list_handler
def _wrap(self, func):
authenticator = self.authenticator
@@ -170,7 +192,7 @@ class BaseFederationServlet(object):
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
- except:
+ except Exception:
logger.exception("authenticate_request failed")
raise
@@ -263,7 +285,7 @@ class FederationSendServlet(BaseFederationServlet):
code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
- except:
+ except Exception:
logger.exception("on_incoming_transaction failed")
raise
@@ -581,7 +603,7 @@ class PublicRoomList(BaseFederationServlet):
else:
network_tuple = ThirdPartyInstanceID(None, None)
- data = yield self.room_list_handler.get_local_public_room_list(
+ data = yield self.handler.get_local_public_room_list(
limit, since_token,
network_tuple=network_tuple
)
@@ -602,7 +624,550 @@ class FederationVersionServlet(BaseFederationServlet):
}))
-SERVLET_CLASSES = (
+class FederationGroupsProfileServlet(BaseFederationServlet):
+ """Get/set the basic profile of a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_profile(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.update_group_profile(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+ PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_summary(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+ """Get the rooms in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+ """Add/remove room from group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+ """Update room config in group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+ """Get the users in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+ """Get the users that have been invited to a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_invited_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+ """Ask a group server to invite someone to the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+ """Accept an invitation from the group server
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.accept_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsJoinServlet(BaseFederationServlet):
+ """Attempt to join a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.join_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+ """Leave or kick a user from the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+ """A group server has invited a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "group_id doesn't match origin")
+
+ new_content = yield self.handler.on_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+ """A group server has removed a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.user_removed_from_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+ """A group or user's server renews their attestation
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ # We don't need to check auth here as we check the attestation signatures
+
+ new_content = yield self.handler.on_renew_attestation(
+ group_id, user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+ """Add/remove a room from the group summary, with optional category.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+ """Get all categories for a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+ """Add/remove/get a category in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_category(
+ group_id, requester_user_id, category_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.upsert_group_category(
+ group_id, requester_user_id, category_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_category(
+ group_id, requester_user_id, category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+ """Add/remove/get a role in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_role(
+ group_id, requester_user_id, role_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_role(
+ group_id, requester_user_id, role_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_role(
+ group_id, requester_user_id, role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+ """Add/remove a user from the group summary, with optional role.
+
+ Matches both:
+ - /groups/:group/summary/users/:user_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/get_groups_publicised$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ resp = yield self.handler.bulk_get_publicised_groups(
+ content["user_ids"], proxy=False,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+ """Sets whether a group is joinable without an invite or knock
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.set_group_join_policy(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
@@ -625,17 +1190,85 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
OpenIdUserInfo,
- PublicRoomList,
FederationVersionServlet,
)
+ROOM_LIST_CLASSES = (
+ PublicRoomList,
+)
+
+GROUP_SERVER_SERVLET_CLASSES = (
+ FederationGroupsProfileServlet,
+ FederationGroupsSummaryServlet,
+ FederationGroupsRoomsServlet,
+ FederationGroupsUsersServlet,
+ FederationGroupsInvitedUsersServlet,
+ FederationGroupsInviteServlet,
+ FederationGroupsAcceptInviteServlet,
+ FederationGroupsJoinServlet,
+ FederationGroupsRemoveUserServlet,
+ FederationGroupsSummaryRoomsServlet,
+ FederationGroupsCategoriesServlet,
+ FederationGroupsCategoryServlet,
+ FederationGroupsRolesServlet,
+ FederationGroupsRoleServlet,
+ FederationGroupsSummaryUsersServlet,
+ FederationGroupsAddRoomsServlet,
+ FederationGroupsAddRoomsConfigServlet,
+ FederationGroupsSettingJoinPolicyServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+ FederationGroupsLocalInviteServlet,
+ FederationGroupsRemoveLocalUserServlet,
+ FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+ FederationGroupsRenewAttestaionServlet,
+)
+
+
def register_servlets(hs, resource, authenticator, ratelimiter):
- for servletclass in SERVLET_CLASSES:
+ for servletclass in FEDERATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_federation_server(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in ROOM_LIST_CLASSES:
+ servletclass(
+ handler=hs.get_room_list_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_replication_layer(),
+ handler=hs.get_groups_attestation_renewer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
- room_list_handler=hs.get_room_list_handler(),
).register(resource)
diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/groups/__init__.py
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
new file mode 100644
index 0000000000..6f11fa374b
--- /dev/null
+++ b/synapse/groups/attestations.py
@@ -0,0 +1,199 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Attestations ensure that users and groups can't lie about their memberships.
+
+When a user joins a group the HS and GS swap attestations, which allow them
+both to independently prove to third parties their membership.These
+attestations have a validity period so need to be periodically renewed.
+
+If a user leaves (or gets kicked out of) a group, either side can still use
+their attestation to "prove" their membership, until the attestation expires.
+Therefore attestations shouldn't be relied on to prove membership in important
+cases, but can for less important situtations, e.g. showing a users membership
+of groups on their profile, showing flairs, etc.abs
+
+An attestsation is a signed blob of json that looks like:
+
+ {
+ "user_id": "@foo:a.example.com",
+ "group_id": "+bar:b.example.com",
+ "valid_until_ms": 1507994728530,
+ "signatures":{"matrix.org":{"ed25519:auto":"..."}}
+ }
+"""
+
+import logging
+import random
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import run_in_background
+
+from signedjson.sign import sign_json
+
+
+logger = logging.getLogger(__name__)
+
+
+# Default validity duration for new attestations we create
+DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
+
+# We add some jitter to the validity duration of attestations so that if we
+# add lots of users at once we don't need to renew them all at once.
+# The jitter is a multiplier picked randomly between the first and second number
+DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
+
+# Start trying to update our attestations when they come this close to expiring
+UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
+
+
+class GroupAttestationSigning(object):
+ """Creates and verifies group attestations.
+ """
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.clock = hs.get_clock()
+ self.server_name = hs.hostname
+ self.signing_key = hs.config.signing_key[0]
+
+ @defer.inlineCallbacks
+ def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+ """Verifies that the given attestation matches the given parameters.
+
+ An optional server_name can be supplied to explicitly set which server's
+ signature is expected. Otherwise assumes that either the group_id or user_id
+ is local and uses the other's server as the one to check.
+ """
+
+ if not server_name:
+ if get_domain_from_id(group_id) == self.server_name:
+ server_name = get_domain_from_id(user_id)
+ elif get_domain_from_id(user_id) == self.server_name:
+ server_name = get_domain_from_id(group_id)
+ else:
+ raise Exception("Expected either group_id or user_id to be local")
+
+ if user_id != attestation["user_id"]:
+ raise SynapseError(400, "Attestation has incorrect user_id")
+
+ if group_id != attestation["group_id"]:
+ raise SynapseError(400, "Attestation has incorrect group_id")
+ valid_until_ms = attestation["valid_until_ms"]
+
+ # TODO: We also want to check that *new* attestations that people give
+ # us to store are valid for at least a little while.
+ if valid_until_ms < self.clock.time_msec():
+ raise SynapseError(400, "Attestation expired")
+
+ yield self.keyring.verify_json_for_server(server_name, attestation)
+
+ def create_attestation(self, group_id, user_id):
+ """Create an attestation for the group_id and user_id with default
+ validity length.
+ """
+ validity_period = DEFAULT_ATTESTATION_LENGTH_MS
+ validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+ valid_until_ms = int(self.clock.time_msec() + validity_period)
+
+ return sign_json({
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": valid_until_ms,
+ }, self.server_name, self.signing_key)
+
+
+class GroupAttestionRenewer(object):
+ """Responsible for sending and receiving attestation updates.
+ """
+
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.assestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.is_mine_id = hs.is_mine_id
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self._renew_attestations_loop = self.clock.looping_call(
+ self._renew_attestations, 30 * 60 * 1000,
+ )
+
+ @defer.inlineCallbacks
+ def on_renew_attestation(self, group_id, user_id, content):
+ """When a remote updates an attestation
+ """
+ attestation = content["attestation"]
+
+ if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
+ raise SynapseError(400, "Neither user not group are on this server")
+
+ yield self.attestations.verify_attestation(
+ attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+
+ yield self.store.update_remote_attestion(group_id, user_id, attestation)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def _renew_attestations(self):
+ """Called periodically to check if we need to update any of our attestations
+ """
+
+ now = self.clock.time_msec()
+
+ rows = yield self.store.get_attestations_need_renewals(
+ now + UPDATE_ATTESTATION_TIME_MS
+ )
+
+ @defer.inlineCallbacks
+ def _renew_attestation(group_id, user_id):
+ try:
+ if not self.is_mine_id(group_id):
+ destination = get_domain_from_id(group_id)
+ elif not self.is_mine_id(user_id):
+ destination = get_domain_from_id(user_id)
+ else:
+ logger.warn(
+ "Incorrectly trying to do attestations for user: %r in %r",
+ user_id, group_id,
+ )
+ yield self.store.remove_attestation_renewal(group_id, user_id)
+ return
+
+ attestation = self.attestations.create_attestation(group_id, user_id)
+
+ yield self.transport_client.renew_group_attestation(
+ destination, group_id, user_id,
+ content={"attestation": attestation},
+ )
+
+ yield self.store.update_attestation_renewal(
+ group_id, user_id, attestation
+ )
+ except Exception:
+ logger.exception("Error renewing attestation of %r in %r",
+ user_id, group_id)
+
+ for row in rows:
+ group_id = row["group_id"]
+ user_id = row["user_id"]
+
+ run_in_background(_renew_attestation, group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
new file mode 100644
index 0000000000..2d95b04e0c
--- /dev/null
+++ b/synapse/groups/groups_server.py
@@ -0,0 +1,950 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Federation admin APIs
+# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: Audit log for admins (profile updates, membership changes, users who tried
+# to join but were rejected, etc)
+# TODO: Flairs
+
+
+class GroupsServerHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.attestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ @defer.inlineCallbacks
+ def check_group_is_ours(self, group_id, requester_user_id,
+ and_exists=False, and_is_admin=None):
+ """Check that the group is ours, and optionally if it exists.
+
+ If group does exist then return group.
+
+ Args:
+ group_id (str)
+ and_exists (bool): whether to also check if group exists
+ and_is_admin (str): whether to also check if given str is a user_id
+ that is an admin
+ """
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Group not on this server")
+
+ group = yield self.store.get_group(group_id)
+ if and_exists and not group:
+ raise SynapseError(404, "Unknown group")
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ if group and not is_user_in_group and not group["is_public"]:
+ raise SynapseError(404, "Unknown group")
+
+ if and_is_admin:
+ is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ defer.returnValue(group)
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the summary for a group as seen by requester_user_id.
+
+ The group summary consists of the profile of the room, and a curated
+ list of users and rooms. These list *may* be organised by role/category.
+ The roles/categories are ordered, and so are the users/rooms within them.
+
+ A user/room may appear in multiple roles/categories.
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ profile = yield self.get_group_profile(group_id, requester_user_id)
+
+ users, roles = yield self.store.get_users_for_summary_by_role(
+ group_id, include_private=is_user_in_group,
+ )
+
+ # TODO: Add profiles to users
+
+ rooms, categories = yield self.store.get_rooms_for_summary_by_category(
+ group_id, include_private=is_user_in_group,
+ )
+
+ for room_entry in rooms:
+ room_id = room_entry["room_id"]
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+ entry = dict(entry) # so we don't change whats cached
+ entry.pop("room_id", None)
+
+ room_entry["profile"] = entry
+
+ rooms.sort(key=lambda e: e.get("order", 0))
+
+ for entry in users:
+ user_id = entry["user_id"]
+
+ if not self.is_mine_id(requester_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, user_id,
+ )
+
+ user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ entry.update(user_profile)
+
+ users.sort(key=lambda e: e.get("order", 0))
+
+ membership_info = yield self.store.get_users_membership_info_in_group(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue({
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ })
+
+ @defer.inlineCallbacks
+ def update_group_summary_room(self, group_id, requester_user_id,
+ room_id, category_id, content):
+ """Add/update a room to the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_room(self, group_id, requester_user_id,
+ room_id, category_id):
+ """Remove a room from the summary
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_room_from_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def set_group_join_policy(self, group_id, requester_user_id, content):
+ """Sets the group join policy.
+
+ Currently supported policies are:
+ - "invite": an invite must be received and accepted in order to join.
+ - "open": anyone can join.
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ join_policy = _parse_join_policy_from_contents(content)
+ if join_policy is None:
+ raise SynapseError(
+ 400, "No value specified for 'm.join_policy'"
+ )
+
+ yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id, requester_user_id):
+ """Get all categories in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ categories = yield self.store.get_group_categories(
+ group_id=group_id,
+ )
+ defer.returnValue({"categories": categories})
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, requester_user_id, category_id):
+ """Get a specific category in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_category(self, group_id, requester_user_id, category_id, content):
+ """Add/Update a group category
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_category(self, group_id, requester_user_id, category_id):
+ """Delete a group category
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id
+ )
+
+ yield self.store.remove_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id, requester_user_id):
+ """Get all roles in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ roles = yield self.store.get_group_roles(
+ group_id=group_id,
+ )
+ defer.returnValue({"roles": roles})
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, requester_user_id, role_id):
+ """Get a specific role in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_role(self, group_id, requester_user_id, role_id, content):
+ """Add/update a role in a group
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_role(self, group_id, requester_user_id, role_id):
+ """Remove role from group
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
+ content):
+ """Add/update a users entry in the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
+ """Remove a user from the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_user_from_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_profile(self, group_id, requester_user_id):
+ """Get the group profile as seen by requester_user_id
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id)
+
+ group = yield self.store.get_group(group_id)
+
+ if group:
+ cols = [
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public",
+ ]
+ group_description = {key: group[key] for key in cols}
+ group_description["is_openly_joinable"] = group["join_policy"] == "open"
+
+ defer.returnValue(group_description)
+ else:
+ raise SynapseError(404, "Unknown group")
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, requester_user_id, content):
+ """Update the group profile
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ profile = {}
+ for keyname in ("name", "avatar_url", "short_description",
+ "long_description"):
+ if keyname in content:
+ value = content[keyname]
+ if not isinstance(value, basestring):
+ raise SynapseError(400, "%r value is not a string" % (keyname,))
+ profile[keyname] = value
+
+ yield self.store.update_group_profile(group_id, profile)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get the users in group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for user_result in user_results:
+ g_user_id = user_result["user_id"]
+ is_public = user_result["is_public"]
+ is_privileged = user_result["is_admin"]
+
+ entry = {"user_id": g_user_id}
+
+ profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+ entry.update(profile)
+
+ entry["is_public"] = bool(is_public)
+ entry["is_privileged"] = bool(is_privileged)
+
+ if not self.is_mine_id(g_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, g_user_id,
+ )
+
+ chunk.append(entry)
+
+ # TODO: If admin add lists of users whose attestations have timed out
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_user_count_estimate": len(user_results),
+ })
+
+ @defer.inlineCallbacks
+ def get_invited_users_in_group(self, group_id, requester_user_id):
+ """Get the users that have been invited to a group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ if not is_user_in_group:
+ raise SynapseError(403, "User not in group")
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+
+ user_profiles = []
+
+ for user_id in invited_users:
+ user_profile = {
+ "user_id": user_id
+ }
+ try:
+ profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ user_profile.update(profile)
+ except Exception as e:
+ logger.warn("Error getting profile for %s: %s", user_id, e)
+ user_profiles.append(user_profile)
+
+ defer.returnValue({
+ "chunk": user_profiles,
+ "total_user_count_estimate": len(invited_users),
+ })
+
+ @defer.inlineCallbacks
+ def get_rooms_in_group(self, group_id, requester_user_id):
+ """Get the rooms in group as seen by requester_user_id
+
+ This returns rooms in order of decreasing number of joined users
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ room_results = yield self.store.get_rooms_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for room_result in room_results:
+ room_id = room_result["room_id"]
+
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+
+ if not entry:
+ continue
+
+ entry["is_public"] = bool(room_result["is_public"])
+
+ chunk.append(entry)
+
+ chunk.sort(key=lambda e: -e["num_joined_members"])
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_room_count_estimate": len(room_results),
+ })
+
+ @defer.inlineCallbacks
+ def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+ """Add room to group
+ """
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
+ content):
+ """Update room in group
+ """
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ if config_key == "m.visibility":
+ is_public = _parse_visibility_dict(content)
+
+ yield self.store.update_room_in_group_visibility(
+ group_id, room_id,
+ is_public=is_public,
+ )
+ else:
+ raise SynapseError(400, "Uknown config option")
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def remove_room_from_group(self, group_id, requester_user_id, room_id):
+ """Remove room from group
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ yield self.store.remove_room_from_group(group_id, room_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite_to_group(self, group_id, user_id, requester_user_id, content):
+ """Invite user to group
+ """
+
+ group = yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ # TODO: Check if user knocked
+ # TODO: Check if user is already invited
+
+ content = {
+ "profile": {
+ "name": group["name"],
+ "avatar_url": group["avatar_url"],
+ },
+ "inviter": requester_user_id,
+ }
+
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ res = yield groups_local.on_invite(group_id, user_id, content)
+ local_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content.update({
+ "attestation": local_attestation,
+ })
+
+ res = yield self.transport_client.invite_to_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, content
+ )
+
+ user_profile = res.get("user_profile", {})
+ yield self.store.add_remote_profile_cache(
+ user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ if res["state"] == "join":
+ if not self.hs.is_mine_id(user_id):
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=False, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+ elif res["state"] == "invite":
+ yield self.store.add_group_invite(
+ group_id, user_id,
+ )
+ defer.returnValue({
+ "state": "invite"
+ })
+ elif res["state"] == "reject":
+ defer.returnValue({
+ "state": "reject"
+ })
+ else:
+ raise SynapseError(502, "Unknown state returned by HS")
+
+ @defer.inlineCallbacks
+ def _add_user(self, group_id, user_id, content):
+ """Add a user to a group based on a content dict.
+
+ See accept_invite, join_group.
+ """
+ if not self.hs.is_mine_id(user_id):
+ local_attestation = self.attestations.create_attestation(
+ group_id, user_id,
+ )
+
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ local_attestation = None
+ remote_attestation = None
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=is_public,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ defer.returnValue(local_attestation)
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, requester_user_id, content):
+ """User tries to accept an invite to the group.
+
+ This is different from them asking to join, and so should error if no
+ invite exists (and they're not a member of the group)
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_invited = yield self.store.is_user_invited_to_local_group(
+ group_id, requester_user_id,
+ )
+ if not is_invited:
+ raise SynapseError(403, "User not invited to group")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, requester_user_id, content):
+ """User tries to join the group.
+
+ This will error if the group requires an invite/knock to join
+ """
+
+ group_info = yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True
+ )
+ if group_info['join_policy'] != "open":
+ raise SynapseError(403, "Group is not publicly joinable")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def knock(self, group_id, requester_user_id, content):
+ """A user requests becoming a member of the group
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def accept_knock(self, group_id, requester_user_id, content):
+ """Accept a users knock to the room.
+
+ Errors if the user hasn't knocked, rather than inviting them.
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from the group; either a user is leaving or an admin
+ kicked them.
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_kick = False
+ if requester_user_id != user_id:
+ is_admin = yield self.store.is_user_admin_in_group(
+ group_id, requester_user_id
+ )
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ is_kick = True
+
+ yield self.store.remove_user_from_group(
+ group_id, user_id,
+ )
+
+ if is_kick:
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ yield groups_local.user_removed_from_group(group_id, user_id, {})
+ else:
+ yield self.transport_client.remove_user_from_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, {}
+ )
+
+ if not self.hs.is_mine_id(user_id):
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, requester_user_id, content):
+ group = yield self.check_group_is_ours(group_id, requester_user_id)
+
+ logger.info("Attempting to create group with ID: %r", group_id)
+
+ # parsing the id into a GroupID validates it.
+ group_id_obj = GroupID.from_string(group_id)
+
+ if group:
+ raise SynapseError(400, "Group already exists")
+
+ is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
+ if not is_admin:
+ if not self.hs.config.enable_group_creation:
+ raise SynapseError(
+ 403, "Only a server admin can create groups on this server",
+ )
+ localpart = group_id_obj.localpart
+ if not localpart.startswith(self.hs.config.group_creation_prefix):
+ raise SynapseError(
+ 400,
+ "Can only create groups with prefix %r on this server" % (
+ self.hs.config.group_creation_prefix,
+ ),
+ )
+
+ profile = content.get("profile", {})
+ name = profile.get("name")
+ avatar_url = profile.get("avatar_url")
+ short_description = profile.get("short_description")
+ long_description = profile.get("long_description")
+ user_profile = content.get("user_profile", {})
+
+ yield self.store.create_group(
+ group_id,
+ requester_user_id,
+ name=name,
+ avatar_url=avatar_url,
+ short_description=short_description,
+ long_description=long_description,
+ )
+
+ if not self.hs.is_mine_id(requester_user_id):
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=requester_user_id,
+ group_id=group_id,
+ )
+
+ local_attestation = self.attestations.create_attestation(
+ group_id,
+ requester_user_id,
+ )
+ else:
+ local_attestation = None
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, requester_user_id,
+ is_admin=True,
+ is_public=True, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ if not self.hs.is_mine_id(requester_user_id):
+ yield self.store.add_remote_profile_cache(
+ requester_user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ defer.returnValue({
+ "group_id": group_id,
+ })
+
+
+def _parse_join_policy_from_contents(content):
+ """Given a content for a request, return the specified join policy or None
+ """
+
+ join_policy_dict = content.get("m.join_policy")
+ if join_policy_dict:
+ return _parse_join_policy_dict(join_policy_dict)
+ else:
+ return None
+
+
+def _parse_join_policy_dict(join_policy_dict):
+ """Given a dict for the "m.join_policy" config return the join policy specified
+ """
+ join_policy_type = join_policy_dict.get("type")
+ if not join_policy_type:
+ return "invite"
+
+ if join_policy_type not in ("invite", "open"):
+ raise SynapseError(
+ 400, "Synapse only supports 'invite'/'open' join rule"
+ )
+ return join_policy_type
+
+
+def _parse_visibility_from_contents(content):
+ """Given a content for a request parse out whether the entity should be
+ public or not
+ """
+
+ visibility = content.get("m.visibility")
+ if visibility:
+ return _parse_visibility_dict(visibility)
+ else:
+ is_public = True
+
+ return is_public
+
+
+def _parse_visibility_dict(visibility):
+ """Given a dict for the "m.visibility" config return if the entity should
+ be public or not
+ """
+ vis_type = visibility.get("type")
+ if not vis_type:
+ return True
+
+ if vis_type not in ("public", "private"):
+ raise SynapseError(
+ 400, "Synapse only supports 'public'/'private' visibility"
+ )
+ return vis_type == "public"
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 5ad408f549..8f8fd82eb0 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -17,10 +17,8 @@ from .register import RegistrationHandler
from .room import (
RoomCreationHandler, RoomContextHandler,
)
-from .room_member import RoomMemberHandler
from .message import MessageHandler
from .federation import FederationHandler
-from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
from .identity import IdentityHandler
@@ -50,9 +48,7 @@ class Handlers(object):
self.registration_handler = RegistrationHandler(hs)
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
- self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs)
- self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index e83adc8339..e089e66fde 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,7 +53,20 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
- def ratelimit(self, requester):
+ @defer.inlineCallbacks
+ def ratelimit(self, requester, update=True):
+ """Ratelimits requests.
+
+ Args:
+ requester (Requester)
+ update (bool): Whether to record that a request is being processed.
+ Set to False when doing multiple checks for one request (e.g.
+ to check up front if we would reject the request), and set to
+ True for the last call for a given request.
+
+ Raises:
+ LimitExceededError if the request should be ratelimited
+ """
time_now = self.clock.time()
user_id = requester.user.to_string()
@@ -67,10 +80,25 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
+ # Check if there is a per user override in the DB.
+ override = yield self.store.get_ratelimit_for_user(user_id)
+ if override:
+ # If overriden with a null Hz then ratelimiting has been entirely
+ # disabled for the user
+ if not override.messages_per_second:
+ return
+
+ messages_per_second = override.messages_per_second
+ burst_count = override.burst_count
+ else:
+ messages_per_second = self.hs.config.rc_messages_per_second
+ burst_count = self.hs.config.rc_message_burst_count
+
allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now,
- msg_rate_hz=self.hs.config.rc_messages_per_second,
- burst_count=self.hs.config.rc_message_burst_count,
+ msg_rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
)
if not allowed:
raise LimitExceededError(
@@ -130,7 +158,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(
target_user, is_guest=True)
- handler = self.hs.get_handlers().room_member_handler
+ handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
target_user,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..b596f098fd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,14 +15,21 @@
from twisted.internet import defer
+import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import (
+ make_deferred_yieldable, run_in_background,
+)
import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
def log_failure(failure):
logger.error(
@@ -70,21 +77,25 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
- upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
- upper_bound, limit
+ self.current_max, limit
)
if not events:
break
+ events_by_room = {}
for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ @defer.inlineCallbacks
+ def handle_event(event):
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
- continue # no services need notifying
+ return # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
@@ -100,14 +111,35 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services
for service in services:
- preserve_fn(self.scheduler.submit_event_for_as)(
- service, event
- )
+ self.scheduler.submit_event_for_as(service, event)
+
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(handle_room_events, evs)
+ for evs in events_by_room.itervalues()
+ ], consumeErrors=True))
yield self.store.set_appservice_last_pos(upper_bound)
- if len(events) < limit:
- break
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_positions.set(
+ upper_bound, "appservice_sender",
+ )
+
+ events_processed_counter.inc_by(len(events))
+
+ synapse.metrics.event_processing_lag.set(
+ now - ts, "appservice_sender",
+ )
+ synapse.metrics.event_processing_last_ts.set(
+ ts, "appservice_sender",
+ )
finally:
self.is_processing = False
@@ -163,8 +195,11 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield preserve_context_over_deferred(defer.DeferredList([
- preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ results = yield make_deferred_yieldable(defer.DeferredList([
+ run_in_background(
+ self.appservice_api.query_3pe,
+ service, kind, protocol, fields,
+ )
for service in services
], consumeErrors=True))
@@ -225,11 +260,15 @@ class ApplicationServicesHandler(object):
event based on the service regex.
"""
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- yield s.is_interested(event, self.store)
- )
- ]
+
+ # we can't use a list comprehension here. Since python 3, list
+ # comprehensions use a generator internally. This means you can't yield
+ # inside of a list comprehension anymore.
+ interested_list = []
+ for s in services:
+ if (yield s.is_interested(event, self.store)):
+ interested_list.append(s)
+
defer.returnValue(interested_list)
def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fffba34383..a5365c4fe4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,14 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
+from twisted.internet import defer, threads
from ._base import BaseHandler
from synapse.api.constants import LoginType
+from synapse.api.errors import (
+ AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+ SynapseError,
+)
+from synapse.module_api import ModuleApi
from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
from twisted.web.client import PartialDownloadError
@@ -44,18 +50,23 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
- LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
+ LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth,
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
- self.sessions = {}
- account_handler = _AccountHandler(
- hs, check_user_exists=self.check_user_exists
+ # This is not a cache per se, but a store of all current sessions that
+ # expire after N hours
+ self.sessions = ExpiringCache(
+ cache_name="register_sessions",
+ clock=hs.get_clock(),
+ expiry_ms=self.SESSION_EXPIRE_MS,
+ reset_expiry_on_get=True,
)
+ account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -64,39 +75,120 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._password_enabled = hs.config.password_enabled
+
+ # we keep this as a list despite the O(N^2) implication so that we can
+ # keep PASSWORD first and avoid confusing clients which pick the first
+ # type in the list. (NB that the spec doesn't require us to do so and
+ # clients which favour types that they don't understand over those that
+ # they do are technically broken)
+ login_types = []
+ if self._password_enabled:
+ login_types.append(LoginType.PASSWORD)
+ for provider in self.password_providers:
+ if hasattr(provider, "get_supported_login_types"):
+ for t in provider.get_supported_login_types().keys():
+ if t not in login_types:
+ login_types.append(t)
+ self._supported_login_types = login_types
+
+ @defer.inlineCallbacks
+ def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ """
+ Checks that the user is who they claim to be, via a UI auth.
+
+ This is used for things like device deletion and password reset where
+ the user already has a valid access token, but we want to double-check
+ that it isn't stolen by re-authenticating them.
+
+ Args:
+ requester (Requester): The user, as given by the access token
+
+ request_body (dict): The body of the request sent by the client
+
+ clientip (str): The IP address of the client.
+
+ Returns:
+ defer.Deferred[dict]: the parameters for this request (which may
+ have been given only in a previous call).
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ any of the permitted login flows
+
+ AuthError if the client has completed a login flow, and it gives
+ a different user to `requester`
+ """
+
+ # build a list of supported flows
+ flows = [
+ [login_type] for login_type in self._supported_login_types
+ ]
+
+ result, params, _ = yield self.check_auth(
+ flows, request_body, clientip,
+ )
+
+ # find the completed login type
+ for login_type in self._supported_login_types:
+ if login_type not in result:
+ continue
+
+ user_id = result[login_type]
+ break
+ else:
+ # this can't happen
+ raise Exception(
+ "check_auth returned True but no successful login type",
+ )
+
+ # check that the UI auth matched the access token
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Invalid auth")
+
+ defer.returnValue(params)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
- protocol and handles the login flow.
+ protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
+ If no auth flows have been completed successfully, raises an
+ InteractiveAuthIncompleteError. To handle this, you can use
+ synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+ decorator.
+
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
+
clientip (str): The IP address of the client.
+
Returns:
- A tuple of (authed, dict, dict, session_id) where authed is true if
- the client has successfully completed an auth flow. If it is true
- the first dict contains the authenticated credentials of each stage.
+ defer.Deferred[dict, dict, str]: a deferred tuple of
+ (creds, params, session_id).
+
+ 'creds' contains the authenticated credentials of each stage.
- If authed is false, the first dictionary is the server response to
- the login request and should be passed back to the client.
+ 'params' contains the parameters for this request (which may
+ have been given only in a previous call).
- In either case, the second dict contains the parameters for this
- request (which may have been given only in a previous call).
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
- session_id is the ID of this session, either passed in by the client
- or assigned by the call to check_auth
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ all the stages in any of the permitted flows.
"""
authdict = None
@@ -124,11 +216,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
- defer.returnValue(
- (
- False, self._auth_dict_for_flows(flows, session),
- clientdict, session['id']
- )
+ raise InteractiveAuthIncompleteError(
+ self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@@ -139,14 +228,12 @@ class AuthHandler(BaseHandler):
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
- if login_type not in self.checkers:
- raise LoginError(400, "", Codes.UNRECOGNIZED)
try:
- result = yield self.checkers[login_type](authdict, clientip)
+ result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
- except LoginError, e:
+ except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
@@ -155,7 +242,7 @@ class AuthHandler(BaseHandler):
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
- raise e
+ raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
@@ -172,12 +259,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
- defer.returnValue((True, creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
ret.update(errordict)
- defer.returnValue((False, ret, clientdict, session['id']))
+ raise InteractiveAuthIncompleteError(
+ ret,
+ )
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -249,16 +338,37 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- def _check_password_auth(self, authdict, _):
- if "user" not in authdict or "password" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ @defer.inlineCallbacks
+ def _check_auth_dict(self, authdict, clientip):
+ """Attempt to validate the auth dict provided by a client
- user_id = authdict["user"]
- password = authdict["password"]
- if not user_id.startswith('@'):
- user_id = UserID.create(user_id, self.hs.hostname).to_string()
+ Args:
+ authdict (object): auth dict provided by the client
+ clientip (str): IP address of the client
- return self._check_password(user_id, password)
+ Returns:
+ Deferred: result of the stage verification.
+
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ login_type = authdict['type']
+ checker = self.checkers.get(login_type)
+ if checker is not None:
+ res = yield checker(authdict, clientip)
+ defer.returnValue(res)
+
+ # build a v1-login-style dict out of the authdict and fall back to the
+ # v1 code
+ user_id = authdict.get("user")
+
+ if user_id is None:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+ (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+ defer.returnValue(canonical_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -307,31 +417,47 @@ class AuthHandler(BaseHandler):
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
- @defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
+ return self._check_threepid('email', authdict)
+
+ def _check_msisdn(self, authdict, _):
+ return self._check_threepid('msisdn', authdict)
+
+ @defer.inlineCallbacks
+ def _check_dummy_auth(self, authdict, _):
+ yield run_on_reactor()
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ def _check_threepid(self, medium, authdict):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
+
identity_handler = self.hs.get_handlers().identity_handler
- logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
+ logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ if threepid['medium'] != medium:
+ raise LoginError(
+ 401,
+ "Expecting threepid of type '%s', got '%s'" % (
+ medium, threepid['medium'],
+ ),
+ errcode=Codes.UNAUTHORIZED
+ )
+
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
- @defer.inlineCallbacks
- def _check_dummy_auth(self, authdict, _):
- yield run_on_reactor()
- defer.returnValue(True)
-
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
@@ -371,26 +497,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- def validate_password_login(self, user_id, password):
- """
- Authenticates the user with their username and password.
-
- Used only by the v1 login API.
-
- Args:
- user_id (str): complete @user:id
- password (str): Password
- Returns:
- defer.Deferred: (str) canonical user id
- Raises:
- StoreError if there was a problem accessing the database
- LoginError if there was an authentication problem.
- """
- return self._check_password(user_id, password)
-
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@@ -404,13 +512,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- initial_display_name (str): display name to associate with the
- device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
- LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@@ -420,9 +525,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ try:
+ yield self.store.get_device(user_id, device_id)
+ except StoreError:
+ yield self.store.delete_access_token(access_token)
+ raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@@ -474,29 +581,115 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
+ def get_supported_login_types(self):
+ """Get a the login types supported for the /login API
+
+ By default this is just 'm.login.password' (unless password_enabled is
+ False in the config file), but password auth providers can provide
+ other login types.
+
+ Returns:
+ Iterable[str]: login types
+ """
+ return self._supported_login_types
+
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Authenticate a user against the LDAP and local databases.
+ def validate_login(self, username, login_submission):
+ """Authenticates the user for the /login API
- user_id is checked case insensitively against the local database, but
- will throw if there are multiple inexact matches.
+ Also used by the user-interactive auth flow to validate
+ m.login.password auth types.
Args:
- user_id (str): complete @user:id
+ username (str): username supplied by the user
+ login_submission (dict): the whole of the login submission
+ (including 'type' and other relevant fields)
Returns:
- (str) the canonical_user_id
+ Deferred[str, func]: canonical user id, and optional callback
+ to be called once the access token and device id are issued
Raises:
- LoginError if login fails
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
"""
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(
+ username, self.hs.hostname
+ ).to_string()
+
+ login_type = login_submission.get("type")
+ known_login_type = False
+
+ # special case to check for "password" for the check_password interface
+ # for the auth providers
+ password = login_submission.get("password")
+ if login_type == LoginType.PASSWORD:
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if not password:
+ raise SynapseError(400, "Missing parameter: password")
+
for provider in self.password_providers:
- is_valid = yield provider.check_password(user_id, password)
- if is_valid:
- defer.returnValue(user_id)
+ if (hasattr(provider, "check_password")
+ and login_type == LoginType.PASSWORD):
+ known_login_type = True
+ is_valid = yield provider.check_password(
+ qualified_user_id, password,
+ )
+ if is_valid:
+ defer.returnValue((qualified_user_id, None))
+
+ if (not hasattr(provider, "get_supported_login_types")
+ or not hasattr(provider, "check_auth")):
+ # this password provider doesn't understand custom login types
+ continue
+
+ supported_login_types = provider.get_supported_login_types()
+ if login_type not in supported_login_types:
+ # this password provider doesn't understand this login type
+ continue
+
+ known_login_type = True
+ login_fields = supported_login_types[login_type]
+
+ missing_fields = []
+ login_dict = {}
+ for f in login_fields:
+ if f not in login_submission:
+ missing_fields.append(f)
+ else:
+ login_dict[f] = login_submission[f]
+ if missing_fields:
+ raise SynapseError(
+ 400, "Missing parameters for login type %s: %s" % (
+ login_type,
+ missing_fields,
+ ),
+ )
- canonical_user_id = yield self._check_local_password(user_id, password)
+ result = yield provider.check_auth(
+ username, login_type, login_dict,
+ )
+ if result:
+ if isinstance(result, str):
+ result = (result, None)
+ defer.returnValue(result)
+
+ if login_type == LoginType.PASSWORD:
+ known_login_type = True
+
+ canonical_user_id = yield self._check_local_password(
+ qualified_user_id, password,
+ )
- if canonical_user_id:
- defer.returnValue(canonical_user_id)
+ if canonical_user_id:
+ defer.returnValue((canonical_user_id, None))
+
+ if not known_login_type:
+ raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@@ -522,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
- result = self.validate_hash(password, password_hash)
+ result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
@@ -546,22 +739,65 @@ class AuthHandler(BaseHandler):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword, requester=None):
- password_hash = self.hash(newpassword)
+ def delete_access_token(self, access_token):
+ """Invalidate a single access token
- except_access_token_id = requester.access_token_id if requester else None
+ Args:
+ access_token (str): access token to be deleted
- try:
- yield self.store.user_set_password_hash(user_id, password_hash)
- except StoreError as e:
- if e.code == 404:
- raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
- raise e
- yield self.store.user_delete_access_tokens(
- user_id, except_access_token_id
+ Returns:
+ Deferred
+ """
+ user_info = yield self.auth.get_user_by_access_token(access_token)
+ yield self.store.delete_access_token(access_token)
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ yield provider.on_logged_out(
+ user_id=str(user_info["user"]),
+ device_id=user_info["device_id"],
+ access_token=access_token,
+ )
+
+ # delete pushers associated with this access token
+ if user_info["token_id"] is not None:
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ str(user_info["user"]), (user_info["token_id"], )
+ )
+
+ @defer.inlineCallbacks
+ def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+ device_id=None):
+ """Invalidate access tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str|None): access_token ID which should *not* be
+ deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ Returns:
+ Deferred
+ """
+ tokens_and_devices = yield self.store.user_delete_access_tokens(
+ user_id, except_token_id=except_token_id, device_id=device_id,
)
- yield self.hs.get_pusherpool().remove_pushers_by_user(
- user_id, except_access_token_id
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ for token, token_id, device_id in tokens_and_devices:
+ yield provider.on_logged_out(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=token,
+ )
+
+ # delete pushers associated with the access tokens
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks
@@ -599,16 +835,6 @@ class AuthHandler(BaseHandler):
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- self._prune_sessions()
-
- def _prune_sessions(self):
- for sid, sess in self.sessions.items():
- last_used = 0
- if 'last_used' in sess:
- last_used = sess['last_used']
- now = self.hs.get_clock().time_msec()
- if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
- del self.sessions[sid]
def hash(self, password):
"""Computes a secure hash of password.
@@ -617,10 +843,13 @@ class AuthHandler(BaseHandler):
password (str): Password to hash.
Returns:
- Hashed password (str).
+ Deferred(str): Hashed password.
"""
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- bcrypt.gensalt(self.bcrypt_rounds))
+ def _do_hash():
+ return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
+
+ return make_deferred_yieldable(threads.deferToThread(_do_hash))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -630,13 +859,19 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value.
Returns:
- Whether self.hash(password) == stored_hash (bool).
+ Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
+ def _do_validate_hash():
+ return bcrypt.checkpw(
+ password.encode('utf8') + self.hs.config.password_pepper,
+ stored_hash.encode('utf8')
+ )
+
if stored_hash:
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- stored_hash.encode('utf8')) == stored_hash
+ return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
else:
- return False
+ return defer.succeed(False)
class MacaroonGeneartor(object):
@@ -679,30 +914,3 @@ class MacaroonGeneartor(object):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
-
-
-class _AccountHandler(object):
- """A proxy object that gets passed to password auth providers so they
- can register new users etc if necessary.
- """
- def __init__(self, hs, check_user_exists):
- self.hs = hs
-
- self._check_user_exists = check_user_exists
-
- def check_user_exists(self, user_id):
- """Check if user exissts.
-
- Returns:
- Deferred(bool)
- """
- return self._check_user_exists(user_id)
-
- def register(self, localpart):
- """Registers a new user with given localpart
-
- Returns:
- Deferred: a 2-tuple of (user_id, access_token)
- """
- reg = self.hs.get_handlers().registration_handler
- return reg.register(localpart=localpart)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
new file mode 100644
index 0000000000..b1d3814909
--- /dev/null
+++ b/synapse/handlers/deactivate_account.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from ._base import BaseHandler
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DeactivateAccountHandler(BaseHandler):
+ """Handler which deals with deactivating user accounts."""
+ def __init__(self, hs):
+ super(DeactivateAccountHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def deactivate_account(self, user_id):
+ """Deactivate a user's account
+
+ Args:
+ user_id (str): ID of user to be deactivated
+
+ Returns:
+ Deferred
+ """
+ # FIXME: Theoretically there is a race here wherein user resets
+ # password using threepid.
+
+ # first delete any devices belonging to the user, which will also
+ # delete corresponding access tokens.
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+ # then delete any remaining access tokens which weren't associated with
+ # a device.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
+
+ yield self.store.user_delete_threepids(user_id)
+ yield self.store.user_set_password_hash(user_id, None)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ca7137f315..f7457a7082 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,8 +14,11 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils
from synapse.util.async import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer
@@ -32,14 +35,17 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
- self.federation = hs.get_replication_layer()
- self._remote_edue_linearizer = Linearizer(name="remote_device_list")
- self.federation.register_edu_handler(
- "m.device_list_update", self._incoming_device_list_update,
+ self._edu_updater = DeviceListEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
+ "m.device_list_update", self._edu_updater.incoming_device_list_update,
)
- self.federation.register_query_handler(
+ federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@@ -103,7 +109,7 @@ class DeviceHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id) for device_id in device_map.keys())
+ user_id, device_id=None
)
devices = device_map.values()
@@ -130,7 +136,7 @@ class DeviceHandler(BaseHandler):
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id),)
+ user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@@ -149,16 +155,15 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_device(user_id, device_id)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
else:
raise
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@@ -168,6 +173,57 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
+ def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ """Delete all of the user's devices
+
+ Args:
+ user_id (str):
+ except_device_id (str|None): optional device id which should not
+ be deleted
+
+ Returns:
+ defer.Deferred:
+ """
+ device_map = yield self.store.get_devices_by_user(user_id)
+ device_ids = device_map.keys()
+ if except_device_id is not None:
+ device_ids = [d for d in device_ids if d != except_device_id]
+ yield self.delete_devices(user_id, device_ids)
+
+ @defer.inlineCallbacks
+ def delete_devices(self, user_id, device_ids):
+ """ Delete several devices
+
+ Args:
+ user_id (str):
+ device_ids (List[str]): The list of device IDs to delete
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.delete_devices(user_id, device_ids)
+ except errors.StoreError as e:
+ if e.code == 404:
+ # no match
+ pass
+ else:
+ raise
+
+ # Delete access tokens and e2e keys for each device. Not optimised as it is not
+ # considered as part of a critical path.
+ for device_id in device_ids:
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id,
+ )
+ yield self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=device_id
+ )
+
+ yield self.notify_device_update(user_id, device_ids)
+
+ @defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@@ -187,7 +243,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
else:
@@ -212,8 +268,7 @@ class DeviceHandler(BaseHandler):
user_id, device_ids, list(hosts)
)
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = [r.room_id for r in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user_id)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
@@ -234,8 +289,9 @@ class DeviceHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = set(r.room_id for r in rooms)
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
@@ -245,11 +301,30 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ member_events = yield self.store.get_membership_changes_for_user(
+ user_id, from_token.room_key, now_token.room_key
+ )
+ rooms_changed.update(event.room_id for event in member_events)
+
stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key).stream
+ from_token.room_key
+ ).stream
possibly_changed = set(changed)
+ possibly_left = set()
for room_id in rooms_changed:
+ current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+ # The user may have left the room
+ # TODO: Check if they actually did or if we were just invited.
+ if room_id not in room_ids:
+ for key, event_id in current_state_ids.iteritems():
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_left.add(state_key)
+ continue
+
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
@@ -260,8 +335,6 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room
event_ids = []
- current_state_ids = yield self.state.get_current_state_ids(room_id)
-
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
@@ -272,9 +345,25 @@ class DeviceHandler(BaseHandler):
possibly_changed.add(state_key)
continue
+ current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+ if not current_member_id:
+ continue
+
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ # Check if we've joined the room? If so we just blindly add all the users to
+ # the "possibly changed" users.
+ for state_dict in prev_state_ids.itervalues():
+ member_event = state_dict.get((EventTypes.Member, user_id), None)
+ if not member_event or member_event != current_member_id:
+ for key, event_id in current_state_ids.iteritems():
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_changed.add(state_key)
+ break
+
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
@@ -285,94 +374,208 @@ class DeviceHandler(BaseHandler):
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.values():
+ for state_dict in prev_state_ids.itervalues():
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
- possibly_changed.add(state_key)
+ if state_key != user_id:
+ possibly_changed.add(state_key)
break
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
+ if possibly_changed or possibly_left:
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
- # Take the intersection of the users whose devices may have changed
- # and those that actually still share a room with the user
- defer.returnValue(users_who_share_room & possibly_changed)
+ # Take the intersection of the users whose devices may have changed
+ # and those that actually still share a room with the user
+ possibly_joined = possibly_changed & users_who_share_room
+ possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+ else:
+ possibly_joined = []
+ possibly_left = []
+
+ defer.returnValue({
+ "changed": list(possibly_joined),
+ "left": list(possibly_left),
+ })
- @measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
- def _incoming_device_list_update(self, origin, edu_content):
- user_id = edu_content["user_id"]
- device_id = edu_content["device_id"]
- stream_id = edu_content["stream_id"]
- prev_ids = edu_content.get("prev_id", [])
+ def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+ defer.returnValue({
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ })
+
+ @defer.inlineCallbacks
+ def user_left_room(self, user, room_id):
+ user_id = user.to_string()
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We no longer share rooms with this user, so we'll no longer
+ # receive device updates. Mark this in DB.
+ yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+
+def _update_device_from_client_ips(device, client_ips):
+ ip = client_ips.get((device["user_id"], device["device_id"]), {})
+ device.update({
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ })
+
+
+class DeviceListEduUpdater(object):
+ "Handles incoming device list updates from federation and updates the DB"
+
+ def __init__(self, hs, device_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.device_handler = device_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="device_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_device_list_update(self, origin, edu_content):
+ """Called on incoming device list update from federation. Responsible
+ for parsing the EDU and adding to pending updates list.
+ """
+
+ user_id = edu_content.pop("user_id")
+ device_id = edu_content.pop("device_id")
+ stream_id = str(edu_content.pop("stream_id")) # They may come as ints
+ prev_ids = edu_content.pop("prev_id", [])
+ prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
- rooms = yield self.store.get_rooms_for_user(user_id)
- if not rooms:
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
- with (yield self._remote_edue_linearizer.queue(user_id)):
- # If the prev id matches whats in our cache table, then we don't need
- # to resync the users device list, otherwise we do.
- resync = True
- if len(prev_ids) == 1:
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(
- user_id
- )
- logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
- if str(extremity) == str(prev_ids[0]):
- resync = False
+ self._pending_updates.setdefault(user_id, []).append(
+ (device_id, stream_id, prev_ids, edu_content)
+ )
+
+ yield self._handle_device_updates(user_id)
+
+ @measure_func("_incoming_device_list_update")
+ @defer.inlineCallbacks
+ def _handle_device_updates(self, user_id):
+ "Actually handle pending updates."
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ # Given a list of updates we check if we need to resync. This
+ # happens if we've missed updates.
+ resync = yield self._need_to_do_resync(user_id, pending_updates)
if resync:
# Fetch all devices for the user.
- result = yield self.federation.query_user_devices(origin, user_id)
+ origin = get_domain_from_id(user_id)
+ try:
+ result = yield self.federation.query_user_devices(origin, user_id)
+ except NotRetryingDestination:
+ # TODO: Remember that we are now out of sync and try again
+ # later
+ logger.warn(
+ "Failed to handle device list update for %s,"
+ " we're not retrying the remote",
+ user_id,
+ )
+ # We abort on exceptions rather than accepting the update
+ # as otherwise synapse will 'forget' that its device list
+ # is out of date. If we bail then we will retry the resync
+ # next time we get a device list update for this user_id.
+ # This makes it more likely that the device lists will
+ # eventually become consistent.
+ return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
+ except Exception:
+ # TODO: Remember that we are now out of sync and try again
+ # later
+ logger.exception(
+ "Failed to handle device list update for %s", user_id
+ )
+ return
+
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
- yield self.notify_device_update(user_id, device_ids)
+ yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
- content = dict(edu_content)
- for key in ("user_id", "device_id", "stream_id", "prev_ids"):
- content.pop(key, None)
- yield self.store.update_remote_device_list_cache_entry(
- user_id, device_id, content, stream_id,
+ for device_id, stream_id, prev_ids, content in pending_updates:
+ yield self.store.update_remote_device_list_cache_entry(
+ user_id, device_id, content, stream_id,
+ )
+
+ yield self.device_handler.notify_device_update(
+ user_id, [device_id for device_id, _, _, _ in pending_updates]
)
- yield self.notify_device_update(user_id, [device_id])
- @defer.inlineCallbacks
- def on_federation_query_user_devices(self, user_id):
- stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue({
- "user_id": user_id,
- "stream_id": stream_id,
- "devices": devices,
- })
+ self._seen_updates.setdefault(user_id, set()).update(
+ stream_id for _, stream_id, _, _ in pending_updates
+ )
@defer.inlineCallbacks
- def user_left_room(self, user, room_id):
- user_id = user.to_string()
- rooms = yield self.store.get_rooms_for_user(user_id)
- if not rooms:
- # We no longer share rooms with this user, so we'll no longer
- # receive device updates. Mark this in DB.
- yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ def _need_to_do_resync(self, user_id, updates):
+ """Given a list of updates for a user figure out if we need to do a full
+ resync, or whether we have enough data that we can just apply the delta.
+ """
+ seen_updates = self._seen_updates.get(user_id, set())
+ extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+ user_id
+ )
-def _update_device_from_client_ips(device, client_ips):
- ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
- })
+ stream_id_in_updates = set() # stream_ids in updates list
+ for _, stream_id, prev_ids, _ in updates:
+ if not prev_ids:
+ # We always do a resync if there are no previous IDs
+ defer.returnValue(True)
+
+ for prev_id in prev_ids:
+ if prev_id == extremity:
+ continue
+ elif prev_id in seen_updates:
+ continue
+ elif prev_id in stream_id_in_updates:
+ continue
+ else:
+ defer.returnValue(True)
+
+ stream_id_in_updates.add(stream_id)
+
+ defer.returnValue(False)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..f147a20b73 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string
@@ -33,10 +34,10 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1b5317edf5..c5b6e75e03 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -34,12 +34,15 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query
)
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services
@@ -73,6 +76,11 @@ class DirectoryHandler(BaseHandler):
# association creation for human users
# TODO(erikj): Do user auth.
+ if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ raise SynapseError(
+ 403, "This user is not permitted to create this alias",
+ )
+
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
@@ -175,6 +183,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(),
},
retry_on_dns_fail=False,
+ ignore_backoff=True,
)
except CodeMessageException as e:
logging.warn("Error retrieving alias")
@@ -241,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases,
@@ -264,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -326,6 +333,14 @@ class DirectoryHandler(BaseHandler):
room_id (str)
visibility (str): "public" or "private"
"""
+ if not self.spam_checker.user_may_publish_room(
+ requester.user.to_string(), room_id
+ ):
+ raise AuthError(
+ 403,
+ "This user is not permitted to publish rooms to the room list"
+ )
+
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index e40495d1ab..25aec624af 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import ujson as json
+import simplejson as json
import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.api.errors import (
+ SynapseError, CodeMessageException, FederationDeniedError,
+)
+from synapse.types import get_domain_from_id, UserID
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -30,15 +33,15 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- self.federation.register_query_handler(
+ hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -70,7 +73,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
@@ -121,38 +125,23 @@ class E2eKeysHandler(object):
def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination]
try:
- limiter = yield get_retry_limiter(
- destination, self.clock, self.store
+ remote_result = yield self.federation.query_client_keys(
+ destination,
+ {"device_keys": destination_query},
+ timeout=timeout
)
- with limiter:
- remote_result = yield self.federation.query_client_keys(
- destination,
- {"device_keys": destination_query},
- timeout=timeout
- )
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
- yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(do_remote_query)(destination)
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
- ]))
+ ], consumeErrors=True))
defer.returnValue({
"device_keys": results, "failures": failures,
@@ -174,7 +163,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -217,7 +207,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
@@ -239,36 +230,31 @@ class E2eKeysHandler(object):
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
- limiter = yield get_retry_limiter(
- destination, self.clock, self.store
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
)
- with limiter:
- remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
- )
- for user_id, keys in remote_result["one_time_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
- yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(claim_client_keys, destination)
for destination in remote_queries
- ]))
+ ], consumeErrors=True))
+
+ logger.info(
+ "Claimed one-time-keys: %s",
+ ",".join((
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in json_result.iteritems()
+ for device_id, device_keys in user_keys.iteritems()
+ for key_id, _ in device_keys.iteritems()
+ )),
+ )
defer.returnValue({
"one_time_keys": json_result,
@@ -296,19 +282,8 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
- logger.info(
- "Adding %d one_time_keys for device %r for user %r at %d",
- len(one_time_keys), device_id, user_id, time_now
- )
- key_list = []
- for key_id, key_json in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, encode_canonical_json(key_json)
- ))
-
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, key_list
+ yield self._upload_one_time_keys_for_user(
+ user_id, device_id, time_now, one_time_keys,
)
# the device should have been registered already, but it may have been
@@ -316,8 +291,88 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
- self.device_handler.check_device_registered(user_id, device_id)
+ yield self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})
+
+ @defer.inlineCallbacks
+ def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+ one_time_keys):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(), device_id, user_id, time_now,
+ )
+
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, key_obj
+ ))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = yield self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ ("One time key %s:%s already exists. "
+ "Old key: %s; new key: %r") %
+ (algorithm, key_id, ex_json, key)
+ )
+ else:
+ new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
+
+
+def _exception_to_failure(e):
+ if isinstance(e, CodeMessageException):
+ return {
+ "status": e.code, "message": e.message,
+ }
+
+ if isinstance(e, NotRetryingDestination):
+ return {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ if isinstance(e, FederationDeniedError):
+ return {
+ "status": 403, "message": "Federation Denied",
+ }
+
+ # include ConnectionRefused and other errors
+ #
+ # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
+ # give a string for e.message, which simplejson then fails to serialize.
+ return {
+ "status": 503, "message": str(e.message),
+ }
+
+
+def _one_time_keys_match(old_key_json, new_key):
+ old_key = json.loads(old_key_json)
+
+ # if either is a string rather than an object, they must match exactly
+ if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+ return old_key == new_key
+
+ # otherwise, we strip off the 'signatures' if any, because it's legitimate
+ # for different upload attempts to have different signatures.
+ old_key.pop("signatures", None)
+ new_key_copy = dict(new_key)
+ new_key_copy.pop("signatures", None)
+
+ return old_key == new_key_copy
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 996bfd0e23..f39233d846 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,24 +15,30 @@
# limitations under the License.
"""Contains handlers for federation events."""
+
+import itertools
+import logging
+import sys
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
+import six
+from six.moves import http_client
+from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
+ FederationDeniedError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
-)
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
@@ -42,13 +49,8 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
-from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
-from twisted.internet import defer
-
-import itertools
-import logging
logger = logging.getLogger(__name__)
@@ -70,38 +72,268 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
- self.replication_layer = hs.get_replication_layer()
+ self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
-
- self.replication_layer.set_handler(self)
+ self.action_generator = hs.get_action_generator()
+ self.is_mine_id = hs.is_mine_id
+ self.pusher_pool = hs.get_pusherpool()
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
# When joining a room we need to queue any events for that room up
self.room_queues = {}
+ self._room_pdu_linearizer = Linearizer("fed_room_pdu")
- @log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
- """ Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it through the StateHandler.
+ @log_function
+ def on_receive_pdu(self, origin, pdu, get_missing=True):
+ """ Process a PDU received via a federation /send/ transaction, or
+ via backfill of missing prev_events
+
+ Args:
+ origin (str): server which initiated the /send/ transaction. Will
+ be used to fetch missing events or state.
+ pdu (FrozenEvent): received PDU
+ get_missing (bool): True if we should fetch missing prev_events
- auth_chain and state are None if we already have the necessary state
- and prev_events in the db
+ Returns (Deferred): completes with None
"""
- event = pdu
- logger.debug("Got event: %s", event.event_id)
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self.get_persisted_pdu(
+ origin, pdu.event_id, do_auth=False
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+
+ already_seen = (
+ existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
+ )
+ )
+ if already_seen:
+ logger.debug("Already seen pdu %s", pdu.event_id)
+ return
+
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ raise FederationError(
+ "ERROR",
+ err.code,
+ err.msg,
+ affected=pdu.event_id,
+ )
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
- if event.room_id in self.room_queues:
- self.room_queues[event.room_id].append((pdu, origin))
+ if pdu.room_id in self.room_queues:
+ logger.info("Ignoring PDU %s for room %s from %s for now; join "
+ "in progress", pdu.event_id, pdu.room_id, origin)
+ self.room_queues[pdu.room_id].append((pdu, origin))
return
- logger.debug("Processing event: %s", event.event_id)
+ # If we're no longer in the room just ditch the event entirely. This
+ # is probably an old server that has come back and thinks we're still
+ # in the room (or we've been rejoined to the room by a state reset).
+ #
+ # If we were never in the room then maybe our database got vaped and
+ # we should check if we *are* in fact in the room. If we are then we
+ # can magically rejoin the room.
+ is_in_room = yield self.auth.check_host_in_room(
+ pdu.room_id,
+ self.server_name
+ )
+ if not is_in_room:
+ was_in_room = yield self.store.was_host_joined(
+ pdu.room_id, self.server_name,
+ )
+ if was_in_room:
+ logger.info(
+ "Ignoring PDU %s for room %s from %s as we've left the room!",
+ pdu.event_id, pdu.room_id, origin,
+ )
+ return
+
+ state = None
+
+ auth_chain = []
+
+ fetch_state = False
+
+ # Get missing pdus if necessary.
+ if not pdu.internal_metadata.is_outlier():
+ # We only backfill backwards to the min depth.
+ min_depth = yield self.get_min_depth_for_context(
+ pdu.room_id
+ )
+
+ logger.debug(
+ "_handle_new_pdu min_depth for %s: %d",
+ pdu.room_id, min_depth
+ )
+
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = yield self.store.have_seen_events(prevs)
+
+ if min_depth and pdu.depth < min_depth:
+ # This is so that we don't notify the user about this
+ # message, to work around the fact that some events will
+ # reference really really old events we really don't want to
+ # send to the clients.
+ pdu.internal_metadata.outlier = True
+ elif min_depth and pdu.depth > min_depth:
+ if get_missing and prevs - seen:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring lock for room %r to fetch %d missing events: %r...",
+ pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
+ )
+ with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ logger.info(
+ "Acquired lock for room %r to fetch %d missing events",
+ pdu.room_id, len(prevs - seen),
+ )
+
+ yield self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = yield self.store.have_seen_events(prevs)
+
+ if not prevs - seen:
+ logger.info(
+ "Found all missing prev events for %s", pdu.event_id
+ )
+ elif prevs - seen:
+ logger.info(
+ "Not fetching %d missing events for room %r,event %s: %r...",
+ len(prevs - seen), pdu.room_id, pdu.event_id,
+ list(prevs - seen)[:5],
+ )
+
+ if prevs - seen:
+ logger.info(
+ "Still missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+ fetch_state = True
- logger.debug("Event: %s", event)
+ if fetch_state:
+ # We need to get the state at this event, since we haven't
+ # processed all the prev events.
+ logger.debug(
+ "_handle_new_pdu getting state for %s",
+ pdu.room_id
+ )
+ try:
+ state, auth_chain = yield self.replication_layer.get_state_for_room(
+ origin, pdu.room_id, pdu.event_id,
+ )
+ except Exception:
+ logger.exception("Failed to get state for event: %s", pdu.event_id)
+
+ yield self._process_received_pdu(
+ origin,
+ pdu,
+ state=state,
+ auth_chain=auth_chain,
+ )
+
+ @defer.inlineCallbacks
+ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ """
+ Args:
+ origin (str): Origin of the pdu. Will be called to get the missing events
+ pdu: received pdu
+ prevs (set(str)): List of event ids which we are missing
+ min_depth (int): Minimum depth of events to return.
+ """
+ # We recalculate seen, since it may have changed.
+ seen = yield self.store.have_seen_events(prevs)
+
+ if not prevs - seen:
+ return
+
+ latest = yield self.store.get_latest_event_ids_in_room(
+ pdu.room_id
+ )
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest)
+ latest |= seen
+
+ logger.info(
+ "Missing %d events for room %r pdu %s: %r...",
+ len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5]
+ )
+
+ # XXX: we set timeout to 10s to help workaround
+ # https://github.com/matrix-org/synapse/issues/1733.
+ # The reason is to avoid holding the linearizer lock
+ # whilst processing inbound /send transactions, causing
+ # FDs to stack up and block other inbound transactions
+ # which empirically can currently take up to 30 minutes.
+ #
+ # N.B. this explicitly disables retry attempts.
+ #
+ # N.B. this also increases our chances of falling back to
+ # fetching fresh state for the room if the missing event
+ # can't be found, which slightly reduces our security.
+ # it may also increase our DAG extremity count for the room,
+ # causing additional state resolution? See #1760.
+ # However, fetching state doesn't hold the linearizer lock
+ # apparently.
+ #
+ # see https://github.com/matrix-org/synapse/pull/1744
+
+ missing_events = yield self.replication_layer.get_missing_events(
+ origin,
+ pdu.room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=10000,
+ )
+
+ logger.info(
+ "Got %d events: %r...",
+ len(missing_events), [e.event_id for e in missing_events[:5]]
+ )
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ missing_events.sort(key=lambda x: x.depth)
+
+ for e in missing_events:
+ logger.info("Handling found event %s", e.event_id)
+ yield self.on_receive_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+
+ @log_function
+ @defer.inlineCallbacks
+ def _process_received_pdu(self, origin, pdu, state, auth_chain):
+ """ Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+ """
+ event = pdu
+
+ logger.debug("Processing event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
@@ -140,9 +372,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -181,13 +411,6 @@ class FederationHandler(BaseHandler):
affected=event.event_id,
)
- # if we're receiving valid events from an origin,
- # it's probably a good idea to mark it as not in retry-state
- # for sending (although this is a bit of a leap)
- retry_timings = yield self.store.get_destination_retry_timings(origin)
- if retry_timings and retry_timings["retry_last_ts"]:
- self.store.set_destination_retry_timings(origin, 0, 0)
-
room = yield self.store.get_room(event.room_id)
if not room:
@@ -206,11 +429,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -249,7 +471,7 @@ class FederationHandler(BaseHandler):
def check_match(id):
try:
return server_name == get_domain_from_id(id)
- except:
+ except Exception:
return False
# Parses mapping `event_id -> (type, state_key) -> state event_id`
@@ -287,7 +509,7 @@ class FederationHandler(BaseHandler):
continue
try:
domain = get_domain_from_id(ev.state_key)
- except:
+ except Exception:
continue
if domain != server_name:
@@ -314,9 +536,16 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
- This will attempt to get more events from the remote. This may return
- be successfull and still return no events if the other side has no new
- events to offer.
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
@@ -328,6 +557,16 @@ class FederationHandler(BaseHandler):
extremities=extremities,
)
+ # ideally we'd sanity check the events here for excess prev_events etc,
+ # but it's hard to reject events at this point without completely
+ # breaking backfill in the same way that it is currently broken by
+ # events whose signature we cannot verify (#3121).
+ #
+ # So for now we accept the events anyway. #3124 tracks this.
+ #
+ # for ev in events:
+ # self._sanity_check_event(ev)
+
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
@@ -398,9 +637,10 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.run_in_background(
+ self.replication_layer.get_pdu,
[dest],
event_id,
outlier=True,
@@ -420,7 +660,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_events(
+ seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -526,7 +766,7 @@ class FederationHandler(BaseHandler):
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
- except:
+ except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
@@ -570,6 +810,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e)
+ continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
@@ -592,10 +835,13 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- states = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
- for e in event_ids
- ]))
+ resolve = logcontext.preserve_fn(
+ self.state_handler.resolve_state_groups_for_events
+ )
+ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids],
+ consumeErrors=True,
+ ))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
@@ -624,6 +870,38 @@ class FederationHandler(BaseHandler):
defer.returnValue(False)
+ def _sanity_check_event(self, ev):
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev (synapse.events.EventBase): event to be checked
+
+ Returns: None
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_events) > 20:
+ logger.warn("Rejecting event %s which has %i prev_events",
+ ev.event_id, len(ev.prev_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many prev_events",
+ )
+
+ if len(ev.auth_events) > 10:
+ logger.warn("Rejecting event %s which has %i auth_events",
+ ev.event_id, len(ev.auth_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many auth_events",
+ )
+
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -641,7 +919,11 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_event_auth(self, event_id):
- auth = yield self.store.get_auth_chain([event_id])
+ event = yield self.store.get_event(event_id)
+ auth = yield self.store.get_auth_chain(
+ [auth_id for auth_id, _ in event.auth_events],
+ include_given=True
+ )
for event in auth:
event.signatures.update(
@@ -670,8 +952,6 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
- yield self.store.clean_room_for_join(room_id)
-
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
@@ -680,7 +960,15 @@ class FederationHandler(BaseHandler):
content,
)
+ # This shouldn't happen, because the RoomMemberHandler has a
+ # linearizer lock which only allows one operation per user per room
+ # at a time - so this is just paranoia.
+ assert (room_id not in self.room_queues)
+
self.room_queues[room_id] = []
+
+ yield self.store.clean_room_for_join(room_id)
+
handled_events = set()
try:
@@ -714,7 +1002,7 @@ class FederationHandler(BaseHandler):
room_creator_user_id="",
is_public=False
)
- except:
+ except Exception:
# FIXME
pass
@@ -722,29 +1010,45 @@ class FederationHandler(BaseHandler):
origin, auth_chain, state, event
)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=[joinee]
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=[joinee]
+ )
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
- for p, origin in room_queue:
- if p.event_id in handled_events:
- continue
+ # we don't need to wait for the queued events to be processed -
+ # it's just a best-effort thing at this point. We do want to do
+ # them roughly in order, though, otherwise we'll end up making
+ # lots of requests for missing prev_events which we do actually
+ # have. Hence we fire off the deferred, but don't wait for it.
- try:
- self.on_receive_pdu(origin, p)
- except:
- logger.exception("Couldn't handle pdu")
+ logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@defer.inlineCallbacks
+ def _handle_queued_pdus(self, room_queue):
+ """Process PDUs which got queued up while we were busy send_joining.
+
+ Args:
+ room_queue (list[FrozenEvent, str]): list of PDUs to be processed
+ and the servers that sent them
+ """
+ for p, origin in room_queue:
+ try:
+ logger.info("Processing queued PDU %s which was received "
+ "while we were joining %s", p.event_id, p.room_id)
+ yield self.on_receive_pdu(origin, p)
+ except Exception as e:
+ logger.warn(
+ "Error handling queued PDU %s from %s: %s",
+ p.event_id, origin, e)
+
+ @defer.inlineCallbacks
@log_function
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
@@ -762,8 +1066,7 @@ class FederationHandler(BaseHandler):
})
try:
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -791,9 +1094,19 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = False
- # Send this event on behalf of the origin server since they may not
- # have an up to data view of the state of the room at this event so
- # will not know which servers to send the event to.
+ # Send this event on behalf of the origin server.
+ #
+ # The reasons we have the destination server rather than the origin
+ # server send it are slightly mysterious: the origin server should have
+ # all the neccessary state once it gets the response to the send_join,
+ # so it could send the event itself if it wanted to. It may be that
+ # doing it this way reduces failure modes, or avoids certain attacks
+ # where a new server selectively tells a subset of the federation that
+ # it has joined.
+ #
+ # The fact is that, as of the current writing, Synapse doesn't send out
+ # the join event over federation after joining, and changing it now
+ # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -812,10 +1125,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id, extra_users=extra_users
+ )
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
@@ -823,9 +1135,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id)
state_ids = context.prev_state_ids.values()
- auth_chain = yield self.store.get_auth_chain(set(
- [event.event_id] + state_ids
- ))
+ auth_chain = yield self.store.get_auth_chain(state_ids)
state = yield self.store.get_events(context.prev_state_ids.values())
@@ -842,6 +1152,34 @@ class FederationHandler(BaseHandler):
"""
event = pdu
+ if event.state_key is None:
+ raise SynapseError(400, "The invite event did not have a state key")
+
+ is_blocked = yield self.store.is_room_blocked(event.room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
+ if self.hs.config.block_non_admin_invites:
+ raise SynapseError(403, "This server does not accept room invites")
+
+ if not self.spam_checker.user_may_invite(
+ event.sender, event.state_key, event.room_id,
+ ):
+ raise SynapseError(
+ 403, "This user is not permitted to send invites to this server/user"
+ )
+
+ membership = event.content.get("membership")
+ if event.type != EventTypes.Member or membership != Membership.INVITE:
+ raise SynapseError(400, "The event was not an m.room.member invite event")
+
+ sender_domain = get_domain_from_id(event.sender)
+ if sender_domain != origin:
+ raise SynapseError(400, "The invite event was not from the server sending it")
+
+ if not self.is_mine_id(event.state_key):
+ raise SynapseError(400, "The invite event must be for this server")
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -861,48 +1199,38 @@ class FederationHandler(BaseHandler):
)
target_user = UserID.from_string(event.state_key)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=[target_user],
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=[target_user],
+ )
defer.returnValue(event)
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
- try:
- origin, event = yield self._make_and_verify_event(
- target_hosts,
- room_id,
- user_id,
- "leave"
- )
- signed_event = self._sign_event(event)
- except SynapseError:
- raise
- except CodeMessageException as e:
- logger.warn("Failed to reject invite: %s", e)
- raise SynapseError(500, "Failed to reject invite")
-
- # Try the host we successfully got a response to /make_join/
- # request first.
+ origin, event = yield self._make_and_verify_event(
+ target_hosts,
+ room_id,
+ user_id,
+ "leave"
+ )
+ # Mark as outlier as we don't have any state for this event; we're not
+ # even in the room.
+ event.internal_metadata.outlier = True
+ event = self._sign_event(event)
+
+ # Try the host that we succesfully called /make_leave/ on first for
+ # the /send_leave/ request.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
- try:
- yield self.replication_layer.send_leave(
- target_hosts,
- signed_event
- )
- except SynapseError:
- raise
- except CodeMessageException as e:
- logger.warn("Failed to reject invite: %s", e)
- raise SynapseError(500, "Failed to reject invite")
+ yield self.replication_layer.send_leave(
+ target_hosts,
+ event
+ )
context = yield self.state_handler.compute_event_context(event)
@@ -978,8 +1306,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1023,10 +1350,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id, extra_users=extra_users
+ )
defer.returnValue(None)
@@ -1061,7 +1387,7 @@ class FederationHandler(BaseHandler):
for event in res:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
+ if self.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
@@ -1096,7 +1422,7 @@ class FederationHandler(BaseHandler):
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
- del results[(event.type, event.state_key)]
+ results.pop((event.type, event.state_key), None)
defer.returnValue(results.values())
else:
@@ -1134,7 +1460,7 @@ class FederationHandler(BaseHandler):
)
if event:
- if self.hs.is_mine_id(event.event_id):
+ if self.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
@@ -1178,23 +1504,33 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
- if not event.internal_metadata.is_outlier():
- action_generator = ActionGenerator(self.hs)
- yield action_generator.handle_push_actions_for_event(
- event, context
+ try:
+ if not event.internal_metadata.is_outlier() and not backfilled:
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ tp, value, tb = sys.exc_info()
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- )
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
+ logcontext.run_in_background(
+ self.pusher_pool.on_new_notifications,
+ event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1206,16 +1542,17 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
- contexts = yield preserve_context_over_deferred(defer.gatherResults(
+ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._prep_event)(
+ logcontext.run_in_background(
+ self._prep_event,
origin,
ev_info["event"],
state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
- ]
+ ], consumeErrors=True,
))
yield self.store.persist_events(
@@ -1325,7 +1662,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
+ """
+
+ Args:
+ origin:
+ event:
+ state:
+ auth_events:
+ Returns:
+ Deferred, which resolves to synapse.events.snapshot.EventContext
+ """
context = yield self.state_handler.compute_event_context(
event, old_state=state,
)
@@ -1362,7 +1709,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
- if event.type == EventTypes.GuestAccess:
+ if event.type == EventTypes.GuestAccess and not context.rejected:
yield self.maybe_kick_guest_users(event)
defer.returnValue(context)
@@ -1379,7 +1726,11 @@ class FederationHandler(BaseHandler):
pass
# Now get the current auth_chain for the event.
- local_auth_chain = yield self.store.get_auth_chain([event_id])
+ event = yield self.store.get_event(event_id)
+ local_auth_chain = yield self.store.get_auth_chain(
+ [auth_id for auth_id, _ in event.auth_events],
+ include_given=True
+ )
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
@@ -1427,6 +1778,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
+ """
+
+ Args:
+ origin (str):
+ event (synapse.events.FrozenEvent):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->str]):
+
+ Returns:
+ defer.Deferred[None]
+ """
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1437,7 +1799,8 @@ class FederationHandler(BaseHandler):
event_key = None
if event_auth_events - current_state:
- have_events = yield self.store.have_events(
+ # TODO: can we use store.have_seen_events here instead?
+ have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state
)
else:
@@ -1460,12 +1823,12 @@ class FederationHandler(BaseHandler):
origin, event.room_id, event.event_id
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
- if e.event_id in seen_remotes.keys():
+ if e.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
@@ -1492,11 +1855,11 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.have_events(
+ have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
@@ -1509,18 +1872,18 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
- different_events = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self.store.get_event)(
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
@@ -1539,16 +1902,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@@ -1572,7 +1928,9 @@ class FederationHandler(BaseHandler):
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
)
- local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+ local_auth_chain = yield self.store.get_auth_chain(
+ auth_ids, include_given=True
+ )
try:
# 2. Get remote difference.
@@ -1583,13 +1941,13 @@ class FederationHandler(BaseHandler):
local_auth_chain,
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes.keys():
+ if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
@@ -1619,23 +1977,16 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
try:
self.auth.check(event, auth_events=auth_events)
@@ -1644,6 +1995,45 @@ class FederationHandler(BaseHandler):
raise e
@defer.inlineCallbacks
+ def _update_context_for_auth_events(self, event, context, auth_events,
+ event_key):
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event (Event): The event we're handling the context for
+
+ context (synapse.events.snapshot.EventContext): event context
+ to be updated
+
+ auth_events (dict[(str, str)->str]): Events to update in the event
+ context.
+
+ event_key ((str, str)): (type, state_key) for the current event.
+ this will not be included in the current_state in the context.
+ """
+ state_updates = {
+ k: a.event_id for k, a in auth_events.iteritems()
+ if k != event_key
+ }
+ context.current_state_ids = dict(context.current_state_ids)
+ context.current_state_ids.update(state_updates)
+ if context.delta_ids is not None:
+ context.delta_ids = dict(context.delta_ids)
+ context.delta_ids.update(state_updates)
+ context.prev_state_ids = dict(context.prev_state_ids)
+ context.prev_state_ids.update({
+ k: a.event_id for k, a in auth_events.iteritems()
+ })
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=context.prev_group,
+ delta_ids=context.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
+
+ @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
@@ -1686,7 +2076,7 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
return it.next()
- except:
+ except Exception:
return opt
current_local = get_next(local_iter)
@@ -1811,8 +2201,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1827,7 +2216,7 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
- member_handler = self.hs.get_handlers().room_member_handler
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@@ -1840,10 +2229,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ """Handle an exchange_third_party_invite request from a remote server
+
+ The remote server will call this when it wants to turn a 3pid invite
+ into a normal m.room.member invite.
+
+ Returns:
+ Deferred: resolves (to None)
+ """
builder = self.event_builder_factory.new(event_dict)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1858,10 +2254,13 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+ # XXX we send the invite here, but send_membership_event also sends it,
+ # so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
- member_handler = self.hs.get_handlers().room_member_handler
+
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
@@ -1890,8 +2289,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(builder=builder)
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ )
defer.returnValue((event, context))
@defer.inlineCallbacks
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
new file mode 100644
index 0000000000..977993e7d4
--- /dev/null
+++ b/synapse/handlers/groups_local.py
@@ -0,0 +1,471 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def _create_rerouter(func_name):
+ """Returns a function that looks at the group id and calls the function
+ on federation or the local group server if the group is local
+ """
+ def f(self, group_id, *args, **kwargs):
+ if self.is_mine_id(group_id):
+ return getattr(self.groups_server_handler, func_name)(
+ group_id, *args, **kwargs
+ )
+ else:
+ destination = get_domain_from_id(group_id)
+ return getattr(self.transport_client, func_name)(
+ destination, group_id, *args, **kwargs
+ )
+ return f
+
+
+class GroupsLocalHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.groups_server_handler = hs.get_groups_server_handler()
+ self.transport_client = hs.get_federation_transport_client()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.notifier = hs.get_notifier()
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ # The following functions merely route the query to the local groups server
+ # or federation depending on if the group is local or remote
+
+ get_group_profile = _create_rerouter("get_group_profile")
+ update_group_profile = _create_rerouter("update_group_profile")
+ get_rooms_in_group = _create_rerouter("get_rooms_in_group")
+
+ get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
+
+ add_room_to_group = _create_rerouter("add_room_to_group")
+ update_room_in_group = _create_rerouter("update_room_in_group")
+ remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+ update_group_summary_room = _create_rerouter("update_group_summary_room")
+ delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+ update_group_category = _create_rerouter("update_group_category")
+ delete_group_category = _create_rerouter("delete_group_category")
+ get_group_category = _create_rerouter("get_group_category")
+ get_group_categories = _create_rerouter("get_group_categories")
+
+ update_group_summary_user = _create_rerouter("update_group_summary_user")
+ delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+ update_group_role = _create_rerouter("update_group_role")
+ delete_group_role = _create_rerouter("delete_group_role")
+ get_group_role = _create_rerouter("get_group_role")
+ get_group_roles = _create_rerouter("get_group_roles")
+
+ set_group_join_policy = _create_rerouter("set_group_join_policy")
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the group summary for a group.
+
+ If the group is remote we check that the users have valid attestations.
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_group_summary(
+ group_id, requester_user_id
+ )
+ else:
+ res = yield self.transport_client.get_group_summary(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ group_server_name = get_domain_from_id(group_id)
+
+ # Loop through the users and validate the attestations.
+ chunk = res["users_section"]["users"]
+ valid_users = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_users.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["users_section"]["users"] = valid_users
+
+ res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
+ res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
+
+ # Add `is_publicised` flag to indicate whether the user has publicised their
+ # membership of the group on their profile
+ result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ is_publicised = group_id in result
+
+ res.setdefault("user", {})["is_publicised"] = is_publicised
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, content):
+ """Create a group
+ """
+
+ logger.info("Asking to create group with ID: %r", group_id)
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.create_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+
+ res = yield self.transport_client.create_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ is_publicised = content.get("publicise", False)
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=True,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get users in a group
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+ defer.returnValue(res)
+
+ group_server_name = get_domain_from_id(group_id)
+
+ res = yield self.transport_client.get_users_in_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ chunk = res["chunk"]
+ valid_entries = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_entries.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["chunk"] = valid_entries
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, user_id, content):
+ """Request to join a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.join_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.join_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, user_id, content):
+ """Accept an invite to a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.accept_invite(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.accept_group_invite(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite(self, group_id, user_id, requester_user_id, config):
+ """Invite a user to a group
+ """
+ content = {
+ "requester_user_id": requester_user_id,
+ "config": config,
+ }
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ res = yield self.transport_client.invite_to_group(
+ get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def on_invite(self, group_id, user_id, content):
+ """One of our users were invited to a group
+ """
+ # TODO: Support auto join and rejection
+
+ if not self.is_mine_id(user_id):
+ raise SynapseError(400, "User not on this server")
+
+ local_profile = {}
+ if "profile" in content:
+ if "name" in content["profile"]:
+ local_profile["name"] = content["profile"]["name"]
+ if "avatar_url" in content["profile"]:
+ local_profile["avatar_url"] = content["profile"]["avatar_url"]
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="invite",
+ content={"profile": local_profile, "inviter": content["inviter"]},
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+ try:
+ user_profile = yield self.profile_handler.get_profile(user_id)
+ except Exception as e:
+ logger.warn("No profile for user %s: %s", user_id, e)
+ user_profile = {}
+
+ defer.returnValue({"state": "invite", "user_profile": user_profile})
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from a group
+ """
+ if user_id == requester_user_id:
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ # TODO: Should probably remember that we tried to leave so that we can
+ # retry if the group server is currently down.
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ content["requester_user_id"] = requester_user_id
+ res = yield self.transport_client.remove_user_from_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ user_id, content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def user_removed_from_group(self, group_id, user_id, content):
+ """One of our users was removed/kicked from a group
+ """
+ # TODO: Check if user in group
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_groups(self, user_id):
+ group_ids = yield self.store.get_joined_groups(user_id)
+ defer.returnValue({"groups": group_ids})
+
+ @defer.inlineCallbacks
+ def get_publicised_groups_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ result = yield self.store.get_publicised_groups_for_user(user_id)
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ result.extend(app_service.get_groups_for_user(user_id))
+
+ defer.returnValue({"groups": result})
+ else:
+ bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ get_domain_from_id(user_id), [user_id],
+ )
+ result = bulk_result.get("users", {}).get(user_id)
+ # TODO: Verify attestations
+ defer.returnValue({"groups": result})
+
+ @defer.inlineCallbacks
+ def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ destinations = {}
+ local_users = set()
+
+ for user_id in user_ids:
+ if self.hs.is_mine_id(user_id):
+ local_users.add(user_id)
+ else:
+ destinations.setdefault(
+ get_domain_from_id(user_id), set()
+ ).add(user_id)
+
+ if not proxy and destinations:
+ raise SynapseError(400, "Some user_ids are not local")
+
+ results = {}
+ failed_results = []
+ for destination, dest_user_ids in destinations.iteritems():
+ try:
+ r = yield self.transport_client.bulk_get_publicised_groups(
+ destination, list(dest_user_ids),
+ )
+ results.update(r["users"])
+ except Exception:
+ failed_results.extend(dest_user_ids)
+
+ for uid in local_users:
+ results[uid] = yield self.store.get_publicised_groups_for_user(
+ uid
+ )
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ results[uid].extend(app_service.get_groups_for_user(uid))
+
+ defer.returnValue({"users": results})
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 559e5d5a71..91a0898860 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +15,20 @@
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
+
+import logging
+
+import simplejson as json
+
from twisted.internet import defer
from synapse.api.errors import (
- CodeMessageException
+ MatrixCodeMessageException, CodeMessageException
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError, Codes
-import json
-import logging
-
logger = logging.getLogger(__name__)
@@ -89,6 +92,9 @@ class IdentityHandler(BaseHandler):
),
{'sid': creds['sid'], 'client_secret': client_secret}
)
+ except MatrixCodeMessageException as e:
+ logger.info("getValidated3pid failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
data = json.loads(e.msg)
@@ -150,7 +156,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs)
try:
- data = yield self.http_client.post_urlencoded_get_json(
+ data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
@@ -158,6 +164,46 @@ class IdentityHandler(BaseHandler):
params
)
defer.returnValue(data)
+ except MatrixCodeMessageException as e:
+ logger.info("Proxied requestToken failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ except CodeMessageException as e:
+ logger.info("Proxied requestToken failed: %r", e)
+ raise e
+
+ @defer.inlineCallbacks
+ def requestMsisdnToken(
+ self, id_server, country, phone_number,
+ client_secret, send_attempt, **kwargs
+ ):
+ yield run_on_reactor()
+
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server,
+ Codes.SERVER_NOT_TRUSTED
+ )
+
+ params = {
+ 'country': country,
+ 'phone_number': phone_number,
+ 'client_secret': client_secret,
+ 'send_attempt': send_attempt,
+ }
+ params.update(kwargs)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ ),
+ params
+ )
+ defer.returnValue(data)
+ except MatrixCodeMessageException as e:
+ logger.info("Proxied requestToken failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e0ade4c164..cd33a86599 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
+from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
@@ -26,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -162,10 +163,11 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield preserve_context_over_deferred(
+ (messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(
+ self.store.get_recent_events_for_room,
event.room_id,
limit=limit,
end_token=room_end_token,
@@ -213,7 +215,7 @@ class InitialSyncHandler(BaseHandler):
})
d["account_data"] = account_data_events
- except:
+ except Exception:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
@@ -225,9 +227,17 @@ class InitialSyncHandler(BaseHandler):
"content": content,
})
+ now = self.clock.time_msec()
+
ret = {
"rooms": rooms_ret,
- "presence": presence,
+ "presence": [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(event, now),
+ }
+ for event in presence
+ ],
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
@@ -382,9 +392,10 @@ class InitialSyncHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(get_presence),
+ run_in_background(get_receipts),
+ run_in_background(
+ self.store.get_recent_events_for_room,
room_id,
limit=limit,
end_token=now_token.room_key,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a498af5a2..b793fc4df7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,31 +13,64 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import simplejson
+import sys
-from twisted.internet import defer
+from canonicaljson import encode_canonical_json
+import six
+from twisted.internet import defer, reactor
+from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
+from synapse.api.constants import EventTypes, Membership, MAX_DEPTH
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
-from synapse.push.action_generator import ActionGenerator
from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
+from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler
-from canonicaljson import encode_canonical_json
+logger = logging.getLogger(__name__)
-import logging
-import random
-logger = logging.getLogger(__name__)
+class PurgeStatus(object):
+ """Object tracking the status of a purge request
+
+ This class contains information on the progress of a purge request, for
+ return by get_purge_status.
+
+ Attributes:
+ status (int): Tracks whether this request has completed. One of
+ STATUS_{ACTIVE,COMPLETE,FAILED}
+ """
+
+ STATUS_ACTIVE = 0
+ STATUS_COMPLETE = 1
+ STATUS_FAILED = 2
+
+ STATUS_TEXT = {
+ STATUS_ACTIVE: "active",
+ STATUS_COMPLETE: "complete",
+ STATUS_FAILED: "failed",
+ }
+
+ def __init__(self):
+ self.status = PurgeStatus.STATUS_ACTIVE
+
+ def asdict(self):
+ return {
+ "status": PurgeStatus.STATUS_TEXT[self.status]
+ }
class MessageHandler(BaseHandler):
@@ -46,25 +80,89 @@ class MessageHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
- self.validator = EventValidator()
self.pagination_lock = ReadWriteLock()
+ self._purges_in_progress_by_room = set()
+ # map from purge id to PurgeStatus
+ self._purges_by_id = {}
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Limiter(max_count=5)
+ def start_purge_history(self, room_id, topological_ordering,
+ delete_local_events=False):
+ """Start off a history purge on a room.
+
+ Args:
+ room_id (str): The room to purge from
+
+ topological_ordering (int): minimum topo ordering to preserve
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
+
+ Returns:
+ str: unique ID for this purge transaction.
+ """
+ if room_id in self._purges_in_progress_by_room:
+ raise SynapseError(
+ 400,
+ "History purge already in progress for %s" % (room_id, ),
+ )
+
+ purge_id = random_string(16)
+
+ # we log the purge_id here so that it can be tied back to the
+ # request id in the log lines.
+ logger.info("[purge] starting purge_id %s", purge_id)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+ run_in_background(
+ self._purge_history,
+ purge_id, room_id, topological_ordering, delete_local_events,
+ )
+ return purge_id
@defer.inlineCallbacks
- def purge_history(self, room_id, event_id):
- event = yield self.store.get_event(event_id)
+ def _purge_history(self, purge_id, room_id, topological_ordering,
+ delete_local_events):
+ """Carry out a history purge on a room.
- if event.room_id != room_id:
- raise SynapseError(400, "Event is for wrong room.")
+ Args:
+ purge_id (str): The id for this purge
+ room_id (str): The room to purge from
+ topological_ordering (int): minimum topo ordering to preserve
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
- depth = event.depth
+ Returns:
+ Deferred
+ """
+ self._purges_in_progress_by_room.add(room_id)
+ try:
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.purge_history(
+ room_id, topological_ordering, delete_local_events,
+ )
+ logger.info("[purge] complete")
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
+ except Exception:
+ logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
+ finally:
+ self._purges_in_progress_by_room.discard(room_id)
+
+ # remove the purge from the list 24 hours after it completes
+ def clear_purge():
+ del self._purges_by_id[purge_id]
+ reactor.callLater(24 * 3600, clear_purge)
+
+ def get_purge_status(self, purge_id):
+ """Get the current status of an active purge
- with (yield self.pagination_lock.write(room_id)):
- yield self.store.delete_old_state(room_id, depth)
+ Args:
+ purge_id (str): purge_id returned by start_purge_history
+
+ Returns:
+ PurgeStatus|None
+ """
+ return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
@@ -175,7 +273,167 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
+ def get_room_data(self, user_id=None, room_id=None,
+ event_type=None, state_key="", is_guest=False):
+ """ Get data from a room.
+
+ Args:
+ event : The room path event
+ Returns:
+ The path data content.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ membership, membership_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if membership == Membership.JOIN:
+ data = yield self.state_handler.get_current_state(
+ room_id, event_type, state_key
+ )
+ elif membership == Membership.LEAVE:
+ key = (event_type, state_key)
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], [key]
+ )
+ data = room_state[membership_event_id].get(key)
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def _check_in_room_or_world_readable(self, room_id, user_id):
+ try:
+ # check_user_was_in_room will return the most recent membership
+ # event for the user if:
+ # * The user is a non-guest user, and was ever in the room
+ # * The user is a guest user, and has joined the room
+ # else it will throw.
+ member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ defer.returnValue((member_event.membership, member_event.event_id))
+ return
+ except AuthError:
+ visibility = yield self.state_handler.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
+ if (
+ visibility and
+ visibility.content["history_visibility"] == "world_readable"
+ ):
+ defer.returnValue((Membership.JOIN, None))
+ return
+ raise AuthError(
+ 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ )
+
+ @defer.inlineCallbacks
+ def get_state_events(self, user_id, room_id, is_guest=False):
+ """Retrieve all state events for a given room. If the user is
+ joined to the room then return the current state. If the user has
+ left the room return the state events from when they left.
+
+ Args:
+ user_id(str): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A list of dicts representing state events. [{}, {}, {}]
+ """
+ membership, membership_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if membership == Membership.JOIN:
+ room_state = yield self.state_handler.get_current_state(room_id)
+ elif membership == Membership.LEAVE:
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], None
+ )
+ room_state = room_state[membership_event_id]
+
+ now = self.clock.time_msec()
+ defer.returnValue(
+ [serialize_event(c, now) for c in room_state.values()]
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
+
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or becuase there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
+
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ })
+
+
+class EventCreationHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.validator = EventValidator()
+ self.profile_handler = hs.get_profile_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.server_name = hs.hostname
+ self.ratelimiter = hs.get_ratelimiter()
+ self.notifier = hs.get_notifier()
+ self.config = hs.config
+
+ self.http_client = hs.get_simple_http_client()
+
+ # This is only used to get at ratelimit function, and maybe_kick_guest_users
+ self.base_handler = BaseHandler(hs)
+
+ self.pusher_pool = hs.get_pusherpool()
+
+ # We arbitrarily limit concurrent event creation for a room to 5.
+ # This is to stop us from diverging history *too* much.
+ self.limiter = Limiter(max_count=5)
+
+ self.action_generator = hs.get_action_generator()
+
+ self.spam_checker = hs.get_spam_checker()
+
+ @defer.inlineCallbacks
+ def create_event(self, requester, event_dict, token_id=None, txn_id=None,
+ prev_events_and_hashes=None):
"""
Given a dict from a client, create a new event.
@@ -185,49 +443,56 @@ class MessageHandler(BaseHandler):
Adds display names to Join membership events.
Args:
+ requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
- prev_event_ids (list): The prev event ids to use when creating the event
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
- with (yield self.limiter.queue(builder.room_id)):
- self.validator.validate_new(builder)
-
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in {Membership.JOIN, Membership.INVITE}:
- # If event doesn't include a display name, add one.
- profile = self.hs.get_handlers().profile_handler
- content = builder.content
-
- try:
- if "displayname" not in content:
- content["displayname"] = yield profile.get_displayname(target)
- if "avatar_url" not in content:
- content["avatar_url"] = yield profile.get_avatar_url(target)
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s",
- target, e
- )
+ self.validator.validate_new(builder)
+
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ target = UserID.from_string(builder.state_key)
+
+ if membership in {Membership.JOIN, Membership.INVITE}:
+ # If event doesn't include a display name, add one.
+ profile = self.profile_handler
+ content = builder.content
+
+ try:
+ if "displayname" not in content:
+ content["displayname"] = yield profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+ except Exception as e:
+ logger.info(
+ "Failed to get profile information for %r: %s",
+ target, e
+ )
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ if token_id is not None:
+ builder.internal_metadata.token_id = token_id
- if txn_id is not None:
- builder.internal_metadata.txn_id = txn_id
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
- event, context = yield self._create_new_client_event(
- builder=builder,
- prev_event_ids=prev_event_ids,
- )
+ event, context = yield self.create_new_client_event(
+ builder=builder,
+ requester=requester,
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
defer.returnValue((event, context))
@@ -248,21 +513,6 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath"
)
- # We check here if we are currently being rate limited, so that we
- # don't do unnecessary work. We check again just before we actually
- # send the event.
- time_now = self.clock.time()
- allowed, time_allowed = self.ratelimiter.send_message(
- event.sender, time_now,
- msg_rate_hz=self.hs.config.rc_messages_per_second,
- burst_count=self.hs.config.rc_message_burst_count,
- update=False,
- )
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
- )
-
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -279,12 +529,6 @@ class MessageHandler(BaseHandler):
ratelimit=ratelimit,
)
- if event.type == EventTypes.Message:
- presence = self.hs.get_presence_handler()
- # We don't want to block sending messages on any presence code. This
- # matters as sometimes presence code can take a while.
- preserve_fn(presence.bump_presence_active_time)(user)
-
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
@@ -318,144 +562,87 @@ class MessageHandler(BaseHandler):
See self.create_event and self.send_nonmember_event.
"""
- event, context = yield self.create_event(
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
- )
- yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- )
- defer.returnValue(event)
-
- @defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
- """ Get data from a room.
-
- Args:
- event : The room path event
- Returns:
- The path data content.
- Raises:
- SynapseError if something went wrong.
- """
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
- if membership == Membership.JOIN:
- data = yield self.state_handler.get_current_state(
- room_id, event_type, state_key
+ # We limit the number of concurrent event sends in a room so that we
+ # don't fork the DAG too much. If we don't limit then we can end up in
+ # a situation where event persistence can't keep up, causing
+ # extremities to pile up, which in turn leads to state resolution
+ # taking longer.
+ with (yield self.limiter.queue(event_dict["room_id"])):
+ event, context = yield self.create_event(
+ requester,
+ event_dict,
+ token_id=requester.access_token_id,
+ txn_id=txn_id
)
- elif membership == Membership.LEAVE:
- key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
- )
- data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ spam_error = self.spam_checker.check_event_for_spam(event)
+ if spam_error:
+ if not isinstance(spam_error, basestring):
+ spam_error = "Spam is not permitted here"
+ raise SynapseError(
+ 403, spam_error, Codes.FORBIDDEN
+ )
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
- try:
- # check_user_was_in_room will return the most recent membership
- # event for the user if:
- # * The user is a non-guest user, and was ever in the room
- # * The user is a guest user, and has joined the room
- # else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
- return
- except AuthError:
- visibility = yield self.state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
- ):
- defer.returnValue((Membership.JOIN, None))
- return
- raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ yield self.send_nonmember_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
)
+ defer.returnValue(event)
+ @measure_func("create_new_client_event")
@defer.inlineCallbacks
- def get_state_events(self, user_id, room_id, is_guest=False):
- """Retrieve all state events for a given room. If the user is
- joined to the room then return the current state. If the user has
- left the room return the state events from when they left.
+ def create_new_client_event(self, builder, requester=None,
+ prev_events_and_hashes=None):
+ """Create a new event for a local client
Args:
- user_id(str): The user requesting state events.
- room_id(str): The room ID to get all state events from.
+ builder (EventBuilder):
+
+ requester (synapse.types.Requester|None):
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
+
Returns:
- A list of dicts representing state events. [{}, {}, {}]
+ Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
"""
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
- if membership == Membership.JOIN:
- room_state = yield self.state_handler.get_current_state(room_id)
- elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], None
+ if prev_events_and_hashes is not None:
+ assert len(prev_events_and_hashes) <= 10, \
+ "Attempting to create an event with %i prev_events" % (
+ len(prev_events_and_hashes),
)
- room_state = room_state[membership_event_id]
-
- now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
- )
-
- @measure_func("_create_new_client_event")
- @defer.inlineCallbacks
- def _create_new_client_event(self, builder, prev_event_ids=None):
- if prev_event_ids:
- prev_events = yield self.store.add_event_hashes(prev_event_ids)
- prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
- depth = prev_max_depth + 1
else:
- latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
- builder.room_id,
- )
-
- # We want to limit the max number of prev events we point to in our
- # new event
- if len(latest_ret) > 10:
- # Sort by reverse depth, so we point to the most recent.
- latest_ret.sort(key=lambda a: -a[2])
- new_latest_ret = latest_ret[:5]
-
- # We also randomly point to some of the older events, to make
- # sure that we don't completely ignore the older events.
- if latest_ret[5:]:
- sample_size = min(5, len(latest_ret[5:]))
- new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
- latest_ret = new_latest_ret
-
- if latest_ret:
- depth = max([d for _, _, d in latest_ret]) + 1
- else:
- depth = 1
+ prev_events_and_hashes = \
+ yield self.store.get_prev_events_for_room(builder.room_id)
+
+ if prev_events_and_hashes:
+ depth = max([d for _, _, d in prev_events_and_hashes]) + 1
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(depth, MAX_DEPTH)
+ else:
+ depth = 1
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in latest_ret
- ]
+ prev_events = [
+ (event_id, prev_hashes)
+ for event_id, prev_hashes, _ in prev_events_and_hashes
+ ]
builder.prev_events = prev_events
builder.depth = depth
- state_handler = self.state_handler
-
- context = yield state_handler.compute_event_context(builder)
+ context = yield self.state.compute_event_context(builder)
+ if requester:
+ context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
@@ -488,12 +675,21 @@ class MessageHandler(BaseHandler):
event,
context,
ratelimit=True,
- extra_users=[]
+ extra_users=[],
):
- # We now need to go and hit out to wherever we need to hit out to.
+ """Processes a new event. This includes checking auth, persisting it,
+ notifying users, sending to remote servers, etc.
- if ratelimit:
- self.ratelimit(requester)
+ If called from a worker will hit out to the master process for final
+ processing.
+
+ Args:
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
try:
yield self.auth.check_from_context(event, context)
@@ -501,7 +697,72 @@ class MessageHandler(BaseHandler):
logger.warn("Denying new event %r because %s", event, err)
raise err
- yield self.maybe_kick_guest_users(event, context)
+ # Ensure that we can round trip before trying to persist in db
+ try:
+ dump = frozendict_json_encoder.encode(event.content)
+ simplejson.loads(dump)
+ except Exception:
+ logger.exception("Failed to encode content: %r", event.content)
+ raise
+
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ try:
+ # If we're a worker we need to hit out to the master.
+ if self.config.worker_app:
+ yield send_event_to_master(
+ self.http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ event=event,
+ context=context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ return
+
+ yield self.persist_and_notify_client_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ except: # noqa: E722, as we reraise the exception this is fine.
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ tp, value, tb = sys.exc_info()
+
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
+
+ @defer.inlineCallbacks
+ def persist_and_notify_client_event(
+ self,
+ requester,
+ event,
+ context,
+ ratelimit=True,
+ extra_users=[],
+ ):
+ """Called when we have fully built the event, have already
+ calculated the push actions for the event, and checked auth.
+
+ This should only be run on master.
+ """
+ assert not self.config.worker_app
+
+ if ratelimit:
+ yield self.base_handler.ratelimit(requester)
+
+ yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -531,9 +792,9 @@ class MessageHandler(BaseHandler):
state_to_include_ids = [
e_id
- for k, e_id in context.current_state_ids.items()
+ for k, e_id in context.current_state_ids.iteritems()
if k[0] in self.hs.config.room_invite_state_types
- or k[0] == EventTypes.Member and k[1] == event.sender
+ or k == (EventTypes.Member, event.sender)
]
state_to_include = yield self.store.get_events(state_to_include_ids)
@@ -545,7 +806,7 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
- for e in state_to_include.values()
+ for e in state_to_include.itervalues()
]
invitee = UserID.from_string(event.state_key)
@@ -594,30 +855,39 @@ class MessageHandler(BaseHandler):
"Changing the room create event is forbidden",
)
- action_generator = ActionGenerator(self.hs)
- yield action_generator.handle_push_actions_for_event(
- event, context
- )
-
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ run_in_background(
+ self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
)
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
- yield self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ try:
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room event")
- preserve_fn(_notify)()
+ run_in_background(_notify)
- # If invite, remove room_state from unsigned before sending.
- event.unsigned.pop("invite_room_state", None)
+ if event.type == EventTypes.Message:
+ # We don't want to block sending messages on any presence code. This
+ # matters as sometimes presence code can take a while.
+ run_in_background(self._bump_active_time, requester.user)
+
+ @defer.inlineCallbacks
+ def _bump_active_time(self, user):
+ try:
+ presence = self.hs.get_presence_handler()
+ yield presence.bump_presence_active_time(user)
+ except Exception:
+ logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da610e430f..585f3e4da2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -29,7 +29,9 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.async import Linearizer
+from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -91,29 +93,30 @@ class PresenceHandler(object):
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
- self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
- self.replication.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
@@ -186,6 +189,7 @@ class PresenceHandler(object):
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
+ self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
@@ -251,6 +255,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
+ def _update_states_and_catch_exception(self, new_states):
+ try:
+ res = yield self._update_states(new_states)
+ defer.returnValue(res)
+ except Exception:
+ logger.exception("Error updating presence")
+
+ @defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -315,11 +327,7 @@ class PresenceHandler(object):
if to_federation_ping:
federation_presence_out_counter.inc_by(len(to_federation_ping))
- _, _, hosts_to_states = yield self._get_interested_parties(
- to_federation_ping.values()
- )
-
- self._push_to_remotes(hosts_to_states)
+ self._push_to_remotes(to_federation_ping.values())
def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
@@ -364,8 +372,8 @@ class PresenceHandler(object):
now=now,
)
- preserve_fn(self._update_states)(changes)
- except:
+ run_in_background(self._update_states_and_catch_exception, changes)
+ except Exception:
logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks
@@ -422,20 +430,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def _end():
- if affect_presence:
+ try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(),
)])
+ except Exception:
+ logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
try:
yield
finally:
- preserve_fn(_end)()
+ if affect_presence:
+ run_in_background(_end)
defer.returnValue(_user_syncing())
@@ -508,6 +519,73 @@ class PresenceHandler(object):
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks
+ def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
+ """Update the syncing users for an external process as a delta.
+
+ Args:
+ process_id (str): An identifier for the process the users are
+ syncing against. This allows synapse to process updates
+ as user start and stop syncing against a given process.
+ user_id (str): The user who has started or stopped syncing
+ is_syncing (bool): Whether or not the user is now syncing
+ sync_time_msec(int): Time in ms when the user was last syncing
+ """
+ with (yield self.external_sync_linearizer.queue(process_id)):
+ prev_state = yield self.current_state_for_user(user_id)
+
+ process_presence = self.external_process_to_current_syncs.setdefault(
+ process_id, set()
+ )
+
+ updates = []
+ if is_syncing and user_id not in process_presence:
+ if prev_state.state == PresenceState.OFFLINE:
+ updates.append(prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=sync_time_msec,
+ last_user_sync_ts=sync_time_msec,
+ ))
+ else:
+ updates.append(prev_state.copy_and_replace(
+ last_user_sync_ts=sync_time_msec,
+ ))
+ process_presence.add(user_id)
+ elif user_id in process_presence:
+ updates.append(prev_state.copy_and_replace(
+ last_user_sync_ts=sync_time_msec,
+ ))
+
+ if not is_syncing:
+ process_presence.discard(user_id)
+
+ if updates:
+ yield self._update_states(updates)
+
+ self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
+
+ @defer.inlineCallbacks
+ def update_external_syncs_clear(self, process_id):
+ """Marks all users that had been marked as syncing by a given process
+ as offline.
+
+ Used when the process has stopped/disappeared.
+ """
+ with (yield self.external_sync_linearizer.queue(process_id)):
+ process_presence = self.external_process_to_current_syncs.pop(
+ process_id, set()
+ )
+ prev_states = yield self.current_state_for_users(process_presence)
+ time_now_ms = self.clock.time_msec()
+
+ yield self._update_states([
+ prev_state.copy_and_replace(
+ last_user_sync_ts=time_now_ms,
+ )
+ for prev_state in prev_states.itervalues()
+ ])
+ self.external_process_last_updated_ms.pop(process_id, None)
+
+ @defer.inlineCallbacks
def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
"""
@@ -526,14 +604,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
- missing = [user_id for user_id, state in states.items() if not state]
+ missing = [user_id for user_id, state in states.iteritems() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.items() if not state]
+ missing = [user_id for user_id, state in states.iteritems() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -545,89 +623,39 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
- def _get_interested_parties(self, states, calculate_remote_hosts=True):
- """Given a list of states return which entities (rooms, users, servers)
- are interested in the given states.
-
- Returns:
- 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`,
- with each item being a dict of `entity_name` -> `[UserPresenceState]`
- """
- room_ids_to_states = {}
- users_to_states = {}
- for state in states:
- events = yield self.store.get_rooms_for_user(state.user_id)
- for e in events:
- room_ids_to_states.setdefault(e.room_id, []).append(state)
-
- plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
- for u in plist:
- users_to_states.setdefault(u, []).append(state)
-
- # Always notify self
- users_to_states.setdefault(state.user_id, []).append(state)
-
- hosts_to_states = {}
- if calculate_remote_hosts:
- for room_id, states in room_ids_to_states.items():
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- if not local_states:
- continue
-
- users = yield self.store.get_users_in_room(room_id)
- hosts = set(get_domain_from_id(u) for u in users)
-
- for host in hosts:
- hosts_to_states.setdefault(host, []).extend(local_states)
-
- for user_id, states in users_to_states.items():
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- if not local_states:
- continue
-
- host = get_domain_from_id(user_id)
- hosts_to_states.setdefault(host, []).extend(local_states)
-
- # TODO: de-dup hosts_to_states, as a single host might have multiple
- # of same presence
-
- defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states))
-
- @defer.inlineCallbacks
def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
"""
stream_id, max_token = yield self.store.update_presence(states)
- parties = yield self._get_interested_parties(states)
- room_ids_to_states, users_to_states, hosts_to_states = parties
+ parties = yield get_interested_parties(self.store, states)
+ room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states.keys()]
+ users=[UserID.from_string(u) for u in users_to_states]
)
- self._push_to_remotes(hosts_to_states)
+ self._push_to_remotes(states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
- parties = yield self._get_interested_parties([state])
- room_ids_to_states, users_to_states, hosts_to_states = parties
+ parties = yield get_interested_parties(self.store, [state])
+ room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states.keys()]
+ users=[UserID.from_string(u) for u in users_to_states]
)
- def _push_to_remotes(self, hosts_to_states):
+ def _push_to_remotes(self, states):
"""Sends state updates to remote servers.
Args:
- hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
+ states (list(UserPresenceState))
"""
- for host, states in hosts_to_states.items():
- self.federation.send_presence(host, states)
+ self.federation.send_presence(states)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
@@ -719,9 +747,7 @@ class PresenceHandler(object):
for state in updates
])
else:
- defer.returnValue([
- format_user_presence_state(state, now) for state in updates
- ])
+ defer.returnValue(updates)
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
@@ -766,18 +792,17 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
- user_ids = yield self.store.get_users_in_room(room_id)
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
- hosts = set(get_domain_from_id(u) for u in user_ids)
- self._push_to_remotes({host: (state,) for host in hosts})
+ self._push_to_remotes([state])
else:
+ user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids)
- self._push_to_remotes({user.domain: states.values()})
+ self._push_to_remotes(states.values())
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@@ -795,6 +820,9 @@ class PresenceHandler(object):
as_event=False,
)
+ now = self.clock.time_msec()
+ results[:] = [format_user_presence_state(r, now) for r in results]
+
is_accepted = {
row["observed_user_id"]: row["accepted"] for row in presence_list
}
@@ -847,6 +875,7 @@ class PresenceHandler(object):
)
state_dict = yield self.get_state(observed_user, as_event=False)
+ state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.send_edu(
destination=observer_user.domain,
@@ -910,11 +939,12 @@ class PresenceHandler(object):
def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
- observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string())
- observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string())
-
- observer_room_ids = set(r.room_id for r in observer_rooms)
- observed_room_ids = set(r.room_id for r in observed_rooms)
+ observer_room_ids = yield self.store.get_rooms_for_user(
+ observer_user.to_string()
+ )
+ observed_room_ids = yield self.store.get_rooms_for_user(
+ observed_user.to_string()
+ )
if observer_room_ids & observed_room_ids:
defer.returnValue(True)
@@ -979,14 +1009,18 @@ def should_notify(old_state, new_state):
return False
-def format_user_presence_state(state, now):
+def format_user_presence_state(state, now, include_user_id=True):
"""Convert UserPresenceState to a format that can be sent down to clients
and to other servers.
+
+ The "user_id" is optional so that this function can be used to format presence
+ updates for client /sync responses and for federation /send requests.
"""
content = {
"presence": state.state,
- "user_id": state.user_id,
}
+ if include_user_id:
+ content["user_id"] = state.user_id
if state.last_active_ts:
content["last_active_ago"] = now - state.last_active_ts
if state.status_msg and state.state != PresenceState.OFFLINE:
@@ -1025,7 +1059,6 @@ class PresenceEventSource(object):
# sending down the rare duplicate is not a concern.
with Measure(self.clock, "presence.get_new_events"):
- user_id = user.to_string()
if from_key is not None:
from_key = int(from_key)
@@ -1034,18 +1067,7 @@ class PresenceEventSource(object):
max_token = self.store.get_current_presence_token()
- plist = yield self.store.get_presence_list_accepted(user.localpart)
- users_interested_in = set(row["observed_user_id"] for row in plist)
- users_interested_in.add(user_id) # So that we receive our own presence
-
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
- users_interested_in.update(users_who_share_room)
-
- if explicit_room_id:
- user_ids = yield self.store.get_users_in_room(explicit_room_id)
- users_interested_in.update(user_ids)
+ users_interested_in = yield self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
changed = None
@@ -1073,16 +1095,13 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
- now = self.clock.time_msec()
-
- defer.returnValue(([
- {
- "type": "m.presence",
- "content": format_user_presence_state(s, now),
- }
- for s in updates.values()
- if include_offline or s.state != PresenceState.OFFLINE
- ], max_token))
+ if include_offline:
+ defer.returnValue((updates.values(), max_token))
+ else:
+ defer.returnValue(([
+ s for s in updates.itervalues()
+ if s.state != PresenceState.OFFLINE
+ ], max_token))
def get_current_key(self):
return self.store.get_current_presence_token()
@@ -1090,6 +1109,31 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False)
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _get_interested_in(self, user, explicit_room_id, cache_context):
+ """Returns the set of users that the given user should see presence
+ updates for
+ """
+ user_id = user.to_string()
+ plist = yield self.store.get_presence_list_accepted(
+ user.localpart, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in = set(row["observed_user_id"] for row in plist)
+ users_interested_in.add(user_id) # So that we receive our own presence
+
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in.update(users_who_share_room)
+
+ if explicit_room_id:
+ user_ids = yield self.store.get_users_in_room(
+ explicit_room_id, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in.update(user_ids)
+
+ defer.returnValue(users_interested_in)
+
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as
@@ -1157,14 +1201,17 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# If there are have been no sync for a while (and none ongoing),
# set presence to offline
if user_id not in syncing_user_ids:
- if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
+ # If the user has done something recently but hasn't synced,
+ # don't set them as offline.
+ sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
+ if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
state=PresenceState.OFFLINE,
status_msg=None,
)
changed = True
else:
- # We expect to be poked occaisonally by the other side.
+ # We expect to be poked occasionally by the other side.
# This is to protect against forgetful/buggy servers, so that
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
@@ -1255,3 +1302,66 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
persist_and_notify = True
return new_state, persist_and_notify, federation_ping
+
+
+@defer.inlineCallbacks
+def get_interested_parties(store, states):
+ """Given a list of states return which entities (rooms, users)
+ are interested in the given states.
+
+ Args:
+ states (list(UserPresenceState))
+
+ Returns:
+ 2-tuple: `(room_ids_to_states, users_to_states)`,
+ with each item being a dict of `entity_name` -> `[UserPresenceState]`
+ """
+ room_ids_to_states = {}
+ users_to_states = {}
+ for state in states:
+ room_ids = yield store.get_rooms_for_user(state.user_id)
+ for room_id in room_ids:
+ room_ids_to_states.setdefault(room_id, []).append(state)
+
+ plist = yield store.get_presence_list_observers_accepted(state.user_id)
+ for u in plist:
+ users_to_states.setdefault(u, []).append(state)
+
+ # Always notify self
+ users_to_states.setdefault(state.user_id, []).append(state)
+
+ defer.returnValue((room_ids_to_states, users_to_states))
+
+
+@defer.inlineCallbacks
+def get_interested_remotes(store, states, state_handler):
+ """Given a list of presence states figure out which remote servers
+ should be sent which.
+
+ All the presence states should be for local users only.
+
+ Args:
+ store (DataStore)
+ states (list(UserPresenceState))
+
+ Returns:
+ Deferred list of ([destinations], [UserPresenceState]), where for
+ each row the list of UserPresenceState should be sent to each
+ destination
+ """
+ hosts_and_states = []
+
+ # First we look up the rooms each user is in (as well as any explicit
+ # subscriptions), then for each distinct room we look up the remote
+ # hosts in those rooms.
+ room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
+
+ for room_id, states in room_ids_to_states.iteritems():
+ hosts = yield state_handler.get_current_hosts_in_room(room_id)
+ hosts_and_states.append((hosts, states))
+
+ for user_id, states in users_to_states.iteritems():
+ host = get_domain_from_id(user_id)
+ hosts_and_states.append(([host], states))
+
+ defer.returnValue(hosts_and_states)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 87f74dfb8e..3465a787ab 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,25 +17,87 @@ import logging
from twisted.internet import defer
-import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler):
+ PROFILE_UPDATE_MS = 60 * 1000
+ PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(ProfileHandler, self).__init__(hs)
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query
)
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ )
+
+ @defer.inlineCallbacks
+ def get_profile(self, user_id):
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ try:
+ result = yield self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ defer.returnValue(result)
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get displayname")
+
+ raise
+
+ @defer.inlineCallbacks
+ def get_profile_from_cache(self, user_id):
+ """Get the profile information from our local cache. If the user is
+ ours then the profile information will always be corect. Otherwise,
+ it may be out of date/missing.
+ """
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ profile = yield self.store.get_from_remote_profile_cache(user_id)
+ defer.returnValue(profile or {})
+
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
@@ -52,14 +114,15 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "displayname",
- }
+ },
+ ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:
logger.exception("Failed to get displayname")
raise
- except:
+ except Exception:
logger.exception("Failed to get displayname")
else:
defer.returnValue(result["displayname"])
@@ -81,7 +144,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -99,13 +168,14 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "avatar_url",
- }
+ },
+ ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:
logger.exception("Failed to get avatar_url")
raise
- except:
+ except Exception:
logger.exception("Failed to get avatar_url")
defer.returnValue(result["avatar_url"])
@@ -124,7 +194,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -149,34 +225,71 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response)
@defer.inlineCallbacks
- def _update_join_states(self, requester):
- user = requester.user
- if not self.hs.is_mine(user):
+ def _update_join_states(self, requester, target_user):
+ if not self.hs.is_mine(target_user):
return
- self.ratelimit(requester)
+ yield self.ratelimit(requester)
- joins = yield self.store.get_rooms_for_user(
- user.to_string(),
+ room_ids = yield self.store.get_rooms_for_user(
+ target_user.to_string(),
)
- for j in joins:
- handler = self.hs.get_handlers().room_member_handler
+ for room_id in room_ids:
+ handler = self.hs.get_room_member_handler()
try:
- # Assume the user isn't a guest because we don't let guests set
- # profile or avatar data.
- # XXX why are we recreating `requester` here for each room?
- # what was wrong with the `requester` we were passed?
- requester = synapse.types.create_requester(user)
+ # Assume the target_user isn't a guest,
+ # because we don't let guests set profile or avatar data.
yield handler.update_membership(
requester,
- user,
- j.room_id,
+ target_user,
+ room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
- j.room_id, str(e.message)
+ room_id, str(e.message)
)
+
+ def _update_remote_profile_cache(self):
+ """Called periodically to check profiles of remote users we haven't
+ checked in a while.
+ """
+ entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
+ )
+
+ for user_id, displayname, avatar_url in entries:
+ is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ user_id,
+ )
+ if not is_subscribed:
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+ continue
+
+ try:
+ profile = yield self.federation.make_query(
+ destination=get_domain_from_id(user_id),
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ except Exception:
+ logger.exception("Failed to get avatar_url")
+
+ yield self.store.update_remote_profile_cache(
+ user_id, displayname, avatar_url
+ )
+ continue
+
+ new_name = profile.get("displayname")
+ new_avatar = profile.get("avatar_url")
+
+ # We always hit update to update the last_check timestamp
+ yield self.store.update_remote_profile_cache(
+ user_id, new_name, new_avatar
+ )
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
new file mode 100644
index 0000000000..5142ae153d
--- /dev/null
+++ b/synapse/handlers/read_marker.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseHandler
+
+from twisted.internet import defer
+
+from synapse.util.async import Linearizer
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class ReadMarkerHandler(BaseHandler):
+ def __init__(self, hs):
+ super(ReadMarkerHandler, self).__init__(hs)
+ self.server_name = hs.config.server_name
+ self.store = hs.get_datastore()
+ self.read_marker_linearizer = Linearizer(name="read_marker")
+ self.notifier = hs.get_notifier()
+
+ @defer.inlineCallbacks
+ def received_client_read_marker(self, room_id, user_id, event_id):
+ """Updates the read marker for a given user in a given room if the event ID given
+ is ahead in the stream relative to the current read marker.
+
+ This uses a notifier to indicate that account data should be sent down /sync if
+ the read marker has changed.
+ """
+
+ with (yield self.read_marker_linearizer.queue((room_id, user_id))):
+ existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+ user_id, room_id, "m.fully_read",
+ )
+
+ should_update = True
+
+ if existing_read_marker:
+ # Only update if the new marker is ahead in the stream
+ should_update = yield self.store.is_event_after(
+ event_id,
+ existing_read_marker['event_id']
+ )
+
+ if should_update:
+ content = {
+ "event_id": event_id
+ }
+ max_id = yield self.store.add_account_data_to_room(
+ user_id, room_id, "m.fully_read", content
+ )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 50aa513935..2e0672161c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.util import logcontext
from ._base import BaseHandler
@@ -34,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
@@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
is_new = yield self._handle_new_receipts([receipt])
if is_new:
+ # fire off a process in the background to send the receipt to
+ # remote servers
self._push_remotes([receipt])
@defer.inlineCallbacks
@@ -126,42 +129,46 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
+ @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
- # TODO: Some of this stuff should be coallesced.
- for receipt in receipts:
- room_id = receipt["room_id"]
- receipt_type = receipt["receipt_type"]
- user_id = receipt["user_id"]
- event_ids = receipt["event_ids"]
- data = receipt["data"]
-
- users = yield self.state.get_current_user_in_room(room_id)
- remotedomains = set(get_domain_from_id(u) for u in users)
- remotedomains = remotedomains.copy()
- remotedomains.discard(self.server_name)
-
- logger.debug("Sending receipt to: %r", remotedomains)
-
- for domain in remotedomains:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.receipt",
- content={
- room_id: {
- receipt_type: {
- user_id: {
- "event_ids": event_ids,
- "data": data,
+ try:
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ remotedomains = set(get_domain_from_id(u) for u in users)
+ remotedomains = remotedomains.copy()
+ remotedomains.discard(self.server_name)
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
}
- }
+ },
},
- },
- key=(room_id, receipt_type, user_id),
- )
+ key=(room_id, receipt_type, user_id),
+ )
+ except Exception:
+ logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
@@ -210,10 +217,9 @@ class ReceiptEventSource(object):
else:
from_key = None
- rooms = yield self.store.get_rooms_for_user(user.to_string())
- rooms = [room.room_id for room in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
- rooms,
+ room_ids,
from_key=from_key,
to_key=to_key,
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 03c6a85fc6..f83c6b3cf8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,7 +15,6 @@
"""Contains functions for registering clients."""
import logging
-import urllib
from twisted.internet import defer
@@ -23,8 +22,10 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from synapse.http.client import CaptchaServerHttpClient
-from synapse.types import UserID
-from synapse.util.async import run_on_reactor
+from synapse import types
+from synapse.types import UserID, create_requester, RoomID, RoomAlias
+from synapse.util.async import run_on_reactor, Linearizer
+from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -36,21 +37,33 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self.profile_handler = hs.get_profile_handler()
+ self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
+ self._generate_user_id_linearizer = Linearizer(
+ name="_generate_user_id_linearizer",
+ )
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
- yield run_on_reactor()
+ if types.contains_invalid_mxid_characters(localpart):
+ raise SynapseError(
+ 400,
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
+ Codes.INVALID_USERNAME
+ )
- if urllib.quote(localpart.encode('utf-8')) != localpart:
+ if not localpart:
raise SynapseError(
400,
- "User ID can only contain characters a-z, 0-9, or '_-./'",
+ "User ID cannot be empty",
Codes.INVALID_USERNAME
)
@@ -73,7 +86,7 @@ class RegistrationHandler(BaseHandler):
"A different user ID has already been registered for this session",
)
- yield self.check_user_id_not_appservice_exclusive(user_id)
+ self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
@@ -123,7 +136,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor()
password_hash = None
if password:
- password_hash = self.auth_handler().hash(password)
+ password_hash = yield self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -158,6 +171,13 @@ class RegistrationHandler(BaseHandler):
),
admin=admin,
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
else:
# autogen a sequential user ID
attempts = 0
@@ -185,10 +205,17 @@ class RegistrationHandler(BaseHandler):
token = None
attempts += 1
+ # auto-join the user to any rooms we're supposed to dump them into
+ fake_requester = create_requester(user_id)
+ for r in self.hs.config.auto_join_rooms:
+ try:
+ yield self._join_user_to_room(fake_requester, r)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
# rather than there being consistent matrix-wide ones, so we don't.
-
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -246,11 +273,10 @@ class RegistrationHandler(BaseHandler):
"""
Registers email_id as SAML2 Based Auth.
"""
- if urllib.quote(localpart) != localpart:
+ if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
- "User ID must only contain characters which do not"
- " require URL encoding."
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -279,12 +305,12 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating theeepidcred sid %s on id server %s",
+ logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
- except:
+ except Exception:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@@ -293,6 +319,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
+ if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+ raise RegistrationError(
+ 403, "Third party identifier is not allowed"
+ )
+
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
@@ -325,9 +356,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
- self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
- )
+ with (yield self._generate_user_id_linearizer.queue(())):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
id = self._next_generated_user_id
self._next_generated_user_id += 1
@@ -411,13 +444,12 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart,
)
else:
- yield self.store.user_delete_access_tokens(user_id=user_id)
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
- profile_handler = self.hs.get_handlers().profile_handler
- yield profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True,
)
@@ -427,16 +459,59 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
- def guest_access_token_for(self, medium, address, inviter_user_id):
+ def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
- defer.returnValue(access_token)
+ user_info = yield self.auth.get_user_by_access_token(
+ access_token
+ )
- _, access_token = yield self.register(
+ defer.returnValue((user_info["user"].to_string(), access_token))
+
+ user_id, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
- defer.returnValue(access_token)
+
+ defer.returnValue((user_id, access_token))
+
+ @defer.inlineCallbacks
+ def _join_user_to_room(self, requester, room_identifier):
+ room_id = None
+ room_member_handler = self.hs.get_room_member_handler()
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = (
+ yield room_member_handler.lookup_room_alias(room_alias)
+ )
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(400, "%s was not legal room ID or room alias" % (
+ room_identifier,
+ ))
+
+ yield room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7e7671c9a2..8df8fcbbad 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,8 +61,14 @@ class RoomCreationHandler(BaseHandler):
},
}
+ def __init__(self, hs):
+ super(RoomCreationHandler, self).__init__(hs)
+
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
+
@defer.inlineCallbacks
- def create_room(self, requester, config):
+ def create_room(self, requester, config, ratelimit=True):
""" Creates a new room.
Args:
@@ -75,14 +82,18 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- self.ratelimit(requester)
+ if not self.spam_checker.user_may_create_room(user_id):
+ raise SynapseError(403, "You are not permitted to create rooms")
+
+ if ratelimit:
+ yield self.ratelimit(requester)
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias.create(
+ room_alias = RoomAlias(
config["room_alias_name"],
self.hs.hostname,
)
@@ -99,7 +110,7 @@ class RoomCreationHandler(BaseHandler):
for i in invite_list:
try:
UserID.from_string(i)
- except:
+ except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
invite_3pid_list = config.get("invite_3pid", [])
@@ -114,7 +125,7 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID.create(
+ gen_room_id = RoomID(
random_string,
self.hs.hostname,
)
@@ -154,24 +165,23 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
- msg_handler = self.hs.get_handlers().message_handler
- room_member_handler = self.hs.get_handlers().room_member_handler
+ room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room(
requester,
room_id,
- msg_handler,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
+ power_level_content_override=config.get("power_level_content_override", {})
)
if "name" in config:
name = config["name"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -184,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -195,12 +205,12 @@ class RoomCreationHandler(BaseHandler):
},
ratelimit=False)
- content = {}
- is_direct = config.get("is_direct", None)
- if is_direct:
- content["is_direct"] = is_direct
-
for invitee in invite_list:
+ content = {}
+ is_direct = config.get("is_direct", None)
+ if is_direct:
+ content["is_direct"] = is_direct
+
yield room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -214,7 +224,7 @@ class RoomCreationHandler(BaseHandler):
id_server = invite_3pid["id_server"]
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
+ yield self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -239,13 +249,13 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
- msg_handler,
room_member_handler,
preset_config,
invite_list,
initial_state,
creation_content,
- room_alias
+ room_alias,
+ power_level_content_override,
):
def create(etype, content, **kwargs):
e = {
@@ -261,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False
@@ -291,7 +301,15 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- if (EventTypes.PowerLevels, '') not in initial_state:
+ # We treat the power levels override specially as this needs to be one
+ # of the first events that get sent into a room.
+ pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+ if pl_content is not None:
+ yield send(
+ etype=EventTypes.PowerLevels,
+ content=pl_content,
+ )
+ else:
power_level_content = {
"users": {
creator_id: 100,
@@ -316,6 +334,8 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list:
power_level_content["users"][invitee] = 100
+ power_level_content.update(power_level_content_override)
+
yield send(
etype=EventTypes.PowerLevels,
content=power_level_content,
@@ -356,7 +376,7 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, is_guest):
+ def get_event_context(self, user, room_id, event_id, limit):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -375,12 +395,15 @@ class RoomContextHandler(BaseHandler):
now_token = yield self.hs.get_event_sources().get_current_token()
+ users = yield self.store.get_users_in_room(room_id)
+ is_peeking = user.to_string() not in users
+
def filter_evts(events):
return filter_events_for_client(
self.store,
user.to_string(),
events,
- is_peeking=is_guest
+ is_peeking=is_peeking
)
event = yield self.store.get_event(event_id, get_prev_content=True,
@@ -452,12 +475,9 @@ class RoomEventSource(object):
user.to_string()
)
if app_service:
- events, end_key = yield self.store.get_appservice_room_stream(
- service=app_service,
- from_key=from_key,
- to_key=to_key,
- limit=limit,
- )
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
else:
room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 19eebbd43f..5757bb7f8a 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,12 +15,15 @@
from twisted.internet import defer
+from six.moves import range
+
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
@@ -42,8 +45,9 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache(hs)
- self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+ self.response_cache = ResponseCache(hs, "room_list")
+ self.remote_response_cache = ResponseCache(hs, "remote_room_list",
+ timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
@@ -62,23 +66,24 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
+ logger.info(
+ "Getting public room list: limit=%r, since=%r, search=%r, network=%r",
+ limit, since_token, bool(search_filter), network_tuple,
+ )
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
+ logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
key = (limit, since_token, network_tuple)
- result = self.response_cache.get(key)
- if not result:
- result = self.response_cache.set(
- key,
- self._get_public_room_list(
- limit, since_token, network_tuple=network_tuple
- )
- )
- return result
+ return self.response_cache.wrap(
+ key,
+ self._get_public_room_list,
+ limit, since_token, network_tuple=network_tuple,
+ )
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
@@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
rooms_to_order_value = {}
rooms_to_num_joined = {}
- rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
@@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_order_for_room(room_id):
- latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
- if not latest_event_ids:
+ # Most of the rooms won't have changed between the since token and
+ # now (especially if the since token is "now"). So, we can ask what
+ # the current users are in a room (that will hit a cache) and then
+ # check if the room has changed since the since token. (We have to
+ # do it in that order to avoid races).
+ # If things have changed then fall back to getting the current state
+ # at the since token.
+ joined_users = yield self.store.get_users_in_room(room_id)
+ if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
- rooms_to_latest_event_ids[room_id] = latest_event_ids
- if not latest_event_ids:
- return
+ if not latest_event_ids:
+ return
+
+ joined_users = yield self.state_handler.get_current_user_in_room(
+ room_id, latest_event_ids,
+ )
- joined_users = yield self.state_handler.get_current_user_in_room(
- room_id, latest_event_ids,
- )
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
@@ -138,6 +149,8 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
+ logger.info("Getting ordering for %i rooms since %s",
+ len(room_ids), stream_token)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -165,34 +178,43 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
- # Actually generate the entries. _generate_room_entry will append to
- # chunk but will stop if len(chunk) > limit
- chunk = []
- if limit and not search_filter:
+ logger.info("After sorting and filtering, %i rooms remain",
+ len(rooms_to_scan))
+
+ # _append_room_entry_to_chunk will append to chunk but will stop if
+ # len(chunk) > limit
+ #
+ # Normally we will generate enough results on the first iteration here,
+ # but if there is a search filter, _append_room_entry_to_chunk may
+ # filter some results out, in which case we loop again.
+ #
+ # We don't want to scan over the entire range either as that
+ # would potentially waste a lot of work.
+ #
+ # XXX if there is no limit, we may end up DoSing the server with
+ # calls to get_current_state_ids for every single room on the
+ # server. Surely we should cap this somehow?
+ #
+ if limit:
step = limit + 1
- for i in xrange(0, len(rooms_to_scan), step):
- # We iterate here because the vast majority of cases we'll stop
- # at first iteration, but occaisonally _generate_room_entry
- # won't append to the chunk and so we need to loop again.
- # We don't want to scan over the entire range either as that
- # would potentially waste a lot of work.
- yield concurrently_execute(
- lambda r: self._generate_room_entry(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter
- ),
- rooms_to_scan[i:i + step], 10
- )
- if len(chunk) >= limit + 1:
- break
else:
+ # step cannot be zero
+ step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
+
+ chunk = []
+ for i in range(0, len(rooms_to_scan), step):
+ batch = rooms_to_scan[i:i + step]
+ logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
- lambda r: self._generate_room_entry(
+ lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
- rooms_to_scan, 5
+ batch, 5,
)
+ logger.info("Now %i rooms in result", len(chunk))
+ if len(chunk) >= limit + 1:
+ break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
@@ -256,21 +278,36 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results)
@defer.inlineCallbacks
- def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
- search_filter):
+ def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
+ search_filter):
+ """Generate the entry for a room in the public room list and append it
+ to the `chunk` if it matches the search filter
+ """
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
+ result = yield self.generate_room_entry(room_id, num_joined_users)
+
+ if result and _matches_room_entry(result, search_filter):
+ chunk.append(result)
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def generate_room_entry(self, room_id, num_joined_users, cache_context,
+ with_alias=True, allow_private=False):
+ """Returns the entry for a room
+ """
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
- current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
+ current_state_ids = yield self.store.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
event_map = yield self.store.get_events([
- event_id for key, event_id in current_state_ids.items()
+ event_id for key, event_id in current_state_ids.iteritems()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
@@ -291,12 +328,15 @@ class RoomListHandler(BaseHandler):
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
- if join_rule and join_rule != JoinRules.PUBLIC:
+ if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
- aliases = yield self.store.get_aliases_for_room(room_id)
- if aliases:
- result["aliases"] = aliases
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
@@ -334,8 +374,7 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
- if _matches_room_entry(result, search_filter):
- chunk.append(result)
+ defer.returnValue(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
@@ -365,7 +404,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
- repl_layer = self.hs.get_replication_layer()
+ repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
@@ -378,18 +417,14 @@ class RoomListHandler(BaseHandler):
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
- result = self.remote_response_cache.get(key)
- if not result:
- result = self.remote_response_cache.set(
- key,
- repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
- )
- return result
+ return self.remote_response_cache.wrap(
+ key,
+ repl_layer.get_public_rooms,
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index b2806555cf..714583f1d5 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import abc
import logging
from signedjson.key import decode_verify_key_bytes
@@ -29,47 +30,139 @@ from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
-from ._base import BaseHandler
+
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
+ __metaclass__ = abc.ABCMeta
+
def __init__(self, hs):
- super(RoomMemberHandler, self).__init__(hs)
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.state_handler = hs.get_state_handler()
+ self.config = hs.config
+ self.simple_http_client = hs.get_simple_http_client()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.directory_handler = hs.get_handlers().directory_handler
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.profile_handler = hs.get_profile_handler()
+ self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
+ self.spam_checker = hs.get_spam_checker()
- self.distributor = hs.get_distributor()
- self.distributor.declare("user_joined_room")
- self.distributor.declare("user_left_room")
+ @abc.abstractmethod
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Try and join a room that this server is not in
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers that can be used
+ to join via.
+ room_id (str): Room that we are trying to join
+ user (UserID): User who is trying to join
+ content (dict): A dict that should be used as the content of the
+ join event.
+
+ Returns:
+ Deferred
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+ """Attempt to reject an invite for a room this server is not in. If we
+ fail to do so we locally mark the invite as rejected.
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers to use to try and
+ reject invite
+ room_id (str)
+ target (UserID): The user rejecting the invite
+
+ Returns:
+ Deferred[dict]: A dictionary to be returned to the client, may
+ include event_id etc, or nothing if we locally rejected
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ requester (Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_joined_room(self, target, room_id):
+ """Notifies distributor on master process that the user has joined the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_left_room(self, target, room_id):
+ """Notifies distributor on master process that the user has left the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
- prev_event_ids,
+ prev_events_and_hashes,
txn_id=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
- msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
+ requester,
{
"type": EventTypes.Member,
"content": content,
@@ -82,16 +175,18 @@ class RoomMemberHandler(BaseHandler):
},
token_id=requester.access_token_id,
txn_id=txn_id,
- prev_event_ids=prev_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield msg_handler.deduplicate_state_event(event, context)
+ duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield msg_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -113,40 +208,16 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target, room_id)
+ yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target, room_id)
+ yield self._user_left_room(target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks
- def remote_join(self, remote_room_hosts, room_id, user, content):
- if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
-
- # We don't do an auth check if we are doing an invite
- # join dance for now, since we're kinda implicitly checking
- # that we are allowed to join when we decide whether or not we
- # need to do the invite/join dance.
- yield self.hs.get_handlers().federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
- )
- yield user_joined_room(self.distributor, user, room_id)
-
- def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
- return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- user_id
- )
-
- @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -192,14 +263,19 @@ class RoomMemberHandler(BaseHandler):
content_specified = bool(content)
if content is None:
content = {}
+ else:
+ # We do a copy here as we potentially change some keys
+ # later on.
+ content = dict(content)
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
+ # if this is a join with a 3pid signature, we may need to turn a 3pid
+ # invite into a normal invite before we can handle the join.
if third_party_signed is not None:
- replication = self.hs.get_replication_layer()
- yield replication.exchange_third_party_invite(
+ yield self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@@ -209,7 +285,41 @@ class RoomMemberHandler(BaseHandler):
if not remote_room_hosts:
remote_room_hosts = []
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ if effective_membership_state not in ("leave", "ban",):
+ is_blocked = yield self.store.is_room_blocked(room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
+ if effective_membership_state == "invite":
+ block_invite = False
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ if self.config.block_non_admin_invites:
+ logger.info(
+ "Blocking invite: user is not admin and non-admin "
+ "invites disabled"
+ )
+ block_invite = True
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(), target.to_string(), room_id,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ block_invite = True
+
+ if block_invite:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ )
+
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ room_id,
+ )
+ latest_event_ids = (
+ event_id for (event_id, _, _) in prev_events_and_hashes
+ )
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
@@ -250,13 +360,13 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN
- profile = self.hs.get_handlers().profile_handler
+ profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
@@ -264,15 +374,15 @@ class RoomMemberHandler(BaseHandler):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self.remote_join(
- remote_room_hosts, room_id, target, content
+ ret = yield self._remote_join(
+ requester, remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@@ -286,20 +396,10 @@ class RoomMemberHandler(BaseHandler):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
-
- try:
- ret = yield self.reject_remote_invite(
- target.to_string(), room_id, remote_room_hosts
- )
- defer.returnValue(ret)
- except SynapseError as e:
- logger.warn("Failed to reject invite: %s", e)
-
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
-
- defer.returnValue({})
+ res = yield self._remote_reject_invite(
+ requester, remote_room_hosts, room_id, target,
+ )
+ defer.returnValue(res)
res = yield self._local_membership_update(
requester=requester,
@@ -308,7 +408,7 @@ class RoomMemberHandler(BaseHandler):
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
- prev_event_ids=latest_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
content=content,
)
defer.returnValue(res)
@@ -354,8 +454,9 @@ class RoomMemberHandler(BaseHandler):
else:
requester = synapse.types.create_requester(target_user)
- message_handler = self.hs.get_handlers().message_handler
- prev_event = yield message_handler.deduplicate_state_event(event, context)
+ prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if prev_event is not None:
return
@@ -367,7 +468,12 @@ class RoomMemberHandler(BaseHandler):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- yield message_handler.handle_new_client_event(
+ if event.membership not in (Membership.LEAVE, Membership.BAN):
+ is_blocked = yield self.store.is_room_blocked(room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -389,12 +495,12 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target_user, room_id)
+ yield self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target_user, room_id)
+ yield self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@@ -428,7 +534,7 @@ class RoomMemberHandler(BaseHandler):
Raises:
SynapseError if room alias could not be found.
"""
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
@@ -440,7 +546,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
- def get_inviter(self, user_id, room_id):
+ def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
@@ -459,6 +565,16 @@ class RoomMemberHandler(BaseHandler):
requester,
txn_id
):
+ if self.config.block_non_admin_invites:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
invitee = yield self._lookup_3pid(
id_server, medium, address
)
@@ -496,7 +612,7 @@ class RoomMemberHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.hs.get_simple_http_client().get_json(
+ data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
@@ -507,7 +623,7 @@ class RoomMemberHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- self.verify_any_signature(data, id_server)
+ yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
@@ -515,11 +631,11 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def verify_any_signature(self, data, server_hostname):
+ def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.hs.get_simple_http_client().get_json(
+ key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
@@ -544,7 +660,7 @@ class RoomMemberHandler(BaseHandler):
user,
txn_id
):
- room_state = yield self.hs.get_state_handler().get_current_state(room_id)
+ room_state = yield self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -575,6 +691,7 @@ class RoomMemberHandler(BaseHandler):
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
+ requester=requester,
id_server=id_server,
medium=medium,
address=address,
@@ -589,8 +706,7 @@ class RoomMemberHandler(BaseHandler):
)
)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -612,6 +728,7 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
+ requester,
id_server,
medium,
address,
@@ -628,6 +745,7 @@ class RoomMemberHandler(BaseHandler):
Asks an identity server for a third party invite.
Args:
+ requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
@@ -669,24 +787,20 @@ class RoomMemberHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
- if self.hs.config.invite_3pid_guest:
- registration_handler = self.hs.get_handlers().registration_handler
- guest_access_token = yield registration_handler.guest_access_token_for(
+ if self.config.invite_3pid_guest:
+ guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ requester=requester,
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
- guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
- guest_access_token
- )
-
invite_config.update({
"guest_access_token": guest_access_token,
- "guest_user_id": guest_user_info["user"].to_string(),
+ "guest_user_id": guest_user_id,
})
- data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
+ data = yield self.simple_http_client.post_urlencoded_get_json(
is_url,
invite_config
)
@@ -709,25 +823,6 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
- def forget(self, user, room_id):
- user_id = user.to_string()
-
- member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
- )
- membership = member.membership if member else None
-
- if membership is not None and membership != Membership.LEAVE:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
-
- if membership:
- yield self.store.forget(user_id, room_id)
-
- @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
@@ -735,10 +830,11 @@ class RoomMemberHandler(BaseHandler):
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
- for (etype, state_key), event_id in current_state_ids.items():
+ for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
+ event_id = current_state_ids[(etype, state_key)]
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
@@ -747,3 +843,102 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(True)
defer.returnValue(False)
+
+
+class RoomMemberMasterHandler(RoomMemberHandler):
+ def __init__(self, hs):
+ super(RoomMemberMasterHandler, self).__init__(hs)
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("user_joined_room")
+ self.distributor.declare("user_left_room")
+
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ # filter ourselves out of remote_room_hosts: do_invite_join ignores it
+ # and if it is the only entry we'd like to return a 404 rather than a
+ # 500.
+
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ # We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+ yield self.federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user.to_string(),
+ content,
+ )
+ yield self._user_joined_room(user, room_id)
+
+ @defer.inlineCallbacks
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ fed_handler = self.federation_handler
+ try:
+ ret = yield fed_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ target.to_string(),
+ )
+ defer.returnValue(ret)
+ except Exception as e:
+ # if we were unable to reject the exception, just mark
+ # it as rejected on our end and plough ahead.
+ #
+ # The 'except' clause is very broad, but we need to
+ # capture everything from DNS failures upwards
+ #
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ target.to_string(), room_id
+ )
+ defer.returnValue({})
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ rg = self.registration_handler
+ return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return user_joined_room(self.distributor, target, room_id)
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return user_left_room(self.distributor, target, room_id)
+
+ @defer.inlineCallbacks
+ def forget(self, user, room_id):
+ user_id = user.to_string()
+
+ member = yield self.state_handler.get_current_state(
+ room_id=room_id,
+ event_type=EventTypes.Member,
+ state_key=user_id
+ )
+ membership = member.membership if member else None
+
+ if membership is not None and membership not in [
+ Membership.LEAVE, Membership.BAN
+ ]:
+ raise SynapseError(400, "User %s in room %s" % (
+ user_id, room_id
+ ))
+
+ if membership:
+ yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
new file mode 100644
index 0000000000..493aec1e48
--- /dev/null
+++ b/synapse/handlers/room_member_worker.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.room_member import RoomMemberHandler
+from synapse.replication.http.membership import (
+ remote_join, remote_reject_invite, get_or_register_3pid_guest,
+ notify_user_membership_change,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberWorkerHandler(RoomMemberHandler):
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ ret = yield remote_join(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=user.to_string(),
+ content=content,
+ )
+
+ yield self._user_joined_room(user, room_id)
+
+ defer.returnValue(ret)
+
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ return remote_reject_invite(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=target.to_string(),
+ )
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="joined",
+ )
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="left",
+ )
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ return get_or_register_3pid_guest(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ medium=medium,
+ address=address,
+ inviter_user_id=inviter_user_id,
+ )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index df75d70fac..9772ed1a0e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -61,7 +61,7 @@ class SearchHandler(BaseHandler):
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
- except:
+ except Exception:
raise SynapseError(400, "Invalid batch")
try:
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
new file mode 100644
index 0000000000..e057ae54c9
--- /dev/null
+++ b/synapse/handlers/set_password.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError, SynapseError
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SetPasswordHandler(BaseHandler):
+ """Handler which deals with changing user account passwords"""
+ def __init__(self, hs):
+ super(SetPasswordHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def set_password(self, user_id, newpassword, requester=None):
+ password_hash = yield self._auth_handler.hash(newpassword)
+
+ except_device_id = requester.device_id if requester else None
+ except_access_token_id = requester.access_token_id if requester else None
+
+ try:
+ yield self.store.user_set_password_hash(user_id, password_hash)
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+ raise e
+
+ # we want to log out all of the user's other sessions. First delete
+ # all his other devices.
+ yield self._device_handler.delete_all_devices_for_user(
+ user_id, except_device_id=except_device_id,
+ )
+
+ # and now delete any access tokens which weren't associated with
+ # devices (or were associated with this device).
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id,
+ )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d7dcd1ce5b..b52e4c2aff 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
+from synapse.types import RoomStreamToken
from twisted.internet import defer
@@ -51,6 +52,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
@@ -75,6 +77,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+ __bool__ = __nonzero__ # python3
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
@@ -94,6 +97,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
or self.state
or self.account_data
)
+ __bool__ = __nonzero__ # python3
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@@ -105,6 +109,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+ __bool__ = __nonzero__ # python3
+
+
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+ "join",
+ "invite",
+ "leave",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.join or self.invite or self.leave)
+ __bool__ = __nonzero__ # python3
+
+
+class DeviceLists(collections.namedtuple("DeviceLists", [
+ "changed", # list of user_ids whose devices may have changed
+ "left", # list of user_ids whose devices we no longer track
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.changed or self.left)
+ __bool__ = __nonzero__ # python3
class SyncResult(collections.namedtuple("SyncResult", [
@@ -116,6 +144,9 @@ class SyncResult(collections.namedtuple("SyncResult", [
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
+ "device_one_time_keys_count", # Dict of algorithm to count for one time keys
+ # for this device
+ "groups",
])):
__slots__ = []
@@ -131,8 +162,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or
self.account_data or
self.to_device or
- self.device_lists
+ self.device_lists or
+ self.groups
)
+ __bool__ = __nonzero__ # python3
class SyncHandler(object):
@@ -143,7 +176,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(hs)
+ self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
@@ -154,15 +187,11 @@ class SyncHandler(object):
Returns:
A Deferred SyncResult.
"""
- result = self.response_cache.get(sync_config.request_key)
- if not result:
- result = self.response_cache.set(
- sync_config.request_key,
- self._wait_for_sync_for_user(
- sync_config, since_token, timeout, full_state
- )
- )
- return result
+ return self.response_cache.wrap(
+ sync_config.request_key,
+ self._wait_for_sync_for_user,
+ sync_config, since_token, timeout, full_state,
+ )
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -209,10 +238,10 @@ class SyncHandler(object):
defer.returnValue(rules)
@defer.inlineCallbacks
- def ephemeral_by_room(self, sync_config, now_token, since_token=None):
+ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
- sync_config (SyncConfig): The flags, filters and user for the sync.
+ sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
last synced.
@@ -222,11 +251,12 @@ class SyncHandler(object):
typing events for that room.
"""
+ sync_config = sync_result_builder.sync_config
+
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
- rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
- room_ids = [room.room_id for room in rooms]
+ room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
@@ -288,10 +318,20 @@ class SyncHandler(object):
if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(current_state_ids.itervalues())
+
recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
recents,
+ always_include_ids=current_state_ids,
)
else:
recents = []
@@ -323,10 +363,20 @@ class SyncHandler(object):
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in loaded_recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(current_state_ids.itervalues())
+
loaded_recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
loaded_recents,
+ always_include_ids=current_state_ids,
)
loaded_recents.extend(recents)
recents = loaded_recents
@@ -520,10 +570,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+ else:
+ joined_room_ids = yield self.get_rooms_for_user_at(
+ user_id, now_token.room_stream_id,
+ )
+
sync_result_builder = SyncResultBuilder(
sync_config, full_state,
since_token=since_token,
now_token=now_token,
+ joined_room_ids=joined_room_ids,
)
account_data_by_room = yield self._generate_sync_entry_for_account_data(
@@ -533,7 +595,8 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users = res
+ newly_joined_rooms, newly_joined_users, _, _ = res
+ _, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
since_token is None and
@@ -547,9 +610,22 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
- sync_result_builder
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_users=newly_joined_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
)
+ device_id = sync_config.device_id
+ one_time_key_counts = {}
+ if device_id:
+ one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+
+ yield self._generate_sync_entry_for_groups(sync_result_builder)
+
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -558,31 +634,103 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
+ groups=sync_result_builder.groups,
+ device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
+ @measure_func("_generate_sync_entry_for_groups")
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_groups(self, sync_result_builder):
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+
+ if since_token and since_token.groups_key:
+ results = yield self.store.get_groups_changes_for_user(
+ user_id, since_token.groups_key, now_token.groups_key,
+ )
+ else:
+ results = yield self.store.get_all_groups_for_user(
+ user_id, now_token.groups_key,
+ )
+
+ invited = {}
+ joined = {}
+ left = {}
+ for result in results:
+ membership = result["membership"]
+ group_id = result["group_id"]
+ gtype = result["type"]
+ content = result["content"]
+
+ if membership == "join":
+ if gtype == "membership":
+ # TODO: Add profile
+ content.pop("membership", None)
+ joined[group_id] = content["content"]
+ else:
+ joined.setdefault(group_id, {})[gtype] = content
+ elif membership == "invite":
+ if gtype == "membership":
+ content.pop("membership", None)
+ invited[group_id] = content["content"]
+ else:
+ if gtype == "membership":
+ left[group_id] = content["content"]
+
+ sync_result_builder.groups = GroupsSyncResult(
+ join=joined,
+ invite=invited,
+ leave=left,
+ )
+
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder):
+ def _generate_sync_entry_for_device_list(self, sync_result_builder,
+ newly_joined_rooms, newly_joined_users,
+ newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = set(r.room_id for r in rooms)
-
- user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
- for other_user_id in changed:
- other_rooms = yield self.store.get_rooms_for_user(other_user_id)
- if room_ids.intersection(e.room_id for e in other_rooms):
- user_ids_changed.add(other_user_id)
- defer.returnValue(user_ids_changed)
+ # TODO: Be more clever than this, i.e. remove users who we already
+ # share a room with?
+ for room_id in newly_joined_rooms:
+ joined_users = yield self.state.get_current_user_in_room(room_id)
+ newly_joined_users.update(joined_users)
+
+ for room_id in newly_left_rooms:
+ left_users = yield self.state.get_current_user_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # TODO: Check that these users are actually new, i.e. either they
+ # weren't in the previous sync *or* they left and rejoined.
+ changed.update(newly_joined_users)
+
+ if not changed and not newly_left_users:
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=newly_left_users,
+ ))
+
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
+
+ defer.returnValue(DeviceLists(
+ changed=users_who_share_room & changed,
+ left=set(newly_left_users) - users_who_share_room,
+ ))
else:
- defer.returnValue([])
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=[],
+ ))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -609,14 +757,14 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
- logger.info("Deleted %d to-device messages up to %d",
- deleted, since_stream_id)
+ logger.debug("Deleted %d to-device messages up to %d",
+ deleted, since_stream_id)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
- logger.info(
+ logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), since_stream_id, stream_id, now_token.to_device_key
)
@@ -721,14 +869,14 @@ class SyncHandler(object):
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
- states = yield self.presence_handler.get_states(
- extra_users_ids,
- as_event=True,
- )
- presence.extend(states)
+ if extra_users_ids:
+ states = yield self.presence_handler.get_states(
+ extra_users_ids,
+ )
+ presence.extend(states)
- # Deduplicate the presence entries so that there's at most one per user
- presence = {p["content"]["user_id"]: p for p in presence}.values()
+ # Deduplicate the presence entries so that there's at most one per user
+ presence = {p.user_id: p for p in presence}.values()
presence = sync_config.filter_collection.filter_presence(
presence
@@ -746,8 +894,8 @@ class SyncHandler(object):
account_data_by_room(dict): Dictionary of per room account data
Returns:
- Deferred(tuple): Returns a 2-tuple of
- `(newly_joined_rooms, newly_joined_users)`
+ Deferred(tuple): Returns a 4-tuple of
+ `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -759,12 +907,27 @@ class SyncHandler(object):
ephemeral_by_room = {}
else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_result_builder.sync_config,
+ sync_result_builder,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
)
sync_result_builder.now_token = now_token
+ # We check up front if anything has changed, if it hasn't then there is
+ # no point in going futher.
+ since_token = sync_result_builder.since_token
+ if not sync_result_builder.full_state:
+ if since_token and not ephemeral_by_room and not account_data_by_room:
+ have_changed = yield self._have_rooms_changed(sync_result_builder)
+ if not have_changed:
+ tags_by_room = yield self.store.get_updated_tags(
+ user_id,
+ since_token.account_data_key,
+ )
+ if not tags_by_room:
+ logger.debug("no-oping sync")
+ defer.returnValue(([], [], [], []))
+
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id,
)
@@ -774,17 +937,17 @@ class SyncHandler(object):
else:
ignored_users = frozenset()
- if sync_result_builder.since_token:
+ if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
- room_entries, invited, newly_joined_rooms = res
+ room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
- user_id,
- sync_result_builder.since_token.account_data_key,
+ user_id, since_token.account_data_key,
)
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res
+ newly_left_rooms = []
tags_by_room = yield self.store.get_tags_for_user(user_id)
@@ -805,17 +968,55 @@ class SyncHandler(object):
# Now we want to get any newly joined users
newly_joined_users = set()
- if sync_result_builder.since_token:
+ newly_left_users = set()
+ if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.values()
+ joined_sync.timeline.events, joined_sync.state.itervalues()
)
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
newly_joined_users.add(event.state_key)
+ else:
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership == Membership.JOIN:
+ newly_left_users.add(event.state_key)
+
+ newly_left_users -= newly_joined_users
+
+ defer.returnValue((
+ newly_joined_rooms,
+ newly_joined_users,
+ newly_left_rooms,
+ newly_left_users,
+ ))
+
+ @defer.inlineCallbacks
+ def _have_rooms_changed(self, sync_result_builder):
+ """Returns whether there may be any new events that should be sent down
+ the sync. Returns True if there are.
+ """
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
- defer.returnValue((newly_joined_rooms, newly_joined_users))
+ assert since_token
+
+ # Get a list of membership change events that have happened.
+ rooms_changed = yield self.store.get_membership_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key
+ )
+
+ if rooms_changed:
+ defer.returnValue(True)
+
+ stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+ for room_id in sync_result_builder.joined_room_ids:
+ if self.store.has_room_changed_since(room_id, stream_id):
+ defer.returnValue(True)
+ defer.returnValue(False)
@defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users):
@@ -836,14 +1037,6 @@ class SyncHandler(object):
assert since_token
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- rooms = yield self.store.get_rooms_for_user(user_id)
- joined_room_ids = set(r.room_id for r in rooms)
-
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
@@ -854,16 +1047,29 @@ class SyncHandler(object):
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
+ newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in mem_change_events_by_room_id.items():
+ for room_id, events in mem_change_events_by_room_id.iteritems():
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
- if room_id in joined_room_ids or has_join:
+
+ old_state_ids = None
+ if room_id in sync_result_builder.joined_room_ids and non_joins:
+ # Always include if the user (re)joined the room, especially
+ # important so that device list changes are calculated correctly.
+ # If there are non join member events, but we are still in the room,
+ # then the user must have left and joined
+ newly_joined_rooms.append(room_id)
+
+ # User is in the room so we don't need to do the invite/leave checks
+ continue
+
+ if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
@@ -874,12 +1080,33 @@ class SyncHandler(object):
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
- if room_id in joined_room_ids:
- continue
+ # If user is in the room then we don't need to do the invite/leave checks
+ if room_id in sync_result_builder.joined_room_ids:
+ continue
if not non_joins:
continue
+ # Check if we have left the room. This can either be because we were
+ # joined before *or* that we since joined and then left.
+ if events[-1].membership != Membership.JOIN:
+ if has_join:
+ newly_left_rooms.append(room_id)
+ else:
+ if not old_state_ids:
+ old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_mem_ev_id = old_state_ids.get(
+ (EventTypes.Member, user_id),
+ None,
+ )
+ old_mem_ev = None
+ if old_mem_ev_id:
+ old_mem_ev = yield self.store.get_event(
+ old_mem_ev_id, allow_none=True
+ )
+ if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
+ newly_left_rooms.append(room_id)
+
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
@@ -921,7 +1148,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
- room_ids=joined_room_ids,
+ room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
@@ -929,7 +1156,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
- for room_id in joined_room_ids:
+ for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
@@ -957,7 +1184,7 @@ class SyncHandler(object):
upto_token=since_token,
))
- defer.returnValue((room_entries, invited, newly_joined_rooms))
+ defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1137,6 +1364,54 @@ class SyncHandler(object):
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
+ @defer.inlineCallbacks
+ def get_rooms_for_user_at(self, user_id, stream_ordering):
+ """Get set of joined rooms for a user at the given stream ordering.
+
+ The stream ordering *must* be recent, otherwise this may throw an
+ exception if older than a month. (This function is called with the
+ current token, which should be perfectly fine).
+
+ Args:
+ user_id (str)
+ stream_ordering (int)
+
+ ReturnValue:
+ Deferred[frozenset[str]]: Set of room_ids the user is in at given
+ stream_ordering.
+ """
+ joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
+ user_id,
+ )
+
+ joined_room_ids = set()
+
+ # We need to check that the stream ordering of the join for each room
+ # is before the stream_ordering asked for. This might not be the case
+ # if the user joins a room between us getting the current token and
+ # calling `get_rooms_for_user_with_stream_ordering`.
+ # If the membership's stream ordering is after the given stream
+ # ordering, we need to go and work out if the user was in the room
+ # before.
+ for room_id, membership_stream_ordering in joined_rooms:
+ if membership_stream_ordering <= stream_ordering:
+ joined_room_ids.add(room_id)
+ continue
+
+ logger.info("User joined room after current token: %s", room_id)
+
+ extrems = yield self.store.get_forward_extremeties_for_room(
+ room_id, stream_ordering,
+ )
+ users_in_room = yield self.state.get_current_user_in_room(
+ room_id, extrems,
+ )
+ if user_id in users_in_room:
+ joined_room_ids.add(room_id)
+
+ joined_room_ids = frozenset(joined_room_ids)
+ defer.returnValue(joined_room_ids)
+
def _action_has_highlight(actions):
for action in actions:
@@ -1186,7 +1461,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user"
- def __init__(self, sync_config, full_state, since_token, now_token):
+ def __init__(self, sync_config, full_state, since_token, now_token,
+ joined_room_ids):
"""
Args:
sync_config(SyncConfig)
@@ -1198,6 +1474,7 @@ class SyncResultBuilder(object):
self.full_state = full_state
self.since_token = since_token
self.now_token = now_token
+ self.joined_room_ids = joined_room_ids
self.presence = []
self.account_data = []
@@ -1205,6 +1482,8 @@ class SyncResultBuilder(object):
self.invited = []
self.archived = []
self.device = []
+ self.groups = None
+ self.to_device = []
class RoomSyncResultBuilder(object):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0eea7f8f9c..5d9736e88f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
@@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id
import logging
from collections import namedtuple
-import ujson as json
logger = logging.getLogger(__name__)
@@ -57,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
+ hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -90,7 +89,7 @@ class TypingHandler(object):
until = self._member_typing_until.get(member, None)
if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id)
- preserve_fn(self._stopped_typing)(member)
+ self._stopped_typing(member)
continue
# Check if we need to resend a keep alive over federation for this
@@ -98,7 +97,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- preserve_fn(self._push_remote)(
+ run_in_background(
+ self._push_remote,
member=member,
typing=True
)
@@ -148,7 +148,7 @@ class TypingHandler(object):
# No point sending another notification
defer.returnValue(None)
- yield self._push_update(
+ self._push_update(
member=member,
typing=True,
)
@@ -172,7 +172,7 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id)
- yield self._stopped_typing(member)
+ self._stopped_typing(member)
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -181,7 +181,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member)
- @defer.inlineCallbacks
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
@@ -190,16 +189,15 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
- yield self._push_update(
+ self._push_update(
member=member,
typing=False,
)
- @defer.inlineCallbacks
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- yield self._push_remote(member, typing)
+ run_in_background(self._push_remote, member, typing)
self._push_update_local(
member=member,
@@ -208,28 +206,31 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
- users = yield self.state.get_current_user_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
+ try:
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
- )
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- for domain in set(get_domain_from_id(u) for u in users):
- if domain != self.server_name:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
@@ -288,11 +289,13 @@ class TypingHandler(object):
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
- typing_bytes = json.dumps(list(typing), ensure_ascii=False)
- rows.append((serial, room_id, typing_bytes))
+ rows.append((serial, room_id, list(typing)))
rows.sort()
return rows
+ def get_current_token(self):
+ return self._latest_room_serial
+
class TypingNotificationEventSource(object):
def __init__(self, hs):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
new file mode 100644
index 0000000000..714f0195c8
--- /dev/null
+++ b/synapse/handlers/user_directory.py
@@ -0,0 +1,681 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.storage.roommember import ProfileInfo
+from synapse.util.metrics import Measure
+from synapse.util.async import sleep
+from synapse.types import get_localpart_from_id
+
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryHandler(object):
+ """Handles querying of and keeping updated the user_directory.
+
+ N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
+
+ The user directory is filled with users who this server can see are joined to a
+ world_readable or publically joinable room. We keep a database table up to date
+ by streaming changes of the current state and recalculating whether users should
+ be in the directory or not when necessary.
+
+ For each user in the directory we also store a room_id which is public and that the
+ user is joined to. This allows us to ignore history_visibility and join_rules changes
+ for that user in all other public rooms, as we know they'll still be in at least
+ one public room.
+ """
+
+ INITIAL_ROOM_SLEEP_MS = 50
+ INITIAL_ROOM_SLEEP_COUNT = 100
+ INITIAL_ROOM_BATCH_SIZE = 100
+ INITIAL_USER_SLEEP_MS = 10
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.update_user_directory = hs.config.update_user_directory
+ self.search_all_users = hs.config.user_directory_search_all_users
+
+ # When start up for the first time we need to populate the user_directory.
+ # This is a set of user_id's we've inserted already
+ self.initially_handled_users = set()
+ self.initially_handled_users_in_public = set()
+
+ self.initially_handled_users_share = set()
+ self.initially_handled_users_share_private_room = set()
+
+ # The current position in the current_state_delta stream
+ self.pos = None
+
+ # Guard to ensure we only process deltas one at a time
+ self._is_processing = False
+
+ if self.update_user_directory:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # We kick this off so that we don't have to wait for a change before
+ # we start populating the user directory
+ self.clock.call_later(0, self.notify_new_event)
+
+ def search_users(self, user_id, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ return self.store.search_user_dir(user_id, search_term, limit)
+
+ @defer.inlineCallbacks
+ def notify_new_event(self):
+ """Called when there may be more deltas to process
+ """
+ if not self.update_user_directory:
+ return
+
+ if self._is_processing:
+ return
+
+ self._is_processing = True
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
+ @defer.inlineCallbacks
+ def handle_local_profile_change(self, user_id, profile):
+ """Called to update index of our local user profiles when they change
+ irrespective of any rooms the user may be in.
+ """
+ yield self.store.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url, None,
+ )
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # If self.pos is None then means we haven't fetched it from DB
+ if self.pos is None:
+ self.pos = yield self.store.get_user_directory_stream_pos()
+
+ # If still None then we need to do the initial fill of directory
+ if self.pos is None:
+ yield self._do_initial_spam()
+ self.pos = yield self.store.get_user_directory_stream_pos()
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "user_dir_delta"):
+ deltas = yield self.store.get_current_state_deltas(self.pos)
+ if not deltas:
+ return
+
+ logger.info("Handling %d state deltas", len(deltas))
+ yield self._handle_deltas(deltas)
+
+ self.pos = deltas[-1]["stream_id"]
+ yield self.store.update_user_directory_stream_pos(self.pos)
+
+ @defer.inlineCallbacks
+ def _do_initial_spam(self):
+ """Populates the user_directory from the current state of the DB, used
+ when synapse first starts with user_directory support
+ """
+ new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
+
+ # Delete any existing entries just in case there are any
+ yield self.store.delete_all_from_user_dir()
+
+ # We process by going through each existing room at a time.
+ room_ids = yield self.store.get_all_rooms()
+
+ logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
+ num_processed_rooms = 0
+
+ for room_id in room_ids:
+ logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
+ yield self._handle_initial_room(room_id)
+ num_processed_rooms += 1
+ yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+
+ logger.info("Processed all rooms.")
+
+ if self.search_all_users:
+ num_processed_users = 0
+ user_ids = yield self.store.get_all_local_users()
+ logger.info("Doing initial update of user directory. %d users", len(user_ids))
+ for user_id in user_ids:
+ # We add profiles for all users even if they don't match the
+ # include pattern, just in case we want to change it in future
+ logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
+ yield self._handle_local_user(user_id)
+ num_processed_users += 1
+ yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+
+ logger.info("Processed all users")
+
+ self.initially_handled_users = None
+ self.initially_handled_users_in_public = None
+ self.initially_handled_users_share = None
+ self.initially_handled_users_share_private_room = None
+
+ yield self.store.update_user_directory_stream_pos(new_pos)
+
+ @defer.inlineCallbacks
+ def _handle_initial_room(self, room_id):
+ """Called when we initially fill out user_directory one room at a time
+ """
+ is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
+ if not is_in_room:
+ return
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ user_ids = set(users_with_profile)
+ unhandled_users = user_ids - self.initially_handled_users
+
+ yield self.store.add_profiles_to_user_dir(
+ room_id, {
+ user_id: users_with_profile[user_id] for user_id in unhandled_users
+ }
+ )
+
+ self.initially_handled_users |= unhandled_users
+
+ if is_public:
+ yield self.store.add_users_to_public_room(
+ room_id,
+ user_ids=user_ids - self.initially_handled_users_in_public
+ )
+ self.initially_handled_users_in_public |= user_ids
+
+ # We now go and figure out the new users who share rooms with user entries
+ # We sleep aggressively here as otherwise it can starve resources.
+ # We also batch up inserts/updates, but try to avoid too many at once.
+ to_insert = set()
+ to_update = set()
+ count = 0
+ for user_id in user_ids:
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+
+ if not self.is_mine_id(user_id):
+ count += 1
+ continue
+
+ if self.store.get_if_app_services_interested_in_user(user_id):
+ count += 1
+ continue
+
+ for other_user_id in user_ids:
+ if user_id == other_user_id:
+ continue
+
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+ count += 1
+
+ user_set = (user_id, other_user_id)
+
+ if user_set in self.initially_handled_users_share_private_room:
+ continue
+
+ if user_set in self.initially_handled_users_share:
+ if is_public:
+ continue
+ to_update.add(user_set)
+ else:
+ to_insert.add(user_set)
+
+ if is_public:
+ self.initially_handled_users_share.add(user_set)
+ else:
+ self.initially_handled_users_share_private_room.add(user_set)
+
+ if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+ to_insert.clear()
+
+ if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+ to_update.clear()
+
+ if to_insert:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+ to_insert.clear()
+
+ if to_update:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+ to_update.clear()
+
+ @defer.inlineCallbacks
+ def _handle_deltas(self, deltas):
+ """Called with the state deltas to process
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ prev_event_id = delta["prev_event_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ # For join rule and visibility changes we need to check if the room
+ # may have become public or not and add/remove the users in said room
+ if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
+ yield self._handle_room_publicity_change(
+ room_id, prev_event_id, event_id, typ,
+ )
+ elif typ == EventTypes.Member:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="membership",
+ public_value=Membership.JOIN,
+ )
+
+ if change is None:
+ # Handle any profile changes
+ yield self._handle_profile_change(
+ state_key, room_id, prev_event_id, event_id,
+ )
+ continue
+
+ if not change:
+ # Need to check if the server left the room entirely, if so
+ # we might need to remove all the users in that room
+ is_in_room = yield self.store.is_host_joined(
+ room_id, self.server_name,
+ )
+ if not is_in_room:
+ logger.info("Server left room: %r", room_id)
+ # Fetch all the users that we marked as being in user
+ # directory due to being in the room and then check if
+ # need to remove those users or not
+ user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
+ for user_id in user_ids:
+ yield self._handle_remove_user(room_id, user_id)
+ return
+ else:
+ logger.debug("Server is still in room: %r", room_id)
+
+ if change: # The user joined
+ event = yield self.store.get_event(event_id, allow_none=True)
+ profile = ProfileInfo(
+ avatar_url=event.content.get("avatar_url"),
+ display_name=event.content.get("displayname"),
+ )
+
+ yield self._handle_new_user(room_id, state_key, profile)
+ else: # The user left
+ yield self._handle_remove_user(room_id, state_key)
+ else:
+ logger.debug("Ignoring irrelevant type: %r", typ)
+
+ @defer.inlineCallbacks
+ def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
+ """Handle a room having potentially changed from/to world_readable/publically
+ joinable.
+
+ Args:
+ room_id (str)
+ prev_event_id (str|None): The previous event before the state change
+ event_id (str|None): The new event after the state change
+ typ (str): Type of the event
+ """
+ logger.debug("Handling change for %s: %s", typ, room_id)
+
+ if typ == EventTypes.RoomHistoryVisibility:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="history_visibility",
+ public_value="world_readable",
+ )
+ elif typ == EventTypes.JoinRules:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="join_rule",
+ public_value=JoinRules.PUBLIC,
+ )
+ else:
+ raise Exception("Invalid event type")
+ # If change is None, no change. True => become world_readable/public,
+ # False => was world_readable/public
+ if change is None:
+ logger.debug("No change")
+ return
+
+ # There's been a change to or from being world readable.
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ logger.debug("Change: %r, is_public: %r", change, is_public)
+
+ if change and not is_public:
+ # If we became world readable but room isn't currently public then
+ # we ignore the change
+ return
+ elif not change and is_public:
+ # If we stopped being world readable but are still public,
+ # ignore the change
+ return
+
+ if change:
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ for user_id, profile in users_with_profile.iteritems():
+ yield self._handle_new_user(room_id, user_id, profile)
+ else:
+ users = yield self.store.get_users_in_public_due_to_room(room_id)
+ for user_id in users:
+ yield self._handle_remove_user(room_id, user_id)
+
+ @defer.inlineCallbacks
+ def _handle_local_user(self, user_id):
+ """Adds a new local roomless user into the user_directory_search table.
+ Used to populate up the user index when we have an
+ user_directory_search_all_users specified.
+ """
+ logger.debug("Adding new local user to dir, %r", user_id)
+
+ profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
+
+ row = yield self.store.get_user_in_directory(user_id)
+ if not row:
+ yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+
+ @defer.inlineCallbacks
+ def _handle_new_user(self, room_id, user_id, profile):
+ """Called when we might need to add user to directory
+
+ Args:
+ room_id (str): room_id that user joined or started being public
+ user_id (str)
+ """
+ logger.debug("Adding new user to dir, %r", user_id)
+
+ row = yield self.store.get_user_in_directory(user_id)
+ if not row:
+ yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ if is_public:
+ row = yield self.store.get_user_in_public_room(user_id)
+ if not row:
+ yield self.store.add_users_to_public_room(room_id, [user_id])
+ else:
+ logger.debug("Not adding new user to public dir, %r", user_id)
+
+ # Now we update users who share rooms with users. We do this by getting
+ # all the current users in the room and seeing which aren't already
+ # marked in the database as sharing with `user_id`
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ to_insert = set()
+ to_update = set()
+
+ is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+
+ # First, if they're our user then we need to update for every user
+ if self.is_mine_id(user_id) and not is_appservice:
+ # Returns a map of other_user_id -> shared_private. We only need
+ # to update mappings if for users that either don't share a room
+ # already (aren't in the map) or, if the room is private, those that
+ # only share a public room.
+ user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
+ user_id
+ )
+
+ for other_user_id in users_with_profile:
+ if user_id == other_user_id:
+ continue
+
+ shared_is_private = user_ids_shared.get(other_user_id)
+ if shared_is_private is True:
+ # We've already marked in the database they share a private room
+ continue
+ elif shared_is_private is False:
+ # They already share a public room, so only update if this is
+ # a private room
+ if not is_public:
+ to_update.add((user_id, other_user_id))
+ elif shared_is_private is None:
+ # This is the first time they both share a room
+ to_insert.add((user_id, other_user_id))
+
+ # Next we need to update for every local user in the room
+ for other_user_id in users_with_profile:
+ if user_id == other_user_id:
+ continue
+
+ is_appservice = self.store.get_if_app_services_interested_in_user(
+ other_user_id
+ )
+ if self.is_mine_id(other_user_id) and not is_appservice:
+ shared_is_private = yield self.store.get_if_users_share_a_room(
+ other_user_id, user_id,
+ )
+ if shared_is_private is True:
+ # We've already marked in the database they share a private room
+ continue
+ elif shared_is_private is False:
+ # They already share a public room, so only update if this is
+ # a private room
+ if not is_public:
+ to_update.add((other_user_id, user_id))
+ elif shared_is_private is None:
+ # This is the first time they both share a room
+ to_insert.add((other_user_id, user_id))
+
+ if to_insert:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+
+ if to_update:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+
+ @defer.inlineCallbacks
+ def _handle_remove_user(self, room_id, user_id):
+ """Called when we might need to remove user to directory
+
+ Args:
+ room_id (str): room_id that user left or stopped being public that
+ user_id (str)
+ """
+ logger.debug("Maybe removing user %r", user_id)
+
+ row = yield self.store.get_user_in_directory(user_id)
+ update_user_dir = row and row["room_id"] == room_id
+
+ row = yield self.store.get_user_in_public_room(user_id)
+ update_user_in_public = row and row["room_id"] == room_id
+
+ if (update_user_in_public or update_user_dir):
+ # XXX: Make this faster?
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ for j_room_id in rooms:
+ if (not update_user_in_public and not update_user_dir):
+ break
+
+ is_in_room = yield self.store.is_host_joined(
+ j_room_id, self.server_name,
+ )
+
+ if not is_in_room:
+ continue
+
+ if update_user_dir:
+ update_user_dir = False
+ yield self.store.update_user_in_user_dir(user_id, j_room_id)
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ j_room_id
+ )
+
+ if update_user_in_public and is_public:
+ yield self.store.update_user_in_public_user_list(user_id, j_room_id)
+ update_user_in_public = False
+
+ if update_user_dir:
+ yield self.store.remove_from_user_dir(user_id)
+ elif update_user_in_public:
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ # Now handle users_who_share_rooms.
+
+ # Get a list of user tuples that were in the DB due to this room and
+ # users (this includes tuples where the other user matches `user_id`)
+ user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
+ user_id, room_id,
+ )
+
+ for user_id, other_user_id in user_tuples:
+ # For each user tuple get a list of rooms that they still share,
+ # trying to find a private room, and update the entry in the DB
+ rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
+
+ # If they dont share a room anymore, remove the mapping
+ if not rooms:
+ yield self.store.remove_user_who_share_room(
+ user_id, other_user_id,
+ )
+ continue
+
+ found_public_share = None
+ for j_room_id in rooms:
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ j_room_id
+ )
+
+ if is_public:
+ found_public_share = j_room_id
+ else:
+ found_public_share = None
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, [(user_id, other_user_id)],
+ )
+ break
+
+ if found_public_share:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, [(user_id, other_user_id)],
+ )
+
+ @defer.inlineCallbacks
+ def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ """Check member event changes for any profile changes and update the
+ database if there are.
+ """
+ if not prev_event_id or not event_id:
+ return
+
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ event = yield self.store.get_event(event_id, allow_none=True)
+
+ if not prev_event or not event:
+ return
+
+ if event.membership != Membership.JOIN:
+ return
+
+ prev_name = prev_event.content.get("displayname")
+ new_name = event.content.get("displayname")
+
+ prev_avatar = prev_event.content.get("avatar_url")
+ new_avatar = event.content.get("avatar_url")
+
+ if prev_name != new_name or prev_avatar != new_avatar:
+ yield self.store.update_profile_in_user_dir(
+ user_id, new_name, new_avatar, room_id,
+ )
+
+ @defer.inlineCallbacks
+ def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ """Given two events check if the `key_name` field in content changed
+ from not matching `public_value` to doing so.
+
+ For example, check if `history_visibility` (`key_name`) changed from
+ `shared` to `world_readable` (`public_value`).
+
+ Returns:
+ None if the field in the events either both match `public_value`
+ or if neither do, i.e. there has been no change.
+ True if it didnt match `public_value` but now does
+ False if it did match `public_value` but now doesn't
+ """
+ prev_event = None
+ event = None
+ if prev_event_id:
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+
+ if event_id:
+ event = yield self.store.get_event(event_id, allow_none=True)
+
+ if not event and not prev_event:
+ logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
+ defer.returnValue(None)
+
+ prev_value = None
+ value = None
+
+ if prev_event:
+ prev_value = prev_event.content.get(key_name)
+
+ if event:
+ value = event.content.get(key_name)
+
+ logger.debug("prev_value: %r -> value: %r", prev_value, value)
+
+ if value == public_value and prev_value != public_value:
+ defer.returnValue(True)
+ elif value != public_value and prev_value == public_value:
+ defer.returnValue(False)
+ else:
+ defer.returnValue(None)
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..054372e179 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+ """Exception representing timeout of an outbound request"""
+ def __init__(self):
+ super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+ """Turns CancelledErrors into RequestTimedOutErrors.
+
+ For use with async.add_timeout_to_deferred
+ """
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise RequestTimedOutError()
+ return value
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
new file mode 100644
index 0000000000..343e932cb1
--- /dev/null
+++ b/synapse/http/additional_resource.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import wrap_request_handler
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+
+class AdditionalResource(Resource):
+ """Resource wrapper for additional_resources
+
+ If the user has configured additional_resources, we need to wrap the
+ handler class with a Resource so that we can map it into the resource tree.
+
+ This class is also where we wrap the request handler with logging, metrics,
+ and exception handling.
+ """
+ def __init__(self, hs, handler):
+ """Initialise AdditionalResource
+
+ The ``handler`` should return a deferred which completes when it has
+ done handling the request. It should write a response with
+ ``request.write()``, and call ``request.finish()``.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
+ function to be called to handle the request.
+ """
+ Resource.__init__(self)
+ self._handler = handler
+
+ # these are required by the request_handler wrapper
+ self.version_string = hs.version_string
+ self.clock = hs.get_clock()
+
+ def render(self, request):
+ self._async_render(request)
+ return NOT_DONE_YET
+
+ @wrap_request_handler
+ def _async_render(self, request):
+ return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2f770f5d..70a19d9b74 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +17,12 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
- CodeMessageException, SynapseError, Codes,
+ CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.logcontext import make_deferred_yieldable
import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
@@ -29,13 +33,14 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
+ HTTPConnectionPool,
)
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
-from StringIO import StringIO
+from six import StringIO
import simplejson as json
import logging
@@ -63,92 +68,139 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
+
+ pool = HTTPConnectionPool(reactor)
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ pool.cachedConnectionTimeout = 2 * 60
+
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory()
+ contextFactory=hs.get_http_client_context_factory(),
+ pool=pool,
)
self.user_agent = hs.version_string
+ self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
+ @defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
- d = preserve_context_over_fn(
- self.agent.request,
- method, uri, *args, **kwargs
- )
logger.info("Sending request %s %s", method, uri)
- def _cb(response):
+ try:
+ request_deferred = self.agent.request(
+ method, uri, *args, **kwargs
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ 60, cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(request_deferred)
+
incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
- return response
-
- def _eb(failure):
+ defer.returnValue(response)
+ except Exception as e:
incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
- method, uri, failure.type, failure.getErrorMessage()
+ method, uri, type(e).__name__, e.message
)
- return failure
+ raise e
- d.addCallbacks(_cb, _eb)
+ @defer.inlineCallbacks
+ def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ """
+ Args:
+ uri (str):
+ args (dict[str, str|List[str]]): query params
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
- return d
+ Returns:
+ Deferred[object]: parsed json
+ """
- @defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ actual_headers = {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json):
+ def post_json_get_json(self, uri, post_json, headers=None):
+ """
+
+ Args:
+ uri (str):
+ post_json (object):
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/json"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
+
+ if 200 <= response.code < 300:
+ defer.returnValue(json.loads(body))
+ else:
+ raise self._exceptionFromFailedRequest(response, body)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, uri, args={}):
+ def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -157,6 +209,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -164,11 +218,14 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
- body = yield self.get_raw(uri, args)
- defer.returnValue(json.loads(body))
+ try:
+ body = yield self.get_raw(uri, args, headers=headers)
+ defer.returnValue(json.loads(body))
+ except CodeMessageException as e:
+ raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}):
+ def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -178,6 +235,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -190,17 +249,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"PUT",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- "Content-Type": ["application/json"]
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -211,7 +274,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
- def get_raw(self, uri, args={}):
+ def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -220,6 +283,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@@ -231,46 +296,65 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
+ def _exceptionFromFailedRequest(self, response, body):
+ try:
+ jsonBody = json.loads(body)
+ errcode = jsonBody['errcode']
+ error = jsonBody['error']
+ return MatrixCodeMessageException(response.code, error, errcode)
+ except (ValueError, KeyError):
+ return CodeMessageException(response.code, body)
+
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
@defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None):
+ def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
url.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- headers = dict(response.headers.getAllRawHeaders())
+ resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -291,10 +375,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ length = yield make_deferred_yieldable(_readBodyToFile(
+ response, output_stream, max_size,
+ ))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@@ -303,7 +386,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
- defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+ defer.returnValue(
+ (length, resp_headers, response.request.absoluteURI, response.code),
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -371,7 +456,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
@@ -422,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 564ae4c10d..87a482650d 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
@@ -30,7 +29,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
@@ -224,9 +226,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -236,11 +239,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -277,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
@@ -285,26 +299,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
-
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
-
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + answer.ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 78b92cef36..4b2b85464d 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,23 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
+from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
+from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util import logcontext
+from synapse.util.logcontext import make_deferred_yieldable
+import synapse.util.retryutils
from canonicaljson import encode_canonical_json
from synapse.api.errors import (
- SynapseError, Codes, HttpResponseException,
+ SynapseError, Codes, HttpResponseException, FederationDeniedError,
)
from signedjson.sign import sign_json
@@ -39,8 +41,7 @@ import logging
import random
import sys
import urllib
-import urlparse
-
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
@@ -94,6 +95,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock()
+ self._store = hs.get_datastore()
self.version_string = hs.version_string
self._next_id = 1
@@ -103,123 +105,161 @@ class MatrixFederationHttpClient(object):
)
@defer.inlineCallbacks
- def _create_request(self, destination, method, path_bytes,
- body_callback, headers_dict={}, param_bytes=b"",
- query_bytes=b"", retry_on_dns_fail=True,
- timeout=None, long_retries=False):
- """ Creates and sends a request to the given url
- """
- headers_dict[b"User-Agent"] = [self.version_string]
- headers_dict[b"Host"] = [destination]
+ def _request(self, destination, method, path,
+ body_callback, headers_dict={}, param_bytes=b"",
+ query_bytes=b"", retry_on_dns_fail=True,
+ timeout=None, long_retries=False,
+ ignore_backoff=False,
+ backoff_on_404=False):
+ """ Creates and sends a request to the given server
+ Args:
+ destination (str): The remote server to send the HTTP request to.
+ method (str): HTTP method
+ path (str): The HTTP path
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
+ backoff_on_404 (bool): Back off if we get a 404
- url_bytes = self._create_url(
- destination, path_bytes, param_bytes, query_bytes
- )
+ Returns:
+ Deferred: resolves with the http response object on success.
- txn_id = "%s-O-%s" % (method, self._next_id)
- self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+ Fails with ``HTTPRequestException``: if we get an HTTP response
+ code >= 300.
- outbound_logger.info(
- "{%s} [%s] Sending request: %s %s",
- txn_id, destination, method, url_bytes
- )
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
- # XXX: Would be much nicer to retry only at the transaction-layer
- # (once we have reliable transactions in place)
- if long_retries:
- retries_left = MAX_LONG_RETRIES
- else:
- retries_left = MAX_SHORT_RETRIES
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
- http_url_bytes = urlparse.urlunparse(
- ("", "", path_bytes, param_bytes, query_bytes, "")
+ (May also fail with plenty of other Exceptions for things like DNS
+ failures, connection failures, SSL failures.)
+ """
+ if (
+ self.hs.config.federation_domain_whitelist and
+ destination not in self.hs.config.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(destination)
+
+ limiter = yield synapse.util.retryutils.get_retry_limiter(
+ destination,
+ self.clock,
+ self._store,
+ backoff_on_404=backoff_on_404,
+ ignore_backoff=ignore_backoff,
)
- log_result = None
- try:
- while True:
- producer = None
- if body_callback:
- producer = body_callback(method, http_url_bytes, headers_dict)
-
- try:
- def send_request():
- request_deferred = preserve_context_over_fn(
- self.agent.request,
+ destination = destination.encode("ascii")
+ path_bytes = path.encode("ascii")
+ with limiter:
+ headers_dict[b"User-Agent"] = [self.version_string]
+ headers_dict[b"Host"] = [destination]
+
+ url_bytes = self._create_url(
+ destination, path_bytes, param_bytes, query_bytes
+ )
+
+ txn_id = "%s-O-%s" % (method, self._next_id)
+ self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+
+ outbound_logger.info(
+ "{%s} [%s] Sending request: %s %s",
+ txn_id, destination, method, url_bytes
+ )
+
+ # XXX: Would be much nicer to retry only at the transaction-layer
+ # (once we have reliable transactions in place)
+ if long_retries:
+ retries_left = MAX_LONG_RETRIES
+ else:
+ retries_left = MAX_SHORT_RETRIES
+
+ http_url_bytes = urlparse.urlunparse(
+ ("", "", path_bytes, param_bytes, query_bytes, "")
+ )
+
+ log_result = None
+ try:
+ while True:
+ producer = None
+ if body_callback:
+ producer = body_callback(method, http_url_bytes, headers_dict)
+
+ try:
+ request_deferred = self.agent.request(
method,
url_bytes,
Headers(headers_dict),
producer
)
-
- return self.clock.time_bound_deferred(
+ add_timeout_to_deferred(
+ request_deferred,
+ timeout / 1000. if timeout else 60,
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(
request_deferred,
- time_out=timeout / 1000. if timeout else 60,
)
- response = yield preserve_context_over_fn(send_request)
+ log_result = "%d %s" % (response.code, response.phrase,)
+ break
+ except Exception as e:
+ if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ logger.warn(
+ "DNS Lookup failed to %s with %s",
+ destination,
+ e
+ )
+ log_result = "DNS Lookup failed to %s with %s" % (
+ destination, e
+ )
+ raise
- log_result = "%d %s" % (response.code, response.phrase,)
- break
- except Exception as e:
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
- "DNS Lookup failed to %s with %s",
+ "{%s} Sending request failed to %s: %s %s: %s",
+ txn_id,
destination,
- e
- )
- log_result = "DNS Lookup failed to %s with %s" % (
- destination, e
+ method,
+ url_bytes,
+ _flatten_response_never_received(e),
)
- raise
-
- logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
- txn_id,
- destination,
- method,
- url_bytes,
- type(e).__name__,
- _flatten_response_never_received(e),
- )
-
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
-
- if retries_left and not timeout:
- if long_retries:
- delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
- delay = min(delay, 60)
- delay *= random.uniform(0.8, 1.4)
+
+ log_result = _flatten_response_never_received(e)
+
+ if retries_left and not timeout:
+ if long_retries:
+ delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+ delay = min(delay, 60)
+ delay *= random.uniform(0.8, 1.4)
+ else:
+ delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+ delay = min(delay, 2)
+ delay *= random.uniform(0.8, 1.4)
+
+ yield sleep(delay)
+ retries_left -= 1
else:
- delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
- delay = min(delay, 2)
- delay *= random.uniform(0.8, 1.4)
-
- yield sleep(delay)
- retries_left -= 1
- else:
- raise
- finally:
- outbound_logger.info(
- "{%s} [%s] Result: %s",
- txn_id,
- destination,
- log_result,
- )
+ raise
+ finally:
+ outbound_logger.info(
+ "{%s} [%s] Result: %s",
+ txn_id,
+ destination,
+ log_result,
+ )
- if 200 <= response.code < 300:
- pass
- else:
- # :'(
- # Update transactions table?
- body = yield preserve_context_over_fn(readBody, response)
- raise HttpResponseException(
- response.code, response.phrase, body
- )
+ if 200 <= response.code < 300:
+ pass
+ else:
+ # :'(
+ # Update transactions table?
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+ raise HttpResponseException(
+ response.code, response.phrase, body
+ )
- defer.returnValue(response)
+ defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
@@ -247,14 +287,18 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, data={}, json_data_callback=None,
- long_retries=False, timeout=None):
+ def put_json(self, destination, path, args={}, data={},
+ json_data_callback=None,
+ long_retries=False, timeout=None,
+ ignore_backoff=False,
+ backoff_on_404=False):
""" Sends the specifed json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
+ args (dict): query params
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
@@ -263,11 +307,24 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
+ backoff_on_404 (bool): True if we should count a 404 response as
+ a failure of the server (and should therefore back off future
+ requests)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body. On a 4xx or 5xx error response a
- CodeMessageException is raised.
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
if not json_data_callback:
@@ -282,26 +339,30 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data)
return producer
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"PUT",
- path.encode("ascii"),
+ path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ query_bytes=encode_query_args(args),
long_retries=long_retries,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ backoff_on_404=backoff_on_404,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None):
+ timeout=None, ignore_backoff=False, args={}):
""" Sends the specifed json data using POST
Args:
@@ -314,11 +375,21 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
-
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
+ args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body. On a 4xx or 5xx error response a
- CodeMessageException is raised.
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -327,27 +398,30 @@ class MatrixFederationHttpClient(object):
)
return _JsonProducer(data)
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"POST",
- path.encode("ascii"),
+ path,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
- timeout=None):
+ timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path
Args:
@@ -359,57 +433,122 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
- Deferred: Succeeds when we get *any* HTTP response.
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
- The result of the deferred is a tuple of `(code, response)`,
- where `response` is a dict representing the decoded JSON body.
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, basestring):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"GET",
- path.encode("ascii"),
- query_bytes=query_bytes,
+ path,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
+ def delete_json(self, destination, path, long_retries=False,
+ timeout=None, ignore_backoff=False, args={}):
+ """Send a DELETE request to the remote expecting some json response
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+ """
+
+ response = yield self._request(
+ destination,
+ "DELETE",
+ path,
+ query_bytes=encode_query_args(args),
+ headers_dict={"Content-Type": ["application/json"]},
+ long_retries=long_retries,
+ timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ )
+
+ if 200 <= response.code < 300:
+ # We need to update the transactions table to say it was sent?
+ check_content_type_is_json(response.headers)
+
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True, max_size=None):
+ retry_on_dns_fail=True, max_size=None,
+ ignore_backoff=False):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
- A (int,dict) tuple of the file length and a dict of the response
- headers.
+ Deferred: resolves with an (int,dict) tuple of the file length and
+ a dict of the response headers.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response code
+ >= 300
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
encoded_args = {}
@@ -419,29 +558,30 @@ class MatrixFederationHttpClient(object):
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
- logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
+ logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"GET",
- path.encode("ascii"),
+ path,
query_bytes=query_bytes,
body_callback=body_callback,
- retry_on_dns_fail=retry_on_dns_fail
+ retry_on_dns_fail=retry_on_dns_fail,
+ ignore_backoff=ignore_backoff,
)
headers = dict(response.headers.getAllRawHeaders())
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
- except:
+ with logcontext.PreserveLoggingContext():
+ length = yield _readBodyToFile(
+ response, output_stream, max_size
+ )
+ except Exception:
logger.exception("Failed to download body")
raise
@@ -506,12 +646,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
- return ", ".join(
+ reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
+
+ return "%s:[%s]" % (type(e).__name__, reasons)
else:
- return "%s: %s" % (type(e).__name__, e.message,)
+ return repr(e)
def check_content_type_is_json(headers):
@@ -538,3 +680,15 @@ def check_content_type_is_json(headers):
raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)
+
+
+def encode_query_args(args):
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, basestring):
+ vs = [vs]
+ encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+ query_bytes = urllib.urlencode(encoded_args, True)
+
+ return query_bytes
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 14715878c5..55b9ad5251 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,6 +29,7 @@ from canonicaljson import (
)
from twisted.internet import defer
+from twisted.python import failure
from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo
@@ -35,42 +37,86 @@ from twisted.web.util import redirectTo
import collections
import logging
import urllib
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-incoming_requests_counter = metrics.register_counter(
- "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+ "response_count",
labels=["method", "servlet", "tag"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_requests",
+ "_response_time:count",
+ "_response_ru_utime:count",
+ "_response_ru_stime:count",
+ "_response_db_txn_count:count",
+ "_response_db_txn_duration:count",
+ )
+ )
+)
+
+requests_counter = metrics.register_counter(
+ "requests_received",
+ labels=["method", "servlet", ],
)
+
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
)
-response_timer = metrics.register_distribution(
- "response_time",
- labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+ "response_time_seconds",
+ labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_time:total",
+ ),
)
-response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+ "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_utime:total",
+ ),
)
-response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+ "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_ru_stime:total",
+ ),
)
-response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+ "response_db_txn_count", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_count:total",
+ ),
)
-response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+ "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+ alternative_names=(
+ metrics.name_prefix + "_response_db_txn_duration:total",
+ ),
)
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+ "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
+
+# size in bytes of the response written
+response_size = metrics.register_counter(
+ "response_size", labels=["method", "servlet", "tag"]
+)
_next_request_id = 0
@@ -106,7 +152,12 @@ def wrap_request_handler(request_handler, include_metrics=False):
with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
- request_metrics.start(self.clock, name=self.__class__.__name__)
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
+ servlet_name = self.__class__.__name__
+ request_metrics.start(self.clock, name=servlet_name)
request_context.request = request_id
with request.processing():
@@ -115,6 +166,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
+ requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
@@ -130,13 +182,18 @@ def wrap_request_handler(request_handler, include_metrics=False):
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
- except:
- logger.exception(
- "Failed handle request %s.%s on %r: %r",
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ logger.error(
+ "Failed handle request %s.%s on %r: %r: %s",
request_handler.__module__,
request_handler.__name__,
self,
- request
+ request,
+ f.getTraceback().rstrip(),
)
respond_with_json(
request,
@@ -145,7 +202,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
- send_cors=True
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ version_string=self.version_string,
)
finally:
try:
@@ -183,7 +242,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
- Register callbacks via register_path()
+ Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
@@ -230,57 +289,62 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and
path.
"""
- if request.method == "OPTIONS":
- self._send_response(request, 200, {})
- return
+ callback, group_dict = self._get_handler_for_request(request)
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path)
- if not m:
- continue
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
- # We found a match! Trigger callback and then return the
- # returned response. We pass both the request and any
- # matched groups from the regex to the callback.
+ request_metrics.name = servlet_classname
+ requests_counter.inc(request.method, servlet_classname)
- callback = path_entry.callback
+ # Now trigger the callback. If it returns a response, we send it
+ # here. If it throws an exception, that is handled by the wrapper
+ # installed by @request_handler.
- kwargs = intern_dict({
- name: urllib.unquote(value).decode("UTF-8") if value else value
- for name, value in m.groupdict().items()
- })
+ kwargs = intern_dict({
+ name: urllib.unquote(value).decode("UTF-8") if value else value
+ for name, value in group_dict.items()
+ })
- callback_return = yield callback(request, **kwargs)
- if callback_return is not None:
- code, response = callback_return
- self._send_response(request, code, response)
+ callback_return = yield callback(request, **kwargs)
+ if callback_return is not None:
+ code, response = callback_return
+ self._send_response(request, code, response)
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
+ def _get_handler_for_request(self, request):
+ """Finds a callback method to handle the given request
- request_metrics.name = servlet_classname
+ Args:
+ request (twisted.web.http.Request):
+
+ Returns:
+ Tuple[Callable, dict[str, str]]: callback method, and the dict
+ mapping keys to path components as specified in the handler's
+ path match regexp.
- return
+ The callback will normally be a method registered via
+ register_paths, so will return (possibly via Deferred) either
+ None, or a tuple of (http code, response body).
+ """
+ if request.method == b"OPTIONS":
+ return _options_handler, {}
+
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request.path)
+ if m:
+ # We found a match!
+ return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- raise UnrecognizedRequestError()
+ return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it.
@@ -294,6 +358,34 @@ class JsonResource(HttpServer, resource.Resource):
)
+def _options_handler(request):
+ """Request handler for OPTIONS requests
+
+ This is a request handler suitable for return from
+ _get_handler_for_request. It returns a 200 and an empty body.
+
+ Args:
+ request (twisted.web.http.Request):
+
+ Returns:
+ Tuple[int, dict]: http code, response body.
+ """
+ return 200, {}
+
+
+def _unrecognised_request_handler(request):
+ """Request handler for unrecognised requests
+
+ This is a request handler suitable for return from
+ _get_handler_for_request. It actually just raises an
+ UnrecognizedRequestError.
+
+ Args:
+ request (twisted.web.http.Request):
+ """
+ raise UnrecognizedRequestError()
+
+
class RequestMetrics(object):
def start(self, clock, name):
self.start = clock.time_msec()
@@ -314,7 +406,7 @@ class RequestMetrics(object):
)
return
- incoming_requests_counter.inc(request.method, self.name, tag)
+ response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
@@ -333,9 +425,14 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
+ context.db_txn_duration_ms / 1000., request.method, self.name, tag
+ )
+ response_db_sched_duration.inc_by(
+ context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
+ response_size.inc_by(request.sentLength, request.method, self.name, tag)
+
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
@@ -356,14 +453,22 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- # ujson doesn't like frozen_dicts.
- json_bytes = ujson.dumps(json_object, ensure_ascii=False)
+ json_bytes = simplejson.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
@@ -390,6 +495,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
+ request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
@@ -412,7 +518,7 @@ def set_cors_headers(request):
)
request.setHeader(
"Access-Control-Allow-Headers",
- "Origin, X-Requested-With, Content-Type, Accept"
+ "Origin, X-Requested-With, Content-Type, Accept, Authorization"
)
@@ -437,9 +543,9 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
- "User-Agent", default=[]
+ b"User-Agent", default=[]
)
for user_agent in user_agents:
- if "curl" in user_agent:
+ if b"curl" in user_agent:
return True
return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 8c22d6f00f..ef8e62901b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -48,7 +48,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try:
return int(args[name][0])
- except:
+ except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
@@ -88,7 +88,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
"true": True,
"false": False,
}[args[name][0]]
- except:
+ except Exception:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
@@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into None
Returns:
The JSON value.
@@ -162,28 +164,39 @@ def parse_json_value_from_request(request):
"""
try:
content_bytes = request.content.read()
- except:
+ except Exception:
raise SynapseError(400, "Error reading JSON content.")
+ if not content_bytes and allow_empty_body:
+ return None
+
try:
content = simplejson.loads(content_bytes)
- except simplejson.JSONDecodeError:
+ except Exception as e:
+ logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into an empty dict.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(request)
+ content = parse_json_value_from_request(
+ request, allow_empty_body=allow_empty_body,
+ )
+
+ if allow_empty_body and content is None:
+ return {}
if type(content) != dict:
message = "Content must be a JSON object."
@@ -192,6 +205,16 @@ def parse_json_object_from_request(request):
return content
+def assert_params_in_request(body, required):
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if len(absent) > 0:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+
class RestServlet(object):
""" A Synapse REST Servlet.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4b09d7ee66..c8b46e1af2 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ import logging
import re
import time
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
class SynapseRequest(Request):
@@ -43,12 +43,12 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
+ br'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
- return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+ return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
- except:
+ db_txn_duration_ms = context.db_txn_duration_ms
+ db_sched_duration_ms = context.db_sched_duration_ms
+ except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
+ " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
- int(db_txn_duration * 1000),
+ db_sched_duration_ms,
+ db_txn_duration_ms,
int(db_txn_count),
self.sentLength,
self.code,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..e3b831db67 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -17,12 +17,13 @@ import logging
import functools
import time
import gc
+import platform
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
- MemoryUsageMetric,
+ MemoryUsageMetric, GaugeMetric,
)
from .process_collector import register_process_collector
@@ -30,6 +31,7 @@ from .process_collector import register_process_collector
logger = logging.getLogger(__name__)
+running_on_pypy = platform.python_implementation() == 'PyPy'
all_metrics = []
all_collectors = []
@@ -57,15 +59,38 @@ class Metrics(object):
return metric
def register_counter(self, *args, **kwargs):
+ """
+ Returns:
+ CounterMetric
+ """
return self._register(CounterMetric, *args, **kwargs)
+ def register_gauge(self, *args, **kwargs):
+ """
+ Returns:
+ GaugeMetric
+ """
+ return self._register(GaugeMetric, *args, **kwargs)
+
def register_callback(self, *args, **kwargs):
+ """
+ Returns:
+ CallbackMetric
+ """
return self._register(CallbackMetric, *args, **kwargs)
def register_distribution(self, *args, **kwargs):
+ """
+ Returns:
+ DistributionMetric
+ """
return self._register(DistributionMetric, *args, **kwargs)
def register_cache(self, *args, **kwargs):
+ """
+ Returns:
+ CacheMetric
+ """
return self._register(CacheMetric, *args, **kwargs)
@@ -126,6 +151,32 @@ reactor_metrics = get_metrics_for("python.twisted.reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
+synapse_metrics = get_metrics_for("synapse")
+
+# Used to track where various components have processed in the event stream,
+# e.g. federation sending, appservice sending, etc.
+event_processing_positions = synapse_metrics.register_gauge(
+ "event_processing_positions", labels=["name"],
+)
+
+# Used to track the current max events stream position
+event_persisted_position = synapse_metrics.register_gauge(
+ "event_persisted_position",
+)
+
+# Used to track the received_ts of the last event processed by various
+# components
+event_processing_last_ts = synapse_metrics.register_gauge(
+ "event_processing_last_ts", labels=["name"],
+)
+
+# Used to track the lag processing events. This is the time difference
+# between the last processed event's received_ts and the time it was
+# finished being processed.
+event_processing_lag = synapse_metrics.register_gauge(
+ "event_processing_lag", labels=["name"],
+)
+
def runUntilCurrentTimer(func):
@@ -146,13 +197,21 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
-
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
+
+ # record the amount of wallclock time spent running pending calls.
+ # This is a proxy for the actual amount of time between reactor polls,
+ # since about 25% of time is actually spent running things triggered by
+ # I/O events, but that is harder to capture without rewriting half the
+ # reactor.
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
+ if running_on_pypy:
+ return ret
+
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
threshold = gc.get_threshold()
@@ -185,6 +244,7 @@ try:
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
- gc.disable()
+ if not running_on_pypy:
+ gc.disable()
except AttributeError:
pass
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..fbba94e633 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -15,18 +15,39 @@
from itertools import chain
+import logging
+import re
+logger = logging.getLogger(__name__)
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
- # flatten a list-of-lists
- return list(chain.from_iterable(map(func, items)))
+
+def flatten(items):
+ """Flatten a list of lists
+
+ Args:
+ items: iterable[iterable[X]]
+
+ Returns:
+ list[X]: flattened list
+ """
+ return list(chain.from_iterable(items))
class BaseMetric(object):
+ """Base class for metrics which report a single value per label set
+ """
- def __init__(self, name, labels=[]):
- self.name = name
+ def __init__(self, name, labels=[], alternative_names=[]):
+ """
+ Args:
+ name (str): principal name for this metric
+ labels (list(str)): names of the labels which will be reported
+ for this metric
+ alternative_names (iterable(str)): list of alternative names for
+ this metric. This can be useful to provide a migration path
+ when renaming metrics.
+ """
+ self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it
def dimension(self):
@@ -36,8 +57,7 @@ class BaseMetric(object):
return not len(self.labels)
def _render_labelvalue(self, value):
- # TODO: some kind of value escape
- return '"%s"' % (value)
+ return '"%s"' % (_escape_label_value(value),)
def _render_key(self, values):
if self.is_scalar():
@@ -47,19 +67,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
+ def _render_for_labels(self, label_values, value):
+ """Render this metric for a single set of labels
+
+ Args:
+ label_values (list[str]): values for each of the labels
+ value: value of the metric at with these labels
+
+ Returns:
+ iterable[str]: rendered metric
+ """
+ rendered_labels = self._render_key(label_values)
+ return (
+ "%s%s %.12g" % (name, rendered_labels, value)
+ for name in self._names
+ )
+
+ def render(self):
+ """Render this metric
+
+ Each metric is rendered as:
+
+ name{label1="val1",label2="val2"} value
+
+ https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+ Returns:
+ iterable[str]: rendered metrics
+ """
+ raise NotImplementedError()
+
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
- integer that counts events."""
+ value that counts events or running totals.
+
+ Example use cases for Counters:
+ - Number of requests processed
+ - Number of items that were inserted into a queue
+ - Total amount of data that a system has processed
+ Counters can only go up (and be reset when the process restarts).
+ """
def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs)
+ # dict[list[str]]: value for each set of label values. the keys are the
+ # label values, in the same order as the labels in self.labels.
+ #
+ # (if the metric is a scalar, the (single) key is the empty tuple).
self.counts = {}
# Scalar metrics are never empty
if self.is_scalar():
- self.counts[()] = 0
+ self.counts[()] = 0.
def inc_by(self, incr, *values):
if len(values) != self.dimension():
@@ -77,11 +138,41 @@ class CounterMetric(BaseMetric):
def inc(self, *values):
self.inc_by(1, *values)
- def render_item(self, k):
- return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
+ def render(self):
+ return flatten(
+ self._render_for_labels(k, self.counts[k])
+ for k in sorted(self.counts.keys())
+ )
+
+
+class GaugeMetric(BaseMetric):
+ """A metric that can go up or down
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(GaugeMetric, self).__init__(*args, **kwargs)
+
+ # dict[list[str]]: value for each set of label values. the keys are the
+ # label values, in the same order as the labels in self.labels.
+ #
+ # (if the metric is a scalar, the (single) key is the empty tuple).
+ self.guages = {}
+
+ def set(self, v, *values):
+ if len(values) != self.dimension():
+ raise ValueError(
+ "Expected as many values to inc() as labels (%d)" % (self.dimension())
+ )
+
+ # TODO: should assert that the tag values are all strings
+
+ self.guages[values] = v
def render(self):
- return map_concat(self.render_item, sorted(self.counts.keys()))
+ return flatten(
+ self._render_for_labels(k, self.guages[k])
+ for k in sorted(self.guages.keys())
+ )
class CallbackMetric(BaseMetric):
@@ -95,13 +186,19 @@ class CallbackMetric(BaseMetric):
self.callback = callback
def render(self):
- value = self.callback()
+ try:
+ value = self.callback()
+ except Exception:
+ logger.exception("Failed to render %s", self.name)
+ return ["# FAILED to render " + self.name]
if self.is_scalar():
- return ["%s %.12g" % (self.name, value)]
+ return list(self._render_for_labels([], value))
- return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
- for k in sorted(value.keys())]
+ return flatten(
+ self._render_for_labels(k, value[k])
+ for k in sorted(value.keys())
+ )
class DistributionMetric(object):
@@ -126,7 +223,9 @@ class DistributionMetric(object):
class CacheMetric(object):
- __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+ __slots__ = (
+ "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+ )
def __init__(self, name, size_callback, cache_name):
self.name = name
@@ -134,6 +233,7 @@ class CacheMetric(object):
self.hits = 0
self.misses = 0
+ self.evicted_size = 0
self.size_callback = size_callback
@@ -143,6 +243,9 @@ class CacheMetric(object):
def inc_misses(self):
self.misses += 1
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
def render(self):
size = self.size_callback()
hits = self.hits
@@ -152,6 +255,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+ """%s:evicted_size{name="%s"} %d""" % (
+ self.name, self.cache_name, self.evicted_size
+ ),
]
@@ -193,3 +299,29 @@ class MemoryUsageMetric(object):
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]
+
+
+def _escape_character(m):
+ """Replaces a single character with its escape sequence.
+
+ Args:
+ m (re.MatchObject): A match object whose first group is the single
+ character to replace
+
+ Returns:
+ str
+ """
+ c = m.group(1)
+ if c == "\\":
+ return "\\\\"
+ elif c == "\"":
+ return "\\\""
+ elif c == "\n":
+ return "\\n"
+ return c
+
+
+def _escape_label_value(value):
+ """Takes a label value and escapes quotes, newlines and backslashes
+ """
+ return re.sub(r"([\n\"\\])", _escape_character, value)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
new file mode 100644
index 0000000000..097c844d31
--- /dev/null
+++ b/synapse/module_api/__init__.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+
+class ModuleApi(object):
+ """A proxy object that gets passed to password auth providers so they
+ can register new users etc if necessary.
+ """
+ def __init__(self, hs, auth_handler):
+ self.hs = hs
+
+ self._store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = auth_handler
+
+ def get_user_by_req(self, req, allow_guest=False):
+ """Check the access_token provided for a request
+
+ Args:
+ req (twisted.web.server.Request): Incoming HTTP request
+ allow_guest (bool): True if guest users should be allowed. If this
+ is False, and the access token is for a guest user, an
+ AuthError will be thrown
+ Returns:
+ twisted.internet.defer.Deferred[synapse.types.Requester]:
+ the requester for this request
+ Raises:
+ synapse.api.errors.AuthError: if no user by that token exists,
+ or the token is invalid.
+ """
+ return self._auth.get_user_by_req(req, allow_guest)
+
+ def get_qualified_user_id(self, username):
+ """Qualify a user id, if necessary
+
+ Takes a user id provided by the user and adds the @ and :domain to
+ qualify it, if necessary
+
+ Args:
+ username (str): provided user id
+
+ Returns:
+ str: qualified @user:id
+ """
+ if username.startswith('@'):
+ return username
+ return UserID(username, self.hs.hostname).to_string()
+
+ def check_user_exists(self, user_id):
+ """Check if user exists.
+
+ Args:
+ user_id (str): Complete @user:id
+
+ Returns:
+ Deferred[str|None]: Canonical (case-corrected) user_id, or None
+ if the user is not registered.
+ """
+ return self._auth_handler.check_user_exists(user_id)
+
+ def register(self, localpart):
+ """Registers a new user with given localpart
+
+ Returns:
+ Deferred: a 2-tuple of (user_id, access_token)
+ """
+ reg = self.hs.get_handlers().registration_handler
+ return reg.register(localpart=localpart)
+
+ @defer.inlineCallbacks
+ def invalidate_access_token(self, access_token):
+ """Invalidate an access token for a user
+
+ Args:
+ access_token(str): access token
+
+ Returns:
+ twisted.internet.defer.Deferred - resolves once the access token
+ has been removed.
+
+ Raises:
+ synapse.api.errors.AuthError: the access token is invalid
+ """
+ # see if the access token corresponds to a device
+ user_info = yield self._auth.get_user_by_access_token(access_token)
+ device_id = user_info.get("device_id")
+ user_id = user_info["user"].to_string()
+ if device_id:
+ # delete the device, which will also delete its access tokens
+ yield self.hs.get_device_handler().delete_device(user_id, device_id)
+ else:
+ # no associated device. Just delete the access token.
+ yield self._auth_handler.delete_access_token(access_token)
+
+ def run_db_interaction(self, desc, func, *args, **kwargs):
+ """Run a function with a database connection
+
+ Args:
+ desc (str): description for the transaction, for metrics etc
+ func (func): function to be run. Passed a database cursor object
+ as well as *args and **kwargs
+ *args: positional args to be passed to func
+ **kwargs: named args to be passed to func
+
+ Returns:
+ Deferred[object]: result of func
+ """
+ return self._store.runInteraction(desc, func, *args, **kwargs)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 8051a7a842..8355c7d621 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,13 +14,17 @@
# limitations under the License.
from twisted.internet import defer
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
+from synapse.handlers.presence import format_user_presence_state
-from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.async import (
+ ObservableDeferred, add_timeout_to_deferred,
+ DeferredTimeoutError,
+)
+from synapse.util.logcontext import PreserveLoggingContext, run_in_background
from synapse.util.metrics import Measure
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
@@ -37,6 +41,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
notified_events_counter = metrics.register_counter("notified_events")
+users_woken_by_stream_counter = metrics.register_counter(
+ "users_woken_by_stream", labels=["stream"]
+)
+
# TODO(paul): Should be shared somewhere
def count(func, l):
@@ -73,6 +81,13 @@ class _NotifierUserStream(object):
self.user_id = user_id
self.rooms = set(rooms)
self.current_token = current_token
+
+ # The last token for which we should wake up any streams that have a
+ # token that comes before it. This gets updated everytime we get poked.
+ # We start it at the current token since if we get any streams
+ # that have a token from before we have no idea whether they should be
+ # woken up or not, so lets just wake them up.
+ self.last_notified_token = current_token
self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
@@ -89,9 +104,12 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
+ self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
+ users_woken_by_stream_counter.inc(stream_key)
+
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
@@ -113,8 +131,14 @@ class _NotifierUserStream(object):
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
+
+ Args:
+ token: The token from which we are streaming from, i.e. we shouldn't
+ notify for things that happened before this.
"""
- if self.current_token.is_after(token):
+ # Immediately wake up stream if something has already since happened
+ # since their last token.
+ if self.last_notified_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
@@ -123,6 +147,7 @@ class _NotifierUserStream(object):
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class Notifier(object):
@@ -142,6 +167,8 @@ class Notifier(object):
self.store = hs.get_datastore()
self.pending_new_room_events = []
+ self.replication_callbacks = []
+
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
@@ -181,7 +208,12 @@ class Notifier(object):
lambda: len(self.user_to_user_stream),
)
- @preserve_fn
+ def add_replication_callback(self, cb):
+ """Add a callback that will be called when some new data is available.
+ Callback is not given any arguments.
+ """
+ self.replication_callbacks.append(cb)
+
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -195,15 +227,13 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- with PreserveLoggingContext():
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
- self._notify_pending_new_room_events(max_room_stream_id)
+ self.pending_new_room_events.append((
+ room_stream_id, event, extra_users
+ ))
+ self._notify_pending_new_room_events(max_room_stream_id)
- self.notify_replication()
+ self.notify_replication()
- @preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
@@ -221,11 +251,10 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
- @preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- self.appservice_handler.notify_interested_services(room_stream_id)
+ run_in_background(self._notify_app_services, room_stream_id)
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
@@ -239,7 +268,13 @@ class Notifier(object):
rooms=[event.room_id],
)
- @preserve_fn
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
+
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
@@ -261,17 +296,15 @@ class Notifier(object):
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
- except:
+ except Exception:
logger.exception("Failed to notify listener")
self.notify_replication()
- @preserve_fn
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
- with PreserveLoggingContext():
- self.notify_replication()
+ self.notify_replication()
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -283,8 +316,7 @@ class Notifier(object):
if user_stream is None:
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = [room.room_id for room in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
@@ -294,40 +326,45 @@ class Notifier(object):
self._register_with_keys(user_stream)
result = None
+ prev_token = from_token
if timeout:
end_time = self.clock.time_msec() + timeout
- prev_token = from_token
while not result:
try:
- current_token = user_stream.current_token
-
- result = yield callback(prev_token, current_token)
- if result:
- break
-
now = self.clock.time_msec()
if end_time <= now:
break
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
- # We need to supply the token we supplied to callback so
- # that we don't miss any current_token updates.
- prev_token = current_token
listener = user_stream.new_listener(prev_token)
+ add_timeout_to_deferred(
+ listener.deferred,
+ (end_time - now) / 1000.,
+ )
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
- except DeferredTimedOutError:
+ yield listener.deferred
+
+ current_token = user_stream.current_token
+
+ result = yield callback(prev_token, current_token)
+ if result:
+ break
+
+ # Update the prev_token to the current_token since nothing
+ # has happened between the old prev_token and the current_token
+ prev_token = current_token
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
- else:
+
+ if result is None:
+ # This happened if there was no timeout or if the timeout had
+ # already expired.
current_token = user_stream.current_token
- result = yield callback(from_token, current_token)
+ result = yield callback(prev_token, current_token)
defer.returnValue(result)
@@ -388,6 +425,15 @@ class Notifier(object):
new_events,
is_peeking=is_peeking,
)
+ elif name == "presence":
+ now = self.clock.time_msec()
+ new_events[:] = [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(event, now),
+ }
+ for event in new_events
+ ]
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
@@ -420,8 +466,7 @@ class Notifier(object):
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
- joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
- joined_room_ids = map(lambda r: r.room_id, joined_rooms)
+ joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
defer.returnValue(([explicit_room_id], True))
@@ -478,6 +523,15 @@ class Notifier(object):
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
+ # the callbacks may well outlast the current request, so we run
+ # them in the sentinel logcontext.
+ #
+ # (ideally it would be up to the callbacks to know if they were
+ # starting off background processes and drop the logcontext
+ # accordingly, but that requires more changes)
+ for cb in self.replication_callbacks:
+ cb()
+
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
@@ -506,13 +560,14 @@ class Notifier(object):
if end_time <= now:
break
+ add_timeout_to_deferred(
+ listener.deferred.addTimeout,
+ (end_time - now) / 1000.,
+ )
try:
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
- except DeferredTimedOutError:
+ yield listener.deferred
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 3f75d3f921..8f619a7a1b 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from .bulk_push_rule_evaluator import evaluator_for_event
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
@@ -24,11 +24,12 @@ import logging
logger = logging.getLogger(__name__)
-class ActionGenerator:
+class ActionGenerator(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user.
@@ -38,16 +39,7 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
- with Measure(self.clock, "evaluator_for_event"):
- bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context
- )
-
with Measure(self.clock, "action_for_event_by_user"):
- actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+ yield self.bulk_evaluator.action_for_event_by_user(
event, context
)
-
- context.push_actions = [
- (uid, actions) for uid, actions in actions_by_user.items()
- ]
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 85effdfa46..7a18afe5f9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
}
]
},
+ {
+ 'rule_id': 'global/override/.m.rule.roomnotif',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern': '@room',
+ '_id': '_roomnotif_content',
+ },
+ {
+ 'kind': 'sender_notification_permission',
+ 'key': 'room',
+ '_id': '_roomnotif_pl',
+ },
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': True,
+ }
+ ]
+ }
]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 78b095c903..7c680659b6 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,88 +20,166 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-from synapse.api.constants import EventTypes
-from synapse.visibility import filter_events_for_clients_context
+from synapse.event_auth import get_user_power_level
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import get_metrics_for
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.caches.descriptors import cached
+from synapse.util.async import Linearizer
+from synapse.state import POWER_KEY
+
+from collections import namedtuple
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, context):
- rules_by_user = yield store.bulk_get_push_rules_for_room(
- event, context
- )
-
- # if this event is an invite event, we may need to run rules for the user
- # who's been invited, otherwise they won't get told they've been invited
- if event.type == 'm.room.member' and event.content['membership'] == 'invite':
- invited_user = event.state_key
- if invited_user and hs.is_mine_id(invited_user):
- has_pusher = yield store.user_has_pusher(invited_user)
- if has_pusher:
- rules_by_user = dict(rules_by_user)
- rules_by_user[invited_user] = yield store.get_push_rules_for_user(
- invited_user
- )
+rules_by_room = {}
- defer.returnValue(BulkPushRuleEvaluator(
- event.room_id, rules_by_user, store
- ))
+push_metrics = get_metrics_for(__name__)
+push_rules_invalidation_counter = push_metrics.register_counter(
+ "push_rules_invalidation_counter"
+)
+push_rules_state_size_counter = push_metrics.register_counter(
+ "push_rules_state_size_counter"
+)
-class BulkPushRuleEvaluator:
- """
- Runs push rules for all users in a room.
- This is faster than running PushRuleEvaluator for each user because it
- fetches all the rules for all the users in one (batched) db query
- rather than doing multiple queries per-user. It currently uses
- the same logic to run the actual rules, but could be optimised further
- (see https://matrix.org/jira/browse/SYN-562)
+# Measures whether we use the fast path of using state deltas, or if we have to
+# recalculate from scratch
+push_rules_delta_state_cache_metric = cache_metrics.register_cache(
+ "cache",
+ size_callback=lambda: 0, # Meaningless size, as this isn't a cache that stores values
+ cache_name="push_rules_delta_state_cache_metric",
+)
+
+
+class BulkPushRuleEvaluator(object):
+ """Calculates the outcome of push rules for an event for all users in the
+ room at once.
"""
- def __init__(self, room_id, rules_by_user, store):
- self.room_id = room_id
- self.rules_by_user = rules_by_user
- self.store = store
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ self.room_push_rule_cache_metrics = cache_metrics.register_cache(
+ "cache",
+ size_callback=lambda: 0, # There's not good value for this
+ cache_name="room_push_rule_cache",
+ )
@defer.inlineCallbacks
- def action_for_event_by_user(self, event, context):
- actions_by_user = {}
+ def _get_rules_for_event(self, event, context):
+ """This gets the rules for all users in the room at the time of the event,
+ as well as the push rules for the invitee if the event is an invite.
+
+ Returns:
+ dict of user_id -> push_rules
+ """
+ room_id = event.room_id
+ rules_for_room = self._get_rules_for_room(room_id)
+
+ rules_by_user = yield rules_for_room.get_rules(event, context)
+
+ # if this event is an invite event, we may need to run rules for the user
+ # who's been invited, otherwise they won't get told they've been invited
+ if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+ invited = event.state_key
+ if invited and self.hs.is_mine_id(invited):
+ has_pusher = yield self.store.user_has_pusher(invited)
+ if has_pusher:
+ rules_by_user = dict(rules_by_user)
+ rules_by_user[invited] = yield self.store.get_push_rules_for_user(
+ invited
+ )
- # None of these users can be peeking since this list of users comes
- # from the set of users in the room, so we know for sure they're all
- # actually in the room.
- user_tuples = [
- (u, False) for u in self.rules_by_user.keys()
- ]
+ defer.returnValue(rules_by_user)
- filtered_by_user = yield filter_events_for_clients_context(
- self.store, user_tuples, [event], {event.event_id: context}
+ @cached()
+ def _get_rules_for_room(self, room_id):
+ """Get the current RulesForRoom object for the given room id
+
+ Returns:
+ RulesForRoom
+ """
+ # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+ # before any lookup methods get called on it as otherwise there may be
+ # a race if invalidate_all gets called (which assumes its in the cache)
+ return RulesForRoom(
+ self.hs, room_id, self._get_rules_for_room.cache,
+ self.room_push_rule_cache_metrics,
)
+ @defer.inlineCallbacks
+ def _get_power_levels_and_sender_level(self, event, context):
+ pl_event_id = context.prev_state_ids.get(POWER_KEY)
+ if pl_event_id:
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
+ pl_event = yield self.store.get_event(pl_event_id)
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.prev_state_ids, for_verification=False,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.itervalues()
+ }
+
+ sender_level = get_user_power_level(event.sender, auth_events)
+
+ pl_event = auth_events.get(POWER_KEY)
+
+ defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+
+ @defer.inlineCallbacks
+ def action_for_event_by_user(self, event, context):
+ """Given an event and context, evaluate the push rules and insert the
+ results into the event_push_actions_staging table.
+
+ Returns:
+ Deferred
+ """
+ rules_by_user = yield self._get_rules_for_event(event, context)
+ actions_by_user = {}
+
room_members = yield self.store.get_joined_users_from_context(
event, context
)
- evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
+ (power_levels, sender_power_level) = (
+ yield self._get_power_levels_and_sender_level(event, context)
+ )
+
+ evaluator = PushRuleEvaluatorForEvent(
+ event, len(room_members), sender_power_level, power_levels,
+ )
condition_cache = {}
- for uid, rules in self.rules_by_user.items():
- display_name = room_members.get(uid, {}).get("display_name", None)
+ for uid, rules in rules_by_user.iteritems():
+ if event.sender == uid:
+ continue
+
+ if not event.is_state():
+ is_ignored = yield self.store.is_ignored_by(event.sender, uid)
+ if is_ignored:
+ continue
+
+ display_name = None
+ profile_info = room_members.get(uid)
+ if profile_info:
+ display_name = profile_info.display_name
+
if not display_name:
# Handle the case where we are pushing a membership event to
# that user, as they might not be already joined.
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
- filtered = filtered_by_user[uid]
- if len(filtered) == 0:
- continue
-
- if filtered[0].sender == uid:
- continue
-
for rule in rules:
if 'enabled' in rule and not rule['enabled']:
continue
@@ -111,9 +190,16 @@ class BulkPushRuleEvaluator:
if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify']
if actions and 'notify' in actions:
+ # Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
- defer.returnValue(actions_by_user)
+
+ # Mark in the DB staging area the push actions for users who should be
+ # notified for this event. (This will then get handled when we persist
+ # the event)
+ yield self.store.add_push_actions_to_staging(
+ event.event_id, actions_by_user,
+ )
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -134,3 +220,264 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
return False
return True
+
+
+class RulesForRoom(object):
+ """Caches push rules for users in a room.
+
+ This efficiently handles users joining/leaving the room by not invalidating
+ the entire cache for the room.
+ """
+
+ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+ """
+ Args:
+ hs (HomeServer)
+ room_id (str)
+ rules_for_room_cache(Cache): The cache object that caches these
+ RoomsForUser objects.
+ room_push_rule_cache_metrics (CacheMetric)
+ """
+ self.room_id = room_id
+ self.is_mine_id = hs.is_mine_id
+ self.store = hs.get_datastore()
+ self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
+
+ self.linearizer = Linearizer(name="rules_for_room")
+
+ self.member_map = {} # event_id -> (user_id, state)
+ self.rules_by_user = {} # user_id -> rules
+
+ # The last state group we updated the caches for. If the state_group of
+ # a new event comes along, we know that we can just return the cached
+ # result.
+ # On invalidation of the rules themselves (if the user changes them),
+ # we invalidate everything and set state_group to `object()`
+ self.state_group = object()
+
+ # A sequence number to keep track of when we're allowed to update the
+ # cache. We bump the sequence number when we invalidate the cache. If
+ # the sequence number changes while we're calculating stuff we should
+ # not update the cache with it.
+ self.sequence = 0
+
+ # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+ # owned by AS's, or remote users, etc. (I.e. users we will never need to
+ # calculate push for)
+ # These never need to be invalidated as we will never set up push for
+ # them.
+ self.uninteresting_user_set = set()
+
+ # We need to be clever on the invalidating caches callbacks, as
+ # otherwise the invalidation callback holds a reference to the object,
+ # potentially causing it to leak.
+ # To get around this we pass a function that on invalidations looks ups
+ # the RoomsForUser entry in the cache, rather than keeping a reference
+ # to self around in the callback.
+ self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
+
+ @defer.inlineCallbacks
+ def get_rules(self, event, context):
+ """Given an event context return the rules for all users who are
+ currently in the room.
+ """
+ state_group = context.state_group
+
+ if state_group and self.state_group == state_group:
+ logger.debug("Using cached rules for %r", self.room_id)
+ self.room_push_rule_cache_metrics.inc_hits()
+ defer.returnValue(self.rules_by_user)
+
+ with (yield self.linearizer.queue(())):
+ if state_group and self.state_group == state_group:
+ logger.debug("Using cached rules for %r", self.room_id)
+ self.room_push_rule_cache_metrics.inc_hits()
+ defer.returnValue(self.rules_by_user)
+
+ self.room_push_rule_cache_metrics.inc_misses()
+
+ ret_rules_by_user = {}
+ missing_member_event_ids = {}
+ if state_group and self.state_group == context.prev_group:
+ # If we have a simple delta then we can reuse most of the previous
+ # results.
+ ret_rules_by_user = self.rules_by_user
+ current_state_ids = context.delta_ids
+
+ push_rules_delta_state_cache_metric.inc_hits()
+ else:
+ current_state_ids = context.current_state_ids
+ push_rules_delta_state_cache_metric.inc_misses()
+
+ push_rules_state_size_counter.inc_by(len(current_state_ids))
+
+ logger.debug(
+ "Looking for member changes in %r %r", state_group, current_state_ids
+ )
+
+ # Loop through to see which member events we've seen and have rules
+ # for and which we need to fetch
+ for key in current_state_ids:
+ typ, user_id = key
+ if typ != EventTypes.Member:
+ continue
+
+ if user_id in self.uninteresting_user_set:
+ continue
+
+ if not self.is_mine_id(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ if self.store.get_if_app_services_interested_in_user(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ event_id = current_state_ids[key]
+
+ res = self.member_map.get(event_id, None)
+ if res:
+ user_id, state = res
+ if state == Membership.JOIN:
+ rules = self.rules_by_user.get(user_id, None)
+ if rules:
+ ret_rules_by_user[user_id] = rules
+ continue
+
+ # If a user has left a room we remove their push rule. If they
+ # joined then we readd it later in _update_rules_with_member_event_ids
+ ret_rules_by_user.pop(user_id, None)
+ missing_member_event_ids[user_id] = event_id
+
+ if missing_member_event_ids:
+ # If we have some memebr events we haven't seen, look them up
+ # and fetch push rules for them if appropriate.
+ logger.debug("Found new member events %r", missing_member_event_ids)
+ yield self._update_rules_with_member_event_ids(
+ ret_rules_by_user, missing_member_event_ids, state_group, event
+ )
+ else:
+ # The push rules didn't change but lets update the cache anyway
+ self.update_cache(
+ self.sequence,
+ members={}, # There were no membership changes
+ rules_by_user=ret_rules_by_user,
+ state_group=state_group
+ )
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "Returning push rules for %r %r",
+ self.room_id, ret_rules_by_user.keys(),
+ )
+ defer.returnValue(ret_rules_by_user)
+
+ @defer.inlineCallbacks
+ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
+ state_group, event):
+ """Update the partially filled rules_by_user dict by fetching rules for
+ any newly joined users in the `member_event_ids` list.
+
+ Args:
+ ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ updated with any new rules.
+ member_event_ids (list): List of event ids for membership events that
+ have happened since the last time we filled rules_by_user
+ state_group: The state group we are currently computing push rules
+ for. Used when updating the cache.
+ """
+ sequence = self.sequence
+
+ rows = yield self.store._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids.values(),
+ retcols=('user_id', 'membership', 'event_id'),
+ keyvalues={},
+ batch_size=500,
+ desc="_get_rules_for_member_event_ids",
+ )
+
+ members = {
+ row["event_id"]: (row["user_id"], row["membership"])
+ for row in rows
+ }
+
+ # If the event is a join event then it will be in current state evnts
+ # map but not in the DB, so we have to explicitly insert it.
+ if event.type == EventTypes.Member:
+ for event_id in member_event_ids.itervalues():
+ if event_id == event.event_id:
+ members[event_id] = (event.state_key, event.membership)
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Found members %r: %r", self.room_id, members.values())
+
+ interested_in_user_ids = set(
+ user_id for user_id, membership in members.itervalues()
+ if membership == Membership.JOIN
+ )
+
+ logger.debug("Joined: %r", interested_in_user_ids)
+
+ if_users_with_pushers = yield self.store.get_if_users_have_pushers(
+ interested_in_user_ids,
+ on_invalidate=self.invalidate_all_cb,
+ )
+
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ )
+
+ logger.debug("With pushers: %r", user_ids)
+
+ users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
+ self.room_id, on_invalidate=self.invalidate_all_cb,
+ )
+
+ logger.debug("With receipts: %r", users_with_receipts)
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in interested_in_user_ids:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.store.bulk_get_push_rules(
+ user_ids, on_invalidate=self.invalidate_all_cb,
+ )
+
+ ret_rules_by_user.update(
+ item for item in rules_by_user.iteritems() if item[0] is not None
+ )
+
+ self.update_cache(sequence, members, ret_rules_by_user, state_group)
+
+ def invalidate_all(self):
+ # Note: Don't hand this function directly to an invalidation callback
+ # as it keeps a reference to self and will stop this instance from being
+ # GC'd if it gets dropped from the rules_to_user cache. Instead use
+ # `self.invalidate_all_cb`
+ logger.debug("Invalidating RulesForRoom for %r", self.room_id)
+ self.sequence += 1
+ self.state_group = object()
+ self.member_map = {}
+ self.rules_by_user = {}
+ push_rules_invalidation_counter.inc()
+
+ def update_cache(self, sequence, members, rules_by_user, state_group):
+ if sequence == self.sequence:
+ self.member_map.update(members)
+ self.rules_by_user = rules_by_user
+ self.state_group = state_group
+
+
+class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
+ # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+ # which namedtuple does for us (i.e. two _CacheContext are the same if
+ # their caches and keys match). This is important in particular to
+ # dedupe when we add callbacks to lru cache nodes, otherwise the number
+ # of callbacks would grow.
+ def __call__(self):
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
+ if rules:
+ rules.invalidate_all()
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 2eb325c7c7..ba7286cb72 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext
-from mailer import Mailer
logger = logging.getLogger(__name__)
@@ -56,8 +55,10 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
- def __init__(self, hs, pusherdict):
+ def __init__(self, hs, pusherdict, mailer):
self.hs = hs
+ self.mailer = mailer
+
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id']
@@ -73,23 +74,16 @@ class EmailPusher(object):
self.processing = False
- if self.hs.config.email_enable_notifs:
- if 'data' in pusherdict and 'brand' in pusherdict['data']:
- app_name = pusherdict['data']['brand']
- else:
- app_name = self.hs.config.email_app_name
-
- self.mailer = Mailer(self.hs, app_name)
- else:
- self.mailer = None
-
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
- self.throttle_params = yield self.store.get_throttle_params_by_room(
- self.pusher_id
- )
- yield self._process()
+ try:
+ self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.pusher_id
+ )
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting email pusher")
def on_stop(self):
if self.timed_call:
@@ -130,7 +124,7 @@ class EmailPusher(object):
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
- except:
+ except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
@@ -218,7 +212,8 @@ class EmailPusher(object):
)
def seconds_until(self, ts_msec):
- return (ts_msec - self.clock.time_msec()) / 1000
+ secs = (ts_msec - self.clock.time_msec()) / 1000
+ return max(secs, 0)
def get_room_throttle_ms(self, room_id):
if room_id in self.throttle_params:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c0f8176e3d..b077e1a446 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from synapse.push import PusherConfigException
+import logging
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-import logging
-import push_rule_evaluator
-import push_tools
-
+from . import push_rule_evaluator
+from . import push_tools
+import synapse
+from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+http_push_processed_counter = metrics.register_counter(
+ "http_pushes_processed",
+)
+
+http_push_failed_counter = metrics.register_counter(
+ "http_pushes_failed",
+)
+
class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@@ -84,7 +94,10 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_started(self):
- yield self._process()
+ try:
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting http pusher")
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -131,7 +144,7 @@ class HttpPusher(object):
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
- except:
+ except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
@@ -151,9 +164,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
+ logger.info(
+ "Processing %i unprocessed push actions for %s starting at "
+ "stream_ordering %s",
+ len(unprocessed), self.name, self.last_stream_ordering,
+ )
+
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
+ http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -168,6 +188,7 @@ class HttpPusher(object):
self.failing_since
)
else:
+ http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
@@ -244,6 +265,26 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
+ if self.data.get('format') == 'event_id_only':
+ d = {
+ 'notification': {
+ 'event_id': event.event_id,
+ 'room_id': event.room_id,
+ 'counts': {
+ 'unread': badge,
+ },
+ 'devices': [
+ {
+ 'app_id': self.app_id,
+ 'pushkey': self.pushkey,
+ 'pushkey_ts': long(self.pushkey_ts / 1000),
+ 'data': self.data_minus_url,
+ }
+ ]
+ }
+ }
+ defer.returnValue(d)
+
ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id
)
@@ -275,7 +316,7 @@ class HttpPusher(object):
if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id
- if 'content' in event:
+ if self.hs.config.push_include_content and 'content' in event:
d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human
@@ -294,8 +335,11 @@ class HttpPusher(object):
defer.returnValue([])
try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
- except:
- logger.warn("Failed to push %s ", self.url)
+ except Exception:
+ logger.warn(
+ "Failed to push event %s to %s",
+ event.event_id, self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
@@ -304,7 +348,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _send_badge(self, badge):
- logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+ logger.info("Sending updated badge count %d to %s", badge, self.name)
d = {
'notification': {
'id': '',
@@ -325,8 +369,11 @@ class HttpPusher(object):
}
try:
resp = yield self.http_client.post_json_get_json(self.url, d)
- except:
- logger.exception("Failed to push %s ", self.url)
+ except Exception:
+ logger.warn(
+ "Failed to send badge count to %s",
+ self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 62d794f22b..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
class Mailer(object):
- def __init__(self, hs, app_name):
+ def __init__(self, hs, app_name, notif_template_html, notif_template_text):
self.hs = hs
+ self.notif_template_html = notif_template_html
+ self.notif_template_text = notif_template_text
+
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
- loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
+
logger.info("Created Mailer for app_name %s" % app_name)
- env = jinja2.Environment(loader=loader)
- env.filters["format_ts"] = format_ts_filter
- env.filters["mxc_to_http"] = self.mxc_to_http_filter
- self.notif_template_html = env.get_template(
- self.hs.config.email_notif_template_html
- )
- self.notif_template_text = env.get_template(
- self.hs.config.email_notif_template_text
- )
@defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address,
@@ -139,7 +133,7 @@ class Mailer(object):
@defer.inlineCallbacks
def _fetch_room_state(room_id):
- room_state = yield self.state_handler.get_current_state_ids(room_id)
+ room_state = yield self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
@@ -200,7 +194,11 @@ class Mailer(object):
yield sendmail(
self.hs.config.email_smtp_host,
raw_from, raw_to, multipart_msg.as_string(),
- port=self.hs.config.email_smtp_port
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security
)
@defer.inlineCallbacks
@@ -477,28 +475,6 @@ class Mailer(object):
urllib.urlencode(params),
)
- def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- serverAndMediaId = value[6:]
- fragment = None
- if '#' in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
- fragment = "#" + fragment
-
- params = {
- "width": width,
- "height": height,
- "method": resize_method,
- }
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- self.hs.config.public_baseurl,
- serverAndMediaId,
- urllib.urlencode(params),
- fragment or "",
- )
-
def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -539,3 +515,52 @@ def string_ordinal_total(s):
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+ """Load the jinja2 email templates from disk
+
+ Returns:
+ (notif_template_html, notif_template_text)
+ """
+ logger.info("loading jinja2")
+
+ loader = jinja2.FileSystemLoader(config.email_template_dir)
+ env = jinja2.Environment(loader=loader)
+ env.filters["format_ts"] = format_ts_filter
+ env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+ notif_template_html = env.get_template(
+ config.email_notif_template_html
+ )
+ notif_template_text = env.get_template(
+ config.email_notif_template_text
+ )
+
+ return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+ def mxc_to_http_filter(value, width, height, resize_method="crop"):
+ if value[0:6] != "mxc://":
+ return ""
+
+ serverAndMediaId = value[6:]
+ fragment = None
+ if '#' in serverAndMediaId:
+ (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+ fragment = "#" + fragment
+
+ params = {
+ "width": width,
+ "height": height,
+ "method": resize_method,
+ }
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ config.public_baseurl,
+ serverAndMediaId,
+ urllib.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18bd..3601f2d365 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +18,7 @@ import logging
import re
from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -28,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count):
+ return _test_ineq_condition(condition, room_member_count)
+
+
+def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+ notif_level_key = condition.get('key')
+ if notif_level_key is None:
+ return False
+
+ notif_levels = power_levels.get('notifications', {})
+ room_notif_level = notif_levels.get(notif_level_key, 50)
+
+ return sender_power_level >= room_notif_level
+
+
+def _test_ineq_condition(condition, number):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
@@ -40,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
rhs = int(rhs)
if ineq == '' or ineq == '==':
- return room_member_count == rhs
+ return number == rhs
elif ineq == '<':
- return room_member_count < rhs
+ return number < rhs
elif ineq == '>':
- return room_member_count > rhs
+ return number > rhs
elif ineq == '>=':
- return room_member_count >= rhs
+ return number >= rhs
elif ineq == '<=':
- return room_member_count <= rhs
+ return number <= rhs
else:
return False
@@ -64,9 +81,11 @@ def tweaks_for_actions(actions):
class PushRuleEvaluatorForEvent(object):
- def __init__(self, event, room_member_count):
+ def __init__(self, event, room_member_count, sender_power_level, power_levels):
self._event = event
self._room_member_count = room_member_count
+ self._sender_power_level = sender_power_level
+ self._power_levels = power_levels
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
@@ -80,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
return _room_member_count(
self._event, condition, self._room_member_count
)
+ elif condition['kind'] == 'sender_notification_permission':
+ return _sender_notification_permission(
+ self._event, condition, self._sender_power_level, self._power_levels,
+ )
else:
return True
@@ -125,6 +148,11 @@ class PushRuleEvaluatorForEvent(object):
return self._value_cache.get(dotted_key, None)
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
@@ -137,47 +165,78 @@ def _glob_matches(glob, value, word_boundary=False):
Returns:
bool
"""
+
try:
- if IS_GLOB.search(glob):
- r = re.escape(glob)
-
- r = r.replace(r'\*', '.*?')
- r = r.replace(r'\?', '.')
-
- # handle [abc], [a-z] and [!a-z] style ranges.
- r = GLOB_REGEX.sub(
- lambda x: (
- '[%s%s]' % (
- x.group(1) and '^' or '',
- x.group(2).replace(r'\\\-', '-')
- )
- ),
- r,
- )
- if word_boundary:
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
-
- return r.search(value)
- else:
- r = r + "$"
- r = _compile_regex(r)
-
- return r.match(value)
- elif word_boundary:
- r = re.escape(glob)
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
-
- return r.search(value)
- else:
- return value.lower() == glob.lower()
+ r = regex_cache.get((glob, word_boundary), None)
+ if not r:
+ r = _glob_to_re(glob, word_boundary)
+ regex_cache[(glob, word_boundary)] = r
+ return r.search(value)
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
-def _flatten_dict(d, prefix=[], result={}):
+def _glob_to_re(glob, word_boundary):
+ """Generates regex for a given glob.
+
+ Args:
+ glob (string)
+ word_boundary (bool): Whether to match against word boundaries or entire
+ string. Defaults to False.
+
+ Returns:
+ regex object
+ """
+ if IS_GLOB.search(glob):
+ r = re.escape(glob)
+
+ r = r.replace(r'\*', '.*?')
+ r = r.replace(r'\?', '.')
+
+ # handle [abc], [a-z] and [!a-z] style ranges.
+ r = GLOB_REGEX.sub(
+ lambda x: (
+ '[%s%s]' % (
+ x.group(1) and '^' or '',
+ x.group(2).replace(r'\\\-', '-')
+ )
+ ),
+ r,
+ )
+ if word_boundary:
+ r = _re_word_boundary(r)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + r + "$"
+
+ return re.compile(r, flags=re.IGNORECASE)
+ elif word_boundary:
+ r = re.escape(glob)
+ r = _re_word_boundary(r)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + re.escape(glob) + "$"
+ return re.compile(r, flags=re.IGNORECASE)
+
+
+def _re_word_boundary(r):
+ """
+ Adds word boundary characters to the start and end of an
+ expression to require that the match occur as a whole word,
+ but do so respecting the fact that strings starting or ending
+ with non-word characters will change word boundaries.
+ """
+ # we can't use \b as it chokes on unicode. however \W seems to be okay
+ # as shorthand for [^0-9A-Za-z_].
+ return r"(^|\W)%s(\W|$)" % (r,)
+
+
+def _flatten_dict(d, prefix=[], result=None):
+ if result is None:
+ result = {}
for key, value in d.items():
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
@@ -185,16 +244,3 @@ def _flatten_dict(d, prefix=[], result={}):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
- r = regex_cache.get(regex_str, None)
- if r:
- return r
-
- r = re.compile(regex_str, flags=re.IGNORECASE)
- regex_cache[regex_str] = r
- return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a27476bbad..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event
)
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(store.get_invited_rooms_for_user)(user_id),
- preserve_fn(store.get_rooms_for_user)(user_id),
- ], consumeErrors=True))
+ invites = yield store.get_invited_rooms_for_user(user_id)
+ joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
@@ -33,13 +30,13 @@ def get_badge_count(store, user_id):
badge = len(invites)
- for r in joins:
- if r.room_id in my_receipts_by_room:
- last_unread_event_id = my_receipts_by_room[r.room_id]
+ for room_id in joins:
+ if room_id in my_receipts_by_room:
+ last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield (
store.get_unread_event_push_actions_by_room_for_user(
- r.room_id, user_id, last_unread_event_id
+ room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..5aa6667e91 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
-except:
+ from synapse.push.mailer import Mailer, load_jinja2_templates
+except Exception:
pass
-def create_pusher(hs, pusherdict):
- logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+ def __init__(self, hs):
+ self.hs = hs
- PUSHER_TYPES = {
- "http": HttpPusher,
- }
+ self.pusher_types = {
+ "http": HttpPusher,
+ }
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
- PUSHER_TYPES["email"] = EmailPusher
- logger.info("defined email pusher type")
+ logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+ if hs.config.email_enable_notifs:
+ self.mailers = {} # app_name -> Mailer
- if pusherdict['kind'] in PUSHER_TYPES:
- logger.info("found pusher")
- return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+ templates = load_jinja2_templates(hs.config)
+ self.notif_template_html, self.notif_template_text = templates
+
+ self.pusher_types["email"] = self._create_email_pusher
+
+ logger.info("defined email pusher type")
+
+ def create_pusher(self, pusherdict):
+ logger.info("trying to create_pusher for %r", pusherdict)
+
+ if pusherdict['kind'] in self.pusher_types:
+ logger.info("found pusher")
+ return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+ def _create_email_pusher(self, _hs, pusherdict):
+ app_name = self._app_name_from_pusherdict(pusherdict)
+ mailer = self.mailers.get(app_name)
+ if not mailer:
+ mailer = Mailer(
+ hs=self.hs,
+ app_name=app_name,
+ notif_template_html=self.notif_template_html,
+ notif_template_text=self.notif_template_text,
+ )
+ self.mailers[app_name] = mailer
+ return EmailPusher(self.hs, pusherdict, mailer)
+
+ def _app_name_from_pusherdict(self, pusherdict):
+ if 'data' in pusherdict and 'brand' in pusherdict['data']:
+ app_name = pusherdict['data']['brand']
+ else:
+ app_name = self.hs.config.email_app_name
+
+ return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..750d11ca38 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-import pusher
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.push.pusher import PusherFactory
from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
+ self.pusher_factory = PusherFactory(_hs)
self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
- pusher.create_pusher(self.hs, {
+ self.pusher_factory.create_pusher({
"id": None,
"user_name": user_id,
"kind": kind,
@@ -102,19 +103,25 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
- def remove_pushers_by_user(self, user_id, except_access_token_id=None):
- all = yield self.store.get_all_pushers()
- logger.info(
- "Removing all pushers for user %s except access tokens id %r",
- user_id, except_access_token_id
- )
- for p in all:
- if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+ def remove_pushers_by_access_token(self, user_id, access_tokens):
+ """Remove the pushers for a given user corresponding to a set of
+ access_tokens.
+
+ Args:
+ user_id (str): user to remove pushers for
+ access_tokens (Iterable[int]): access token *ids* to remove pushers
+ for
+ """
+ tokens = set(access_tokens)
+ for p in (yield self.store.get_pushers_by_user_id(user_id)):
+ if p['access_token'] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
- yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(
+ p['app_id'], p['pushkey'], p['user_name'],
+ )
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
@@ -130,13 +137,16 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_notifications)(
- min_stream_id, max_stream_id
+ run_in_background(
+ p.on_new_notifications,
+ min_stream_id, max_stream_id,
)
)
- yield preserve_context_over_deferred(defer.gatherResults(deferreds))
- except:
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
+ except Exception:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
@@ -157,11 +167,16 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+ run_in_background(
+ p.on_new_receipts,
+ min_stream_id, max_stream_id,
+ )
)
- yield preserve_context_over_deferred(defer.gatherResults(deferreds))
- except:
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
+ except Exception:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
@@ -186,8 +201,8 @@ class PusherPool:
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
try:
- p = pusher.create_pusher(self.hs, pusherdict)
- except:
+ p = self.pusher_factory.create_pusher(pusherdict)
+ except Exception:
logger.exception("Couldn't start a pusher: caught Exception")
continue
if p:
@@ -200,7 +215,7 @@ class PusherPool:
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
- preserve_fn(p.on_started)()
+ run_in_background(p.on_started)
logger.info("Started pushers")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 7817b0cd91..216db4d164 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,4 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,26 +19,43 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
+# this dict maps from python package name to a list of modules we expect it to
+# provide.
+#
+# the key is a "requirement specifier", as used as a parameter to `pip
+# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
+#
+# the value is a sequence of strings; each entry should be the name of the
+# python module, optionally followed by a version assertion which can be either
+# ">=<ver>" or "==<ver>".
+#
+# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
+# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
REQUIREMENTS = {
+ "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
- "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
+ "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
- "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
+ "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"],
- "pyopenssl>=0.14": ["OpenSSL>=0.14"],
+
+ # We use crypto.get_elliptic_curve which is only supported in >=0.15
+ "pyopenssl>=0.15": ["OpenSSL>=0.15"],
+
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"daemonize": ["daemonize"],
- "py-bcrypt": ["bcrypt"],
+ "bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
- "ujson": ["ujson"],
"blist": ["blist"],
- "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
+ "pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
+ "phonenumbers>=8.2.0": ["phonenumbers"],
+ "six": ["six"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
@@ -55,6 +74,9 @@ CONDITIONAL_REQUIREMENTS = {
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
+ "affinity": {
+ "affinity": ["affinity"],
+ },
}
diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py
deleted file mode 100644
index c05a50d7a6..0000000000
--- a/synapse/replication/expire_cache.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-
-
-class ExpireCacheResource(Resource):
- """
- HTTP endpoint for expiring storage caches.
-
- POST /_synapse/replication/expire_cache HTTP/1.1
- Content-Type: application/json
-
- {
- "invalidate": [
- {
- "name": "func_name",
- "keys": ["key1", "key2"]
- }
- ]
- }
- """
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
- self.clock = hs.get_clock()
-
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
-
- @request_handler()
- def _async_render_POST(self, request):
- content = parse_json_object_from_request(request)
-
- for row in content["invalidate"]:
- name = row["name"]
- keys = tuple(row["keys"])
-
- getattr(self.store, name).invalidate(keys)
-
- respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
new file mode 100644
index 0000000000..1d7a607529
--- /dev/null
+++ b/synapse/replication/http/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import JsonResource
+from synapse.replication.http import membership, send_event
+
+
+REPLICATION_PREFIX = "/_synapse/replication"
+
+
+class ReplicationRestResource(JsonResource):
+ def __init__(self, hs):
+ JsonResource.__init__(self, hs, canonical_json=False)
+ self.register_servlets(hs)
+
+ def register_servlets(self, hs):
+ send_event.register_servlets(hs, self)
+ membership.register_servlets(hs, self)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
new file mode 100644
index 0000000000..e66c4e881f
--- /dev/null
+++ b/synapse/replication/http/membership.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, MatrixCodeMessageException
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import Requester, UserID
+from synapse.util.distributor import user_left_room, user_joined_room
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def remote_join(client, host, port, requester, remote_room_hosts,
+ room_id, user_id, content):
+ """Ask the master to do a remote join for the given user to the given room
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ remote_room_hosts (list[str]): Servers to try and join via
+ room_id (str)
+ user_id (str)
+ content (dict): The event content to use for the join event
+
+ Returns:
+ Deferred
+ """
+ uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "room_id": room_id,
+ "user_id": user_id,
+ "content": content,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def remote_reject_invite(client, host, port, requester, remote_room_hosts,
+ room_id, user_id):
+ """Ask master to reject the invite for the user and room.
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ remote_room_hosts (list[str]): Servers to try and reject via
+ room_id (str)
+ user_id (str)
+
+ Returns:
+ Deferred
+ """
+ uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "room_id": room_id,
+ "user_id": user_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def get_or_register_3pid_guest(client, host, port, requester,
+ medium, address, inviter_user_id):
+ """Ask the master to get/create a guest account for given 3PID.
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
+
+ uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "medium": medium,
+ "address": address,
+ "inviter_user_id": inviter_user_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def notify_user_membership_change(client, host, port, user_id, room_id, change):
+ """Notify master that a user has joined or left the room
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication.
+ user_id (str)
+ room_id (str)
+ change (str): Either "join" or "left"
+
+ Returns:
+ Deferred
+ """
+ assert change in ("joined", "left")
+
+ uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
+
+ payload = {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+class ReplicationRemoteJoinRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
+
+ def __init__(self, hs):
+ super(ReplicationRemoteJoinRestServlet, self).__init__()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ room_id = content["room_id"]
+ user_id = content["user_id"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "remote_join: %s into room: %s",
+ user_id, room_id,
+ )
+
+ yield self.federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user_id,
+ event_content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class ReplicationRemoteRejectInviteRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
+
+ def __init__(self, hs):
+ super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ room_id = content["room_id"]
+ user_id = content["user_id"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "remote_reject_invite: %s out of room: %s",
+ user_id, room_id,
+ )
+
+ try:
+ event = yield self.federation_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ user_id,
+ )
+ ret = event.get_pdu_json()
+ except Exception as e:
+ # if we were unable to reject the exception, just mark
+ # it as rejected on our end and plough ahead.
+ #
+ # The 'except' clause is very broad, but we need to
+ # capture everything from DNS failures upwards
+ #
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ user_id, room_id
+ )
+ ret = {}
+
+ defer.returnValue((200, ret))
+
+
+class ReplicationRegister3PIDGuestRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
+
+ def __init__(self, hs):
+ super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
+
+ self.registeration_handler = hs.get_handlers().registration_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ medium = content["medium"]
+ address = content["address"]
+ inviter_user_id = content["inviter_user_id"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info("get_or_register_3pid_guest: %r", content)
+
+ ret = yield self.registeration_handler.get_or_register_3pid_guest(
+ medium, address, inviter_user_id,
+ )
+
+ defer.returnValue((200, ret))
+
+
+class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
+
+ def __init__(self, hs):
+ super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
+
+ self.registeration_handler = hs.get_handlers().registration_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.distributor = hs.get_distributor()
+
+ def on_POST(self, request, change):
+ content = parse_json_object_from_request(request)
+
+ user_id = content["user_id"]
+ room_id = content["room_id"]
+
+ logger.info("user membership change: %s in %s", user_id, room_id)
+
+ user = UserID.from_string(user_id)
+
+ if change == "joined":
+ user_joined_room(self.distributor, user, room_id)
+ elif change == "left":
+ user_left_room(self.distributor, user, room_id)
+ else:
+ raise Exception("Unrecognized change: %r", change)
+
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ ReplicationRemoteJoinRestServlet(hs).register(http_server)
+ ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
+ ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
+ ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
new file mode 100644
index 0000000000..a9baa2c1c3
--- /dev/null
+++ b/synapse/replication/http/send_event.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import (
+ SynapseError, MatrixCodeMessageException, CodeMessageException,
+)
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.util.async import sleep
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.metrics import Measure
+from synapse.types import Requester, UserID
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def send_event_to_master(client, host, port, requester, event, context,
+ ratelimit, extra_users):
+ """Send event to be handled on the master
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
+ uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
+ host, port, event.event_id,
+ )
+
+ payload = {
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": context.serialize(event),
+ "requester": requester.serialize(),
+ "ratelimit": ratelimit,
+ "extra_users": [u.to_string() for u in extra_users],
+ }
+
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed on
+ # the master, and so whether we should clean up or not.
+ while True:
+ try:
+ result = yield client.put_json(uri, payload)
+ break
+ except CodeMessageException as e:
+ if e.code != 504:
+ raise
+
+ logger.warn("send_event request timed out")
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ yield sleep(1)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+class ReplicationSendEventRestServlet(RestServlet):
+ """Handles events newly created on workers, including persisting and
+ notifying.
+
+ The API looks like:
+
+ POST /_synapse/replication/send_event/:event_id
+
+ {
+ "event": { .. serialized event .. },
+ "internal_metadata": { .. serialized internal_metadata .. },
+ "rejected_reason": .., // The event.rejected_reason field
+ "context": { .. serialized event context .. },
+ "requester": { .. serialized requester .. },
+ "ratelimit": true,
+ "extra_users": [],
+ }
+ """
+ PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
+
+ def __init__(self, hs):
+ super(ReplicationSendEventRestServlet, self).__init__()
+
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ # The responses are tiny, so we may as well cache them for a while
+ self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
+
+ def on_PUT(self, request, event_id):
+ return self.response_cache.wrap(
+ event_id,
+ self._handle_request,
+ request
+ )
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request):
+ with Measure(self.clock, "repl_send_event_parse"):
+ content = parse_json_object_from_request(request)
+
+ event_dict = content["event"]
+ internal_metadata = content["internal_metadata"]
+ rejected_reason = content["rejected_reason"]
+ event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ requester = Requester.deserialize(self.store, content["requester"])
+ context = yield EventContext.deserialize(self.store, content["context"])
+
+ ratelimit = content["ratelimit"]
+ extra_users = [UserID.from_string(u) for u in content["extra_users"]]
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "Got event to send with ID: %s into room: %s",
+ event.event_id, event.room_id,
+ )
+
+ yield self.event_creation_handler.persist_and_notify_client_event(
+ requester, event, context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReplicationSendEventRestServlet(hs).register(http_server)
diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py
deleted file mode 100644
index fc18130ab4..0000000000
--- a/synapse/replication/presence_resource.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-
-class PresenceResource(Resource):
- """
- HTTP endpoint for marking users as syncing.
-
- POST /_synapse/replication/presence HTTP/1.1
- Content-Type: application/json
-
- {
- "process_id": "<process_id>",
- "syncing_users": ["<user_id>"]
- }
- """
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.version_string = hs.version_string
- self.clock = hs.get_clock()
- self.presence_handler = hs.get_presence_handler()
-
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
-
- @request_handler()
- @defer.inlineCallbacks
- def _async_render_POST(self, request):
- content = parse_json_object_from_request(request)
-
- process_id = content["process_id"]
- syncing_user_ids = content["syncing_users"]
-
- yield self.presence_handler.update_external_syncs(
- process_id, set(syncing_user_ids)
- )
-
- respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py
deleted file mode 100644
index 9b01ab3c13..0000000000
--- a/synapse/replication/pusher_resource.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-
-class PusherResource(Resource):
- """
- HTTP endpoint for deleting rejected pushers
- """
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.version_string = hs.version_string
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self.clock = hs.get_clock()
-
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
-
- @request_handler()
- @defer.inlineCallbacks
- def _async_render_POST(self, request):
- content = parse_json_object_from_request(request)
-
- for remove in content["remove"]:
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
- remove["app_id"],
- remove["push_key"],
- remove["user_id"],
- )
-
- self.notifier.on_new_replication_data()
-
- respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
deleted file mode 100644
index d8eb14592b..0000000000
--- a/synapse/replication/resource.py
+++ /dev/null
@@ -1,576 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.servlet import parse_integer, parse_string
-from synapse.http.server import request_handler, finish_request
-from synapse.replication.pusher_resource import PusherResource
-from synapse.replication.presence_resource import PresenceResource
-from synapse.replication.expire_cache import ExpireCacheResource
-from synapse.api.errors import SynapseError
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-import ujson as json
-
-import collections
-import logging
-
-logger = logging.getLogger(__name__)
-
-REPLICATION_PREFIX = "/_synapse/replication"
-
-STREAM_NAMES = (
- ("events",),
- ("presence",),
- ("typing",),
- ("receipts",),
- ("user_account_data", "room_account_data", "tag_account_data",),
- ("backfill",),
- ("push_rules",),
- ("pushers",),
- ("caches",),
- ("to_device",),
- ("public_rooms",),
- ("federation",),
- ("device_lists",),
-)
-
-
-class ReplicationResource(Resource):
- """
- HTTP endpoint for extracting data from synapse.
-
- The streams of data returned by the endpoint are controlled by the
- parameters given to the API. To return a given stream pass a query
- parameter with a position in the stream to return data from or the
- special value "-1" to return data from the start of the stream.
-
- If there is no data for any of the supplied streams after the given
- position then the request will block until there is data for one
- of the streams. This allows clients to long-poll this API.
-
- The possible streams are:
-
- * "streams": A special stream returing the positions of other streams.
- * "events": The new events seen on the server.
- * "presence": Presence updates.
- * "typing": Typing updates.
- * "receipts": Receipt updates.
- * "user_account_data": Top-level per user account data.
- * "room_account_data: Per room per user account data.
- * "tag_account_data": Per room per user tags.
- * "backfill": Old events that have been backfilled from other servers.
- * "push_rules": Per user changes to push rules.
- * "pushers": Per user changes to their pushers.
- * "caches": Cache invalidations.
-
- The API takes two additional query parameters:
-
- * "timeout": How long to wait before returning an empty response.
- * "limit": The maximum number of rows to return for the selected streams.
-
- The response is a JSON object with keys for each stream with updates. Under
- each key is a JSON object with:
-
- * "position": The current position of the stream.
- * "field_names": The names of the fields in each row.
- * "rows": The updates as an array of arrays.
-
- There are a number of ways this API could be used:
-
- 1) To replicate the contents of the backing database to another database.
- 2) To be notified when the contents of a shared backing database changes.
- 3) To "tail" the activity happening on a server for debugging.
-
- In the first case the client would track all of the streams and store it's
- own copy of the data.
-
- In the second case the client might theoretically just be able to follow
- the "streams" stream to track where the other streams are. However in
- practise it will probably need to get the contents of the streams in
- order to expire the any in-memory caches. Whether it gets the contents
- of the streams from this replication API or directly from the backing
- store is a matter of taste.
-
- In the third case the client would use the "streams" stream to find what
- streams are available and their current positions. Then it can start
- long-polling this replication API for new data on those streams.
- """
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.version_string = hs.version_string
- self.store = hs.get_datastore()
- self.sources = hs.get_event_sources()
- self.presence_handler = hs.get_presence_handler()
- self.typing_handler = hs.get_typing_handler()
- self.federation_sender = hs.get_federation_sender()
- self.notifier = hs.notifier
- self.clock = hs.get_clock()
- self.config = hs.get_config()
-
- self.putChild("remove_pushers", PusherResource(hs))
- self.putChild("syncing_users", PresenceResource(hs))
- self.putChild("expire_cache", ExpireCacheResource(hs))
-
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
-
- @defer.inlineCallbacks
- def current_replication_token(self):
- stream_token = yield self.sources.get_current_token()
- backfill_token = yield self.store.get_current_backfill_token()
- push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
- pushers_token = self.store.get_pushers_stream_token()
- caches_token = self.store.get_cache_stream_token()
- public_rooms_token = self.store.get_current_public_room_stream_id()
- federation_token = self.federation_sender.get_current_token()
- device_list_token = self.store.get_device_stream_token()
-
- defer.returnValue(_ReplicationToken(
- room_stream_token,
- int(stream_token.presence_key),
- int(stream_token.typing_key),
- int(stream_token.receipt_key),
- int(stream_token.account_data_key),
- backfill_token,
- push_rules_token,
- pushers_token,
- 0, # State stream is no longer a thing
- caches_token,
- int(stream_token.to_device_key),
- int(public_rooms_token),
- int(federation_token),
- int(device_list_token),
- ))
-
- @request_handler()
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
- limit = parse_integer(request, "limit", 100)
- timeout = parse_integer(request, "timeout", 10 * 1000)
-
- request.setHeader(b"Content-Type", b"application/json")
-
- request_streams = {
- name: parse_integer(request, name)
- for names in STREAM_NAMES for name in names
- }
- request_streams["streams"] = parse_string(request, "streams")
-
- federation_ack = parse_integer(request, "federation_ack", None)
-
- def replicate():
- return self.replicate(
- request_streams, limit,
- federation_ack=federation_ack
- )
-
- writer = yield self.notifier.wait_for_replication(replicate, timeout)
- result = writer.finish()
-
- for stream_name, stream_content in result.items():
- logger.info(
- "Replicating %d rows of %s from %s -> %s",
- len(stream_content["rows"]),
- stream_name,
- request_streams.get(stream_name),
- stream_content["position"],
- )
-
- request.write(json.dumps(result, ensure_ascii=False))
- finish_request(request)
-
- @defer.inlineCallbacks
- def replicate(self, request_streams, limit, federation_ack=None):
- writer = _Writer()
- current_token = yield self.current_replication_token()
- logger.debug("Replicating up to %r", current_token)
-
- if limit == 0:
- raise SynapseError(400, "Limit cannot be 0")
-
- yield self.account_data(writer, current_token, limit, request_streams)
- yield self.events(writer, current_token, limit, request_streams)
- # TODO: implement limit
- yield self.presence(writer, current_token, request_streams)
- yield self.typing(writer, current_token, request_streams)
- yield self.receipts(writer, current_token, limit, request_streams)
- yield self.push_rules(writer, current_token, limit, request_streams)
- yield self.pushers(writer, current_token, limit, request_streams)
- yield self.caches(writer, current_token, limit, request_streams)
- yield self.to_device(writer, current_token, limit, request_streams)
- yield self.public_rooms(writer, current_token, limit, request_streams)
- yield self.device_lists(writer, current_token, limit, request_streams)
- self.federation(writer, current_token, limit, request_streams, federation_ack)
- self.streams(writer, current_token, request_streams)
-
- logger.debug("Replicated %d rows", writer.total)
- defer.returnValue(writer)
-
- def streams(self, writer, current_token, request_streams):
- request_token = request_streams.get("streams")
-
- streams = []
-
- if request_token is not None:
- if request_token == "-1":
- for names, position in zip(STREAM_NAMES, current_token):
- streams.extend((name, position) for name in names)
- else:
- items = zip(
- STREAM_NAMES,
- current_token,
- _ReplicationToken(request_token)
- )
- for names, current_id, last_id in items:
- if last_id < current_id:
- streams.extend((name, current_id) for name in names)
-
- if streams:
- writer.write_header_and_rows(
- "streams", streams, ("name", "position"),
- position=str(current_token)
- )
-
- @defer.inlineCallbacks
- def events(self, writer, current_token, limit, request_streams):
- request_events = request_streams.get("events")
- request_backfill = request_streams.get("backfill")
-
- if request_events is not None or request_backfill is not None:
- if request_events is None:
- request_events = current_token.events
- if request_backfill is None:
- request_backfill = current_token.backfill
-
- no_new_tokens = (
- request_events == current_token.events
- and request_backfill == current_token.backfill
- )
- if no_new_tokens:
- return
-
- res = yield self.store.get_all_new_events(
- request_backfill, request_events,
- current_token.backfill, current_token.events,
- limit
- )
-
- upto_events_token = _position_from_rows(
- res.new_forward_events, current_token.events
- )
-
- upto_backfill_token = _position_from_rows(
- res.new_backfill_events, current_token.backfill
- )
-
- if request_events != upto_events_token:
- writer.write_header_and_rows("events", res.new_forward_events, (
- "position", "internal", "json", "state_group"
- ), position=upto_events_token)
-
- if request_backfill != upto_backfill_token:
- writer.write_header_and_rows("backfill", res.new_backfill_events, (
- "position", "internal", "json", "state_group",
- ), position=upto_backfill_token)
-
- writer.write_header_and_rows(
- "forward_ex_outliers", res.forward_ex_outliers,
- ("position", "event_id", "state_group"),
- )
- writer.write_header_and_rows(
- "backward_ex_outliers", res.backward_ex_outliers,
- ("position", "event_id", "state_group"),
- )
-
- @defer.inlineCallbacks
- def presence(self, writer, current_token, request_streams):
- current_position = current_token.presence
-
- request_presence = request_streams.get("presence")
-
- if request_presence is not None and request_presence != current_position:
- presence_rows = yield self.presence_handler.get_all_presence_updates(
- request_presence, current_position
- )
- upto_token = _position_from_rows(presence_rows, current_position)
- writer.write_header_and_rows("presence", presence_rows, (
- "position", "user_id", "state", "last_active_ts",
- "last_federation_update_ts", "last_user_sync_ts",
- "status_msg", "currently_active",
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def typing(self, writer, current_token, request_streams):
- current_position = current_token.typing
-
- request_typing = request_streams.get("typing")
-
- if request_typing is not None and request_typing != current_position:
- # If they have a higher token than current max, we can assume that
- # they had been talking to a previous instance of the master. Since
- # we reset the token on restart, the best (but hacky) thing we can
- # do is to simply resend down all the typing notifications.
- if request_typing > current_position:
- request_typing = 0
-
- typing_rows = yield self.typing_handler.get_all_typing_updates(
- request_typing, current_position
- )
- upto_token = _position_from_rows(typing_rows, current_position)
- writer.write_header_and_rows("typing", typing_rows, (
- "position", "room_id", "typing"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def receipts(self, writer, current_token, limit, request_streams):
- current_position = current_token.receipts
-
- request_receipts = request_streams.get("receipts")
-
- if request_receipts is not None and request_receipts != current_position:
- receipts_rows = yield self.store.get_all_updated_receipts(
- request_receipts, current_position, limit
- )
- upto_token = _position_from_rows(receipts_rows, current_position)
- writer.write_header_and_rows("receipts", receipts_rows, (
- "position", "room_id", "receipt_type", "user_id", "event_id", "data"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def account_data(self, writer, current_token, limit, request_streams):
- current_position = current_token.account_data
-
- user_account_data = request_streams.get("user_account_data")
- room_account_data = request_streams.get("room_account_data")
- tag_account_data = request_streams.get("tag_account_data")
-
- if user_account_data is not None or room_account_data is not None:
- if user_account_data is None:
- user_account_data = current_position
- if room_account_data is None:
- room_account_data = current_position
-
- no_new_tokens = (
- user_account_data == current_position
- and room_account_data == current_position
- )
- if no_new_tokens:
- return
-
- user_rows, room_rows = yield self.store.get_all_updated_account_data(
- user_account_data, room_account_data, current_position, limit
- )
-
- upto_users_token = _position_from_rows(user_rows, current_position)
- upto_rooms_token = _position_from_rows(room_rows, current_position)
-
- writer.write_header_and_rows("user_account_data", user_rows, (
- "position", "user_id", "type", "content"
- ), position=upto_users_token)
- writer.write_header_and_rows("room_account_data", room_rows, (
- "position", "user_id", "room_id", "type", "content"
- ), position=upto_rooms_token)
-
- if tag_account_data is not None:
- tag_rows = yield self.store.get_all_updated_tags(
- tag_account_data, current_position, limit
- )
- upto_tag_token = _position_from_rows(tag_rows, current_position)
- writer.write_header_and_rows("tag_account_data", tag_rows, (
- "position", "user_id", "room_id", "tags"
- ), position=upto_tag_token)
-
- @defer.inlineCallbacks
- def push_rules(self, writer, current_token, limit, request_streams):
- current_position = current_token.push_rules
-
- push_rules = request_streams.get("push_rules")
-
- if push_rules is not None and push_rules != current_position:
- rows = yield self.store.get_all_push_rule_updates(
- push_rules, current_position, limit
- )
- upto_token = _position_from_rows(rows, current_position)
- writer.write_header_and_rows("push_rules", rows, (
- "position", "event_stream_ordering", "user_id", "rule_id", "op",
- "priority_class", "priority", "conditions", "actions"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def pushers(self, writer, current_token, limit, request_streams):
- current_position = current_token.pushers
-
- pushers = request_streams.get("pushers")
-
- if pushers is not None and pushers != current_position:
- updated, deleted = yield self.store.get_all_updated_pushers(
- pushers, current_position, limit
- )
- upto_token = _position_from_rows(updated, current_position)
- writer.write_header_and_rows("pushers", updated, (
- "position", "user_id", "access_token", "profile_tag", "kind",
- "app_id", "app_display_name", "device_display_name", "pushkey",
- "ts", "lang", "data"
- ), position=upto_token)
- writer.write_header_and_rows("deleted_pushers", deleted, (
- "position", "user_id", "app_id", "pushkey"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def caches(self, writer, current_token, limit, request_streams):
- current_position = current_token.caches
-
- caches = request_streams.get("caches")
-
- if caches is not None and caches != current_position:
- updated_caches = yield self.store.get_all_updated_caches(
- caches, current_position, limit
- )
- upto_token = _position_from_rows(updated_caches, current_position)
- writer.write_header_and_rows("caches", updated_caches, (
- "position", "cache_func", "keys", "invalidation_ts"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def to_device(self, writer, current_token, limit, request_streams):
- current_position = current_token.to_device
-
- to_device = request_streams.get("to_device")
-
- if to_device is not None and to_device != current_position:
- to_device_rows = yield self.store.get_all_new_device_messages(
- to_device, current_position, limit
- )
- upto_token = _position_from_rows(to_device_rows, current_position)
- writer.write_header_and_rows("to_device", to_device_rows, (
- "position", "user_id", "device_id", "message_json"
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def public_rooms(self, writer, current_token, limit, request_streams):
- current_position = current_token.public_rooms
-
- public_rooms = request_streams.get("public_rooms")
-
- if public_rooms is not None and public_rooms != current_position:
- public_rooms_rows = yield self.store.get_all_new_public_rooms(
- public_rooms, current_position, limit
- )
- upto_token = _position_from_rows(public_rooms_rows, current_position)
- writer.write_header_and_rows("public_rooms", public_rooms_rows, (
- "position", "room_id", "visibility", "appservice_id", "network_id",
- ), position=upto_token)
-
- def federation(self, writer, current_token, limit, request_streams, federation_ack):
- if self.config.send_federation:
- return
-
- current_position = current_token.federation
-
- federation = request_streams.get("federation")
-
- if federation is not None and federation != current_position:
- federation_rows = self.federation_sender.get_replication_rows(
- federation, limit, federation_ack=federation_ack,
- )
- upto_token = _position_from_rows(federation_rows, current_position)
- writer.write_header_and_rows("federation", federation_rows, (
- "position", "type", "content",
- ), position=upto_token)
-
- @defer.inlineCallbacks
- def device_lists(self, writer, current_token, limit, request_streams):
- current_position = current_token.device_lists
-
- device_lists = request_streams.get("device_lists")
-
- if device_lists is not None and device_lists != current_position:
- changes = yield self.store.get_all_device_list_changes_for_remotes(
- device_lists,
- )
- writer.write_header_and_rows("device_lists", changes, (
- "position", "user_id", "destination",
- ), position=current_position)
-
-
-class _Writer(object):
- """Writes the streams as a JSON object as the response to the request"""
- def __init__(self):
- self.streams = {}
- self.total = 0
-
- def write_header_and_rows(self, name, rows, fields, position=None):
- if position is None:
- if rows:
- position = rows[-1][0]
- else:
- return
-
- self.streams[name] = {
- "position": position if type(position) is int else str(position),
- "field_names": fields,
- "rows": rows,
- }
-
- self.total += len(rows)
-
- def __nonzero__(self):
- return bool(self.total)
-
- def finish(self):
- return self.streams
-
-
-class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
- "events", "presence", "typing", "receipts", "account_data", "backfill",
- "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
- "federation", "device_lists",
-))):
- __slots__ = []
-
- def __new__(cls, *args):
- if len(args) == 1:
- streams = [int(value) for value in args[0].split("_")]
- if len(streams) < len(cls._fields):
- streams.extend([0] * (len(cls._fields) - len(streams)))
- return cls(*streams)
- else:
- return super(_ReplicationToken, cls).__new__(cls, *args)
-
- def __str__(self):
- return "_".join(str(value) for value in self)
-
-
-def _position_from_rows(rows, current_position):
- """Calculates a position to return for a stream. Ideally we want to return the
- position of the last row, as that will be the most correct. However, if there
- are no rows we fall back to using the current position to stop us from
- repeatedly hitting the storage layer unncessarily thinking there are updates.
- (Not all advances of the token correspond to an actual update)
-
- We can't just always return the current position, as we often limit the
- number of rows we replicate, and so the stream may lag. The assumption is
- that if the storage layer returns no new rows then we are not lagging and
- we are at the `current_position`.
- """
- if rows:
- return rows[-1][0]
- return current_position
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 18076e0f3b..61f5590c53 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,6 @@
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
-from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
- super(BaseSlavedStore, self).__init__(hs)
+ super(BaseSlavedStore, self).__init__(db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
@@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
else:
self._cache_id_gen = None
- self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
- self.http_client = hs.get_simple_http_client()
+ self.hs = hs
def stream_positions(self):
pos = {}
@@ -43,33 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
- def process_replication(self, result):
- stream = result.get("caches")
- if stream:
- for row in stream["rows"]:
- (
- position, cache_func, keys, invalidation_ts,
- ) = row
-
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "caches":
+ self._cache_id_gen.advance(token)
+ for row in rows:
try:
- getattr(self, cache_func).invalidate(tuple(keys))
+ getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError:
- logger.info("Got unexpected cache_func: %r", cache_func)
- self._cache_id_gen.advance(int(stream["position"]))
- return defer.succeed(None)
+ # We probably haven't pulled in the cache in this worker,
+ # which is fine.
+ pass
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
- @defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys):
- try:
- yield self.http_client.post_json_get_json(self.expire_cache_url, {
- "invalidate": [{
- "name": cache_func.__name__,
- "keys": list(keys),
- }]
- })
- except:
- logger.exception("Failed to poke on expire_cache")
+ self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 24b5c79d4a..9d1d173b2f 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -27,4 +27,9 @@ class SlavedIdTracker(object):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
+ """
+
+ Returns:
+ int
+ """
return self._current
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..d9ba6d69b1 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,50 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.account_data import AccountDataStore
-from synapse.storage.tags import TagsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.tags import TagsWorkerStore
-class SlavedAccountDataStore(BaseSlavedStore):
+class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache",
- self._account_data_id_gen.get_current_token(),
- )
-
- get_account_data_for_user = (
- AccountDataStore.__dict__["get_account_data_for_user"]
- )
-
- get_global_account_data_by_type_for_users = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
- )
- get_global_account_data_by_type_for_user = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
- )
-
- get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
- get_tags_for_room = (
- DataStore.get_tags_for_room.__func__
- )
- get_account_data_for_room = (
- DataStore.get_account_data_for_room.__func__
- )
-
- get_updated_tags = DataStore.get_updated_tags.__func__
- get_updated_account_data_for_user = (
- DataStore.get_updated_account_data_for_user.__func__
- )
+ super(SlavedAccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
@@ -69,38 +40,29 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position
return result
- def process_replication(self, result):
- stream = result.get("user_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id, data_type = row[:3]
- self.get_global_account_data_by_type_for_user.invalidate(
- (data_type, user_id,)
- )
- self.get_account_data_for_user.invalidate((user_id,))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "tag_account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- stream = result.get("room_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_account_data_for_user.invalidate((user_id,))
- self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ elif stream_name == "account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id,)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type,),
)
-
- stream = result.get("tag_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- return super(SlavedAccountDataStore, self).process_replication(result)
+ return super(SlavedAccountDataStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a374f2f1a2..8cae3076f4 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.config.appservice import load_appservices
+from synapse.storage.appservice import (
+ ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
+)
-class SlavedApplicationServiceStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
- self.services_cache = load_appservices(
- hs.config.server_name,
- hs.config.app_service_config_files
- )
-
- get_app_service_by_token = DataStore.get_app_service_by_token.__func__
- get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
- get_app_services = DataStore.get_app_services.__func__
- get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
- create_appservice_txn = DataStore.create_appservice_txn.__func__
- get_appservices_by_state = DataStore.get_appservices_by_state.__func__
- get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
- _get_last_txn = DataStore._get_last_txn.__func__
- complete_appservice_txn = DataStore.complete_appservice_txn.__func__
- get_appservice_state = DataStore.get_appservice_state.__func__
- set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
- set_appservice_state = DataStore.set_appservice_state.__func__
+class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore):
+ pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
new file mode 100644
index 0000000000..352c9a2aa8
--- /dev/null
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import Cache
+
+
+class SlavedClientIpStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedClientIpStore, self).__init__(db_conn, hs)
+
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
+ )
+
+ def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+ now = int(self._clock.time_msec())
+ key = (user_id, access_token, ip)
+
+ try:
+ last_seen = self.client_ip_last_seen.get(key)
+ except KeyError:
+ last_seen = None
+
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ return
+
+ self.hs.get_tcp_replication().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index cc860f9f9b..6f3fb64770 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -17,6 +17,7 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.expiringcache import ExpiringCache
class SlavedDeviceInboxStore(BaseSlavedStore):
@@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.get_current_token()
)
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
@@ -45,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("to_device")
- if stream:
- self._device_inbox_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- stream_id = row[0]
- entity = row[1]
-
- if entity.startswith("@"):
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "to_device":
+ self._device_inbox_id_gen.advance(token)
+ for row in rows:
+ if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
- entity, stream_id
+ row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
- entity, stream_id
+ row.entity, token
)
-
- return super(SlavedDeviceInboxStore, self).process_replication(result)
+ return super(SlavedDeviceInboxStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ca46aa17b6..7687867aee 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,6 +16,7 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
+from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -45,28 +46,25 @@ class SlavedDeviceStore(BaseSlavedStore):
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
+ count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("device_lists")
- if stream:
- self._device_list_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- stream_id = row[0]
- user_id = row[1]
- destination = row[2]
-
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "device_lists":
+ self._device_list_id_gen.advance(token)
+ for row in rows:
self._device_list_stream_cache.entity_has_changed(
- user_id, stream_id
+ row.user_id, token
)
- if destination:
+ if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
- destination, stream_id
+ row.destination, token
)
-
- return super(SlavedDeviceStore, self).process_replication(result)
+ return super(SlavedDeviceStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 7301d885f2..6deecd3963 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -14,10 +14,8 @@
# limitations under the License.
from ._base import BaseSlavedStore
-from synapse.storage.directory import DirectoryStore
+from synapse.storage.directory import DirectoryWorkerStore
-class DirectoryStore(BaseSlavedStore):
- get_aliases_for_room = DirectoryStore.__dict__[
- "get_aliases_for_room"
- ]
+class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d72ff6055c..b1f64ef0d8 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,22 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-
-from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
-from synapse.storage import DataStore
-from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.event_push_actions import EventPushActionsStore
-from synapse.storage.state import StateStore
-from synapse.storage.stream import StreamStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-import ujson as json
import logging
+from synapse.api.constants import EventTypes
+from synapse.storage.event_federation import EventFederationWorkerStore
+from synapse.storage.event_push_actions import EventPushActionsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.storage.state import StateGroupWorkerStore
+from synapse.storage.stream import StreamWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -41,152 +38,33 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(BaseSlavedStore):
+class SlavedEventStore(EventFederationWorkerStore,
+ RoomMemberWorkerStore,
+ EventPushActionsWorkerStore,
+ StreamWorkerStore,
+ EventsWorkerStore,
+ StateGroupWorkerStore,
+ SignatureWorkerStore,
+ BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
- events_max = self._stream_id_gen.get_current_token()
- event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
- entity_column="room_id",
- stream_column="stream_ordering",
- max_value=events_max,
- )
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
- prefilled_cache=event_cache_prefill,
- )
- self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
- )
- self.stream_ordering_month_ago = 0
- self._stream_order_on_start = self.get_room_max_stream_ordering()
+ super(SlavedEventStore, self).__init__(db_conn, hs)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
- get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
- get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
- get_users_who_share_room_with_user = (
- RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
- )
- get_latest_event_ids_in_room = EventFederationStore.__dict__[
- "get_latest_event_ids_in_room"
- ]
- get_invited_rooms_for_user = RoomMemberStore.__dict__[
- "get_invited_rooms_for_user"
- ]
- get_unread_event_push_actions_by_room_for_user = (
- EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
- )
- _get_state_group_for_events = (
- StateStore.__dict__["_get_state_group_for_events"]
- )
- _get_state_group_for_event = (
- StateStore.__dict__["_get_state_group_for_event"]
- )
- _get_state_groups_from_groups = (
- StateStore.__dict__["_get_state_groups_from_groups"]
- )
- _get_state_groups_from_groups_txn = (
- DataStore._get_state_groups_from_groups_txn.__func__
- )
- _get_state_group_from_group = (
- StateStore.__dict__["_get_state_group_from_group"]
- )
- get_recent_event_ids_for_room = (
- StreamStore.__dict__["get_recent_event_ids_for_room"]
- )
-
- get_unread_push_actions_for_user_in_range_for_http = (
- DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
- )
- get_unread_push_actions_for_user_in_range_for_email = (
- DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
- )
- get_push_action_users_in_range = (
- DataStore.get_push_action_users_in_range.__func__
- )
- get_event = DataStore.get_event.__func__
- get_events = DataStore.get_events.__func__
- get_rooms_for_user_where_membership_is = (
- DataStore.get_rooms_for_user_where_membership_is.__func__
- )
- get_membership_changes_for_user = (
- DataStore.get_membership_changes_for_user.__func__
- )
- get_room_events_max_id = DataStore.get_room_events_max_id.__func__
- get_room_events_stream_for_room = (
- DataStore.get_room_events_stream_for_room.__func__
- )
- get_events_around = DataStore.get_events_around.__func__
- get_state_for_event = DataStore.get_state_for_event.__func__
- get_state_for_events = DataStore.get_state_for_events.__func__
- get_state_groups = DataStore.get_state_groups.__func__
- get_state_groups_ids = DataStore.get_state_groups_ids.__func__
- get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
- get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
- get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
- get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
- _get_joined_users_from_context = (
- RoomMemberStore.__dict__["_get_joined_users_from_context"]
- )
-
- get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
- get_room_events_stream_for_rooms = (
- DataStore.get_room_events_stream_for_rooms.__func__
- )
- is_host_joined = DataStore.is_host_joined.__func__
- _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
- get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
-
- _set_before_and_after = staticmethod(DataStore._set_before_and_after)
-
- _get_events = DataStore._get_events.__func__
- _get_events_from_cache = DataStore._get_events_from_cache.__func__
- _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
- _enqueue_events = DataStore._enqueue_events.__func__
- _do_fetch = DataStore._do_fetch.__func__
- _fetch_event_rows = DataStore._fetch_event_rows.__func__
- _get_event_from_row = DataStore._get_event_from_row.__func__
- _get_rooms_for_user_where_membership_is_txn = (
- DataStore._get_rooms_for_user_where_membership_is_txn.__func__
- )
- _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
- _get_state_for_groups = DataStore._get_state_for_groups.__func__
- _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
- _get_events_around_txn = DataStore._get_events_around_txn.__func__
- _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
- get_backfill_events = DataStore.get_backfill_events.__func__
- _get_backfill_events = DataStore._get_backfill_events.__func__
- get_missing_events = DataStore.get_missing_events.__func__
- _get_missing_events = DataStore._get_missing_events.__func__
-
- get_auth_chain = DataStore.get_auth_chain.__func__
- get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
- _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
-
- get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
-
- get_forward_extremeties_for_room = (
- DataStore.get_forward_extremeties_for_room.__func__
- )
- _get_forward_extremeties_for_room = (
- EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
- )
-
- get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
-
- get_federation_out_pos = DataStore.get_federation_out_pos.__func__
- update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
@@ -194,84 +72,47 @@ class SlavedEventStore(BaseSlavedStore):
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("events")
- if stream:
- self._stream_id_gen.advance(int(stream["position"]))
-
- if stream["rows"]:
- logger.info("Got %d event rows", len(stream["rows"]))
-
- for row in stream["rows"]:
- self._process_replication_row(
- row, backfilled=False,
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "events":
+ self._stream_id_gen.advance(token)
+ for row in rows:
+ self.invalidate_caches_for_event(
+ token, row.event_id, row.room_id, row.type, row.state_key,
+ row.redacts,
+ backfilled=False,
)
-
- stream = result.get("backfill")
- if stream:
- self._backfill_id_gen.advance(-int(stream["position"]))
- for row in stream["rows"]:
- self._process_replication_row(
- row, backfilled=True,
+ elif stream_name == "backfill":
+ self._backfill_id_gen.advance(-token)
+ for row in rows:
+ self.invalidate_caches_for_event(
+ -token, row.event_id, row.room_id, row.type, row.state_key,
+ row.redacts,
+ backfilled=True,
)
-
- stream = result.get("forward_ex_outliers")
- if stream:
- self._stream_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- event_id = row[1]
- self._invalidate_get_event_cache(event_id)
-
- stream = result.get("backward_ex_outliers")
- if stream:
- self._backfill_id_gen.advance(-int(stream["position"]))
- for row in stream["rows"]:
- event_id = row[1]
- self._invalidate_get_event_cache(event_id)
-
- return super(SlavedEventStore, self).process_replication(result)
-
- def _process_replication_row(self, row, backfilled):
- internal = json.loads(row[1])
- event_json = json.loads(row[2])
- event = FrozenEvent(event_json, internal_metadata_dict=internal)
- self.invalidate_caches_for_event(
- event, backfilled,
+ return super(SlavedEventStore, self).process_replication_rows(
+ stream_name, token, rows
)
- def invalidate_caches_for_event(self, event, backfilled):
- self._invalidate_get_event_cache(event.event_id)
+ def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
+ etype, state_key, redacts, backfilled):
+ self._invalidate_get_event_cache(event_id)
- self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- (event.room_id,)
+ (room_id,)
)
if not backfilled:
self._events_stream_cache.entity_has_changed(
- event.room_id, event.internal_metadata.stream_ordering
+ room_id, stream_ordering
)
- # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- # (event.room_id,)
- # )
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
- if event.type == EventTypes.Redaction:
- self._invalidate_get_event_cache(event.redacts)
-
- if event.type == EventTypes.Member:
+ if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(
- event.state_key, event.internal_metadata.stream_ordering
+ state_key, stream_ordering
)
- self.get_invited_rooms_for_user.invalidate((event.state_key,))
-
- if not event.is_state():
- return
-
- if backfilled:
- return
-
- if (not event.internal_metadata.is_invite_from_remote()
- and event.internal_metadata.is_outlier()):
- return
+ self.get_invited_rooms_for_user.invalidate((state_key,))
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
new file mode 100644
index 0000000000..0bc4bce5b0
--- /dev/null
+++ b/synapse/replication/slave/storage/groups.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedGroupServerStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+
+ self.hs = hs
+
+ self._group_updates_id_gen = SlavedIdTracker(
+ db_conn, "local_group_updates", "stream_id",
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+ )
+
+ get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
+ get_group_stream_token = DataStore.get_group_stream_token.__func__
+ get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
+
+ def stream_positions(self):
+ result = super(SlavedGroupServerStore, self).stream_positions()
+ result["groups"] = self._group_updates_id_gen.get_current_token()
+ return result
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "groups":
+ self._group_updates_id_gen.advance(token)
+ for row in rows:
+ self._group_updates_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
+
+ return super(SlavedGroupServerStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 40f6c9a386..cfb9280181 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -39,6 +39,16 @@ class SlavedPresenceStore(BaseSlavedStore):
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
+ # XXX: This is a bit broken because we don't persist the accepted list in a
+ # way that can be replicated. This means that we don't have a way to
+ # invalidate the cache correctly.
+ get_presence_list_accepted = PresenceStore.__dict__[
+ "get_presence_list_accepted"
+ ]
+ get_presence_list_observers_accepted = PresenceStore.__dict__[
+ "get_presence_list_observers_accepted"
+ ]
+
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
@@ -48,14 +58,14 @@ class SlavedPresenceStore(BaseSlavedStore):
result["presence"] = position
return result
- def process_replication(self, result):
- stream = result.get("presence")
- if stream:
- self._presence_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "presence":
+ self._presence_id_gen.advance(token)
+ for row in rows:
self.presence_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- return super(SlavedPresenceStore, self).process_replication(result)
+ self._get_presence_for_user.invalidate((row.user_id,))
+ return super(SlavedPresenceStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
new file mode 100644
index 0000000000..46c28d4171
--- /dev/null
+++ b/synapse/replication/slave/storage/profile.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.storage.profile import ProfileWorkerStore
+
+
+class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 21ceb0213a..bb2c40b6e3 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,29 +16,15 @@
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.push_rule import PushRulesWorkerStore
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
- super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache",
- self._push_rules_stream_id_gen.get_current_token(),
- )
-
- get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
- get_push_rules_enabled_for_user = (
- PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
- )
- have_push_rules_changed_for_user = (
- DataStore.have_push_rules_changed_for_user.__func__
- )
+ super(SlavedPushRuleStore, self).__init__(db_conn, hs)
def get_push_rules_stream_token(self):
return (
@@ -45,23 +32,23 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(),
)
+ def get_max_push_rules_stream_id(self):
+ return self._push_rules_stream_id_gen.get_current_token()
+
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("push_rules")
- if stream:
- for row in stream["rows"]:
- position = row[0]
- user_id = row[2]
- self.get_push_rules_for_user.invalidate((user_id,))
- self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "push_rules":
+ self._push_rules_stream_id_gen.advance(token)
+ for row in rows:
+ self.get_push_rules_for_user.invalidate((row.user_id,))
+ self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- self._push_rules_stream_id_gen.advance(int(stream["position"]))
-
- return super(SlavedPushRuleStore, self).process_replication(result)
+ return super(SlavedPushRuleStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index d88206b3bb..a7cd5a7291 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +17,10 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
+from synapse.storage.pusher import PusherWorkerStore
-class SlavedPusherStore(BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
@@ -28,25 +29,14 @@ class SlavedPusherStore(BaseSlavedStore):
extra_tables=[("deleted_pushers", "stream_id")],
)
- get_all_pushers = DataStore.get_all_pushers.__func__
- get_pushers_by = DataStore.get_pushers_by.__func__
- get_pushers_by_app_id_and_pushkey = (
- DataStore.get_pushers_by_app_id_and_pushkey.__func__
- )
- _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
-
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("pushers")
- if stream:
- self._pushers_id_gen.advance(int(stream["position"]))
-
- stream = result.get("deleted_pushers")
- if stream:
- self._pushers_id_gen.advance(int(stream["position"]))
-
- return super(SlavedPusherStore, self).process_replication(result)
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "pushers":
+ self._pushers_id_gen.advance(token)
+ return super(SlavedPusherStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ac9662d399..1647072f65 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +17,7 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.receipts import ReceiptsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.receipts import ReceiptsWorkerStore
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -29,56 +28,43 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedReceiptsStore(BaseSlavedStore):
+class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedReceiptsStore, self).__init__(db_conn, hs)
-
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
- self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
- )
-
- get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
- get_linearized_receipts_for_room = (
- ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
- )
- _get_linearized_receipts_for_rooms = (
- ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
- )
- get_last_receipt_event_id_for_user = (
- ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
- )
-
- get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
- get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+ super(SlavedReceiptsStore, self).__init__(db_conn, hs)
- get_linearized_receipts_for_rooms = (
- DataStore.get_linearized_receipts_for_rooms.__func__
- )
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("receipts")
- if stream:
- self._receipts_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, room_id, receipt_type, user_id = row[:4]
- self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
- self._receipts_stream_cache.entity_has_changed(room_id, position)
-
- return super(SlavedReceiptsStore, self).process_replication(result)
-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "receipts":
+ self._receipts_id_gen.advance(token)
+ for row in rows:
+ self.invalidate_caches_for_receipt(
+ row.room_id, row.receipt_type, row.user_id
+ )
+ self._receipts_stream_cache.entity_has_changed(row.room_id, token)
+
+ return super(SlavedReceiptsStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index e27c7332d2..7323bf0f1e 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -14,20 +14,8 @@
# limitations under the License.
from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.registration import RegistrationStore
+from synapse.storage.registration import RegistrationWorkerStore
-class SlavedRegistrationStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedRegistrationStore, self).__init__(db_conn, hs)
-
- # TODO: use the cached version and invalidate deleted tokens
- get_user_by_access_token = RegistrationStore.__dict__[
- "get_user_by_access_token"
- ]
-
- _query_for_auth = DataStore._query_for_auth.__func__
- get_user_by_id = RegistrationStore.__dict__[
- "get_user_by_id"
- ]
+class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 6df9a25ef3..5ae1670157 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,41 +14,29 @@
# limitations under the License.
from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.room import RoomStore
+from synapse.storage.room import RoomWorkerStore
from ._slaved_id_tracker import SlavedIdTracker
-class RoomStore(BaseSlavedStore):
+class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
- get_public_room_ids = DataStore.get_public_room_ids.__func__
- get_current_public_room_stream_id = (
- DataStore.get_current_public_room_stream_id.__func__
- )
- get_public_room_ids_at_stream_id = (
- RoomStore.__dict__["get_public_room_ids_at_stream_id"]
- )
- get_public_room_ids_at_stream_id_txn = (
- DataStore.get_public_room_ids_at_stream_id_txn.__func__
- )
- get_published_at_stream_id_txn = (
- DataStore.get_published_at_stream_id_txn.__func__
- )
- get_public_room_changes = DataStore.get_public_room_changes.__func__
+ def get_current_public_room_stream_id(self):
+ return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("public_rooms")
- if stream:
- self._public_room_id_gen.advance(int(stream["position"]))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "public_rooms":
+ self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication(result)
+ return super(RoomStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
new file mode 100644
index 0000000000..81c2ea7ee9
--- /dev/null
+++ b/synapse/replication/tcp/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module implements the TCP replication protocol used by synapse to
+communicate between the master process and its workers (when they're enabled).
+
+Further details can be found in docs/tcp_replication.rst
+
+
+Structure of the module:
+ * client.py - the client classes used for workers to connect to master
+ * command.py - the definitions of all the valid commands
+ * protocol.py - contains bot the client and server protocol implementations,
+ these should not be used directly
+ * resource.py - the server classes that accepts and handle client connections
+ * streams.py - the definitons of all the valid streams
+
+"""
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
new file mode 100644
index 0000000000..6d2513c4e2
--- /dev/null
+++ b/synapse/replication/tcp/client.py
@@ -0,0 +1,203 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A replication client for use by synapse workers.
+"""
+
+from twisted.internet import reactor, defer
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from .commands import (
+ FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ UserIpCommand,
+)
+from .protocol import ClientReplicationStreamProtocol
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationClientFactory(ReconnectingClientFactory):
+ """Factory for building connections to the master. Will reconnect if the
+ connection is lost.
+
+ Accepts a handler that will be called when new data is available or data
+ is required.
+ """
+ maxDelay = 5 # Try at least once every N seconds
+
+ def __init__(self, hs, client_name, handler):
+ self.client_name = client_name
+ self.handler = handler
+ self.server_name = hs.config.server_name
+ self._clock = hs.get_clock() # As self.clock is defined in super class
+
+ reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+
+ def startedConnecting(self, connector):
+ logger.info("Connecting to replication: %r", connector.getDestination())
+
+ def buildProtocol(self, addr):
+ logger.info("Connected to replication: %r", addr)
+ self.resetDelay()
+ return ClientReplicationStreamProtocol(
+ self.client_name, self.server_name, self._clock, self.handler
+ )
+
+ def clientConnectionLost(self, connector, reason):
+ logger.error("Lost replication conn: %r", reason)
+ ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.error("Failed to connect to replication: %r", reason)
+ ReconnectingClientFactory.clientConnectionFailed(
+ self, connector, reason
+ )
+
+
+class ReplicationClientHandler(object):
+ """A base handler that can be passed to the ReplicationClientFactory.
+
+ By default proxies incoming replication data to the SlaveStore.
+ """
+ def __init__(self, store):
+ self.store = store
+
+ # The current connection. None if we are currently (re)connecting
+ self.connection = None
+
+ # Any pending commands to be sent once a new connection has been
+ # established
+ self.pending_commands = []
+
+ # Map from string -> deferred, to wake up when receiveing a SYNC with
+ # the given string.
+ # Used for tests.
+ self.awaiting_syncs = {}
+
+ def start_replication(self, hs):
+ """Helper method to start a replication connection to the remote server
+ using TCP.
+ """
+ client_name = hs.config.worker_name
+ factory = ReplicationClientFactory(hs, client_name, self)
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_port
+ reactor.connectTCP(host, port, factory)
+
+ def on_rdata(self, stream_name, token, rows):
+ """Called when we get new replication data. By default this just pokes
+ the slave store.
+
+ Can be overriden in subclasses to handle more.
+ """
+ logger.info("Received rdata %s -> %s", stream_name, token)
+ self.store.process_replication_rows(stream_name, token, rows)
+
+ def on_position(self, stream_name, token):
+ """Called when we get new position data. By default this just pokes
+ the slave store.
+
+ Can be overriden in subclasses to handle more.
+ """
+ self.store.process_replication_rows(stream_name, token, [])
+
+ def on_sync(self, data):
+ """When we received a SYNC we wake up any deferreds that were waiting
+ for the sync with the given data.
+
+ Used by tests.
+ """
+ d = self.awaiting_syncs.pop(data, None)
+ if d:
+ d.callback(data)
+
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns a dictionary of stream name to token.
+ """
+ args = self.store.stream_positions()
+ user_account_data = args.pop("user_account_data", None)
+ room_account_data = args.pop("room_account_data", None)
+ if user_account_data:
+ args["account_data"] = user_account_data
+ elif room_account_data:
+ args["account_data"] = room_account_data
+ return args
+
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users. (Overriden by the synchrotron's only)
+ """
+ return []
+
+ def send_command(self, cmd):
+ """Send a command to master (when we get establish a connection if we
+ don't have one already.)
+ """
+ if self.connection:
+ self.connection.send_command(cmd)
+ else:
+ logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ self.pending_commands.append(cmd)
+
+ def send_federation_ack(self, token):
+ """Ack data for the federation stream. This allows the master to drop
+ data stored purely in memory.
+ """
+ self.send_command(FederationAckCommand(token))
+
+ def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ """Poke the master that a user has started/stopped syncing.
+ """
+ self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+
+ def send_remove_pusher(self, app_id, push_key, user_id):
+ """Poke the master to remove a pusher for a user
+ """
+ cmd = RemovePusherCommand(app_id, push_key, user_id)
+ self.send_command(cmd)
+
+ def send_invalidate_cache(self, cache_func, keys):
+ """Poke the master to invalidate a cache.
+ """
+ cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+ self.send_command(cmd)
+
+ def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ """Tell the master that the user made a request.
+ """
+ cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+ self.send_command(cmd)
+
+ def await_sync(self, data):
+ """Returns a deferred that is resolved when we receive a SYNC command
+ with given data.
+
+ Used by tests.
+ """
+ return self.awaiting_syncs.setdefault(data, defer.Deferred())
+
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ self.connection = connection
+ if connection:
+ for cmd in self.pending_commands:
+ connection.send_command(cmd)
+ self.pending_commands = []
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
new file mode 100644
index 0000000000..12aac3cc6b
--- /dev/null
+++ b/synapse/replication/tcp/commands.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Defines the various valid commands
+
+The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
+allowed to be sent by which side.
+"""
+
+import logging
+import simplejson
+
+
+logger = logging.getLogger(__name__)
+
+_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
+
+
+class Command(object):
+ """The base command class.
+
+ All subclasses must set the NAME variable which equates to the name of the
+ command on the wire.
+
+ A full command line on the wire is constructed from `NAME + " " + to_line()`
+
+ The default implementation creates a command of form `<NAME> <data>`
+ """
+ NAME = None
+
+ def __init__(self, data):
+ self.data = data
+
+ @classmethod
+ def from_line(cls, line):
+ """Deserialises a line from the wire into this command. `line` does not
+ include the command.
+ """
+ return cls(line)
+
+ def to_line(self):
+ """Serialises the comamnd for the wire. Does not include the command
+ prefix.
+ """
+ return self.data
+
+
+class ServerCommand(Command):
+ """Sent by the server on new connection and includes the server_name.
+
+ Format::
+
+ SERVER <server_name>
+ """
+ NAME = "SERVER"
+
+
+class RdataCommand(Command):
+ """Sent by server when a subscribed stream has an update.
+
+ Format::
+
+ RDATA <stream_name> <token> <row_json>
+
+ The `<token>` may either be a numeric stream id OR "batch". The latter case
+ is used to support sending multiple updates with the same stream ID. This
+ is done by sending an RDATA for each row, with all but the last RDATA having
+ a token of "batch" and the last having the final stream ID.
+
+ The client should batch all incoming RDATA with a token of "batch" (per
+ stream_name) until it sees an RDATA with a numeric stream ID.
+
+ `<token>` of "batch" maps to the instance variable `token` being None.
+
+ An example of a batched series of RDATA::
+
+ RDATA presence batch ["@foo:example.com", "online", ...]
+ RDATA presence batch ["@bar:example.com", "online", ...]
+ RDATA presence 59 ["@baz:example.com", "online", ...]
+ """
+ NAME = "RDATA"
+
+ def __init__(self, stream_name, token, row):
+ self.stream_name = stream_name
+ self.token = token
+ self.row = row
+
+ @classmethod
+ def from_line(cls, line):
+ stream_name, token, row_json = line.split(" ", 2)
+ return cls(
+ stream_name,
+ None if token == "batch" else int(token),
+ simplejson.loads(row_json)
+ )
+
+ def to_line(self):
+ return " ".join((
+ self.stream_name,
+ str(self.token) if self.token is not None else "batch",
+ _json_encoder.encode(self.row),
+ ))
+
+
+class PositionCommand(Command):
+ """Sent by the client to tell the client the stream postition without
+ needing to send an RDATA.
+ """
+ NAME = "POSITION"
+
+ def __init__(self, stream_name, token):
+ self.stream_name = stream_name
+ self.token = token
+
+ @classmethod
+ def from_line(cls, line):
+ stream_name, token = line.split(" ", 1)
+ return cls(stream_name, int(token))
+
+ def to_line(self):
+ return " ".join((self.stream_name, str(self.token),))
+
+
+class ErrorCommand(Command):
+ """Sent by either side if there was an ERROR. The data is a string describing
+ the error.
+ """
+ NAME = "ERROR"
+
+
+class PingCommand(Command):
+ """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+ """
+ NAME = "PING"
+
+
+class NameCommand(Command):
+ """Sent by client to inform the server of the client's identity. The data
+ is the name
+ """
+ NAME = "NAME"
+
+
+class ReplicateCommand(Command):
+ """Sent by the client to subscribe to the stream.
+
+ Format::
+
+ REPLICATE <stream_name> <token>
+
+ Where <token> may be either:
+ * a numeric stream_id to stream updates from
+ * "NOW" to stream all subsequent updates.
+
+ The <stream_name> can be "ALL" to subscribe to all known streams, in which
+ case the <token> must be set to "NOW", i.e.::
+
+ REPLICATE ALL NOW
+ """
+ NAME = "REPLICATE"
+
+ def __init__(self, stream_name, token):
+ self.stream_name = stream_name
+ self.token = token
+
+ @classmethod
+ def from_line(cls, line):
+ stream_name, token = line.split(" ", 1)
+ if token in ("NOW", "now"):
+ token = "NOW"
+ else:
+ token = int(token)
+ return cls(stream_name, token)
+
+ def to_line(self):
+ return " ".join((self.stream_name, str(self.token),))
+
+
+class UserSyncCommand(Command):
+ """Sent by the client to inform the server that a user has started or
+ stopped syncing. Used to calculate presence on the master.
+
+ Includes a timestamp of when the last user sync was.
+
+ Format::
+
+ USER_SYNC <user_id> <state> <last_sync_ms>
+
+ Where <state> is either "start" or "stop"
+ """
+ NAME = "USER_SYNC"
+
+ def __init__(self, user_id, is_syncing, last_sync_ms):
+ self.user_id = user_id
+ self.is_syncing = is_syncing
+ self.last_sync_ms = last_sync_ms
+
+ @classmethod
+ def from_line(cls, line):
+ user_id, state, last_sync_ms = line.split(" ", 2)
+
+ if state not in ("start", "end"):
+ raise Exception("Invalid USER_SYNC state %r" % (state,))
+
+ return cls(user_id, state == "start", int(last_sync_ms))
+
+ def to_line(self):
+ return " ".join((
+ self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
+ ))
+
+
+class FederationAckCommand(Command):
+ """Sent by the client when it has processed up to a given point in the
+ federation stream. This allows the master to drop in-memory caches of the
+ federation stream.
+
+ This must only be sent from one worker (i.e. the one sending federation)
+
+ Format::
+
+ FEDERATION_ACK <token>
+ """
+ NAME = "FEDERATION_ACK"
+
+ def __init__(self, token):
+ self.token = token
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(int(line))
+
+ def to_line(self):
+ return str(self.token)
+
+
+class SyncCommand(Command):
+ """Used for testing. The client protocol implementation allows waiting
+ on a SYNC command with a specified data.
+ """
+ NAME = "SYNC"
+
+
+class RemovePusherCommand(Command):
+ """Sent by the client to request the master remove the given pusher.
+
+ Format::
+
+ REMOVE_PUSHER <app_id> <push_key> <user_id>
+ """
+ NAME = "REMOVE_PUSHER"
+
+ def __init__(self, app_id, push_key, user_id):
+ self.user_id = user_id
+ self.app_id = app_id
+ self.push_key = push_key
+
+ @classmethod
+ def from_line(cls, line):
+ app_id, push_key, user_id = line.split(" ", 2)
+
+ return cls(app_id, push_key, user_id)
+
+ def to_line(self):
+ return " ".join((self.app_id, self.push_key, self.user_id))
+
+
+class InvalidateCacheCommand(Command):
+ """Sent by the client to invalidate an upstream cache.
+
+ THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
+ NOT DISASTROUS IF WE DROP ON THE FLOOR.
+
+ Mainly used to invalidate destination retry timing caches.
+
+ Format::
+
+ INVALIDATE_CACHE <cache_func> <keys_json>
+
+ Where <keys_json> is a json list.
+ """
+ NAME = "INVALIDATE_CACHE"
+
+ def __init__(self, cache_func, keys):
+ self.cache_func = cache_func
+ self.keys = keys
+
+ @classmethod
+ def from_line(cls, line):
+ cache_func, keys_json = line.split(" ", 1)
+
+ return cls(cache_func, simplejson.loads(keys_json))
+
+ def to_line(self):
+ return " ".join((
+ self.cache_func, _json_encoder.encode(self.keys),
+ ))
+
+
+class UserIpCommand(Command):
+ """Sent periodically when a worker sees activity from a client.
+
+ Format::
+
+ USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
+ """
+ NAME = "USER_IP"
+
+ def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ self.user_id = user_id
+ self.access_token = access_token
+ self.ip = ip
+ self.user_agent = user_agent
+ self.device_id = device_id
+ self.last_seen = last_seen
+
+ @classmethod
+ def from_line(cls, line):
+ user_id, jsn = line.split(" ", 1)
+
+ access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+
+ return cls(
+ user_id, access_token, ip, user_agent, device_id, last_seen
+ )
+
+ def to_line(self):
+ return self.user_id + " " + _json_encoder.encode((
+ self.access_token, self.ip, self.user_agent, self.device_id,
+ self.last_seen,
+ ))
+
+
+# Map of command name to command type.
+COMMAND_MAP = {
+ cmd.NAME: cmd
+ for cmd in (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+ )
+}
+
+# The commands the server is allowed to send
+VALID_SERVER_COMMANDS = (
+ ServerCommand.NAME,
+ RdataCommand.NAME,
+ PositionCommand.NAME,
+ ErrorCommand.NAME,
+ PingCommand.NAME,
+ SyncCommand.NAME,
+)
+
+# The commands the client is allowed to send
+VALID_CLIENT_COMMANDS = (
+ NameCommand.NAME,
+ ReplicateCommand.NAME,
+ PingCommand.NAME,
+ UserSyncCommand.NAME,
+ FederationAckCommand.NAME,
+ RemovePusherCommand.NAME,
+ InvalidateCacheCommand.NAME,
+ UserIpCommand.NAME,
+ ErrorCommand.NAME,
+)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
new file mode 100644
index 0000000000..d7d38464b2
--- /dev/null
+++ b/synapse/replication/tcp/protocol.py
@@ -0,0 +1,655 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains the implementation of both the client and server
+protocols.
+
+The basic structure of the protocol is line based, where the initial word of
+each line specifies the command. The rest of the line is parsed based on the
+command. For example, the `RDATA` command is defined as::
+
+ RDATA <stream_name> <token> <row_json>
+
+(Note that `<row_json>` may contains spaces, but cannot contain newlines.)
+
+Blank lines are ignored.
+
+# Example
+
+An example iteraction is shown below. Each line is prefixed with '>' or '<' to
+indicate which side is sending, these are *not* included on the wire::
+
+ * connection established *
+ > SERVER localhost:8823
+ > PING 1490197665618
+ < NAME synapse.app.appservice
+ < PING 1490197665618
+ < REPLICATE events 1
+ < REPLICATE backfill 1
+ < REPLICATE caches 1
+ > POSITION events 1
+ > POSITION backfill 1
+ > POSITION caches 1
+ > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
+ > RDATA events 14 ["$149019767112vOHxz:localhost:8823",
+ "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
+ < PING 1490197675618
+ > ERROR server stopping
+ * connection closed by server *
+"""
+
+from twisted.internet import defer
+from twisted.protocols.basic import LineOnlyReceiver
+from twisted.python.failure import Failure
+
+from .commands import (
+ COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
+ ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
+ NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
+)
+from .streams import STREAMS_MAP
+
+from synapse.util.stringutils import random_string
+from synapse.metrics.metric import CounterMetric
+
+import logging
+import synapse.metrics
+import struct
+import fcntl
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+connection_close_counter = metrics.register_counter(
+ "close_reason", labels=["reason_type"],
+)
+
+
+# A list of all connected protocols. This allows us to send metrics about the
+# connections.
+connected_connections = []
+
+
+logger = logging.getLogger(__name__)
+
+
+PING_TIME = 5000
+PING_TIMEOUT_MULTIPLIER = 5
+PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
+
+
+class ConnectionStates(object):
+ CONNECTING = "connecting"
+ ESTABLISHED = "established"
+ PAUSED = "paused"
+ CLOSED = "closed"
+
+
+class BaseReplicationStreamProtocol(LineOnlyReceiver):
+ """Base replication protocol shared between client and server.
+
+ Reads lines (ignoring blank ones) and parses them into command classes,
+ asserting that they are valid for the given direction, i.e. server commands
+ are only sent by the server.
+
+ On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
+ command.
+
+ It also sends `PING` periodically, and correctly times out remote connections
+ (if they send a `PING` command)
+ """
+ delimiter = b'\n'
+
+ VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
+ VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+
+ max_line_buffer = 10000
+
+ def __init__(self, clock):
+ self.clock = clock
+
+ self.last_received_command = self.clock.time_msec()
+ self.last_sent_command = 0
+ self.time_we_closed = None # When we requested the connection be closed
+
+ self.received_ping = False # Have we reecived a ping from the other side
+
+ self.state = ConnectionStates.CONNECTING
+
+ self.name = "anon" # The name sent by a client.
+ self.conn_id = random_string(5) # To dedupe in case of name clashes.
+
+ # List of pending commands to send once we've established the connection
+ self.pending_commands = []
+
+ # The LoopingCall for sending pings.
+ self._send_ping_loop = None
+
+ self.inbound_commands_counter = CounterMetric(
+ "inbound_commands", labels=["command"],
+ )
+ self.outbound_commands_counter = CounterMetric(
+ "outbound_commands", labels=["command"],
+ )
+
+ def connectionMade(self):
+ logger.info("[%s] Connection established", self.id())
+
+ self.state = ConnectionStates.ESTABLISHED
+
+ connected_connections.append(self) # Register connection for metrics
+
+ self.transport.registerProducer(self, True) # For the *Producing callbacks
+
+ self._send_pending_commands()
+
+ # Starts sending pings
+ self._send_ping_loop = self.clock.looping_call(self.send_ping, 5000)
+
+ # Always send the initial PING so that the other side knows that they
+ # can time us out.
+ self.send_command(PingCommand(self.clock.time_msec()))
+
+ def send_ping(self):
+ """Periodically sends a ping and checks if we should close the connection
+ due to the other side timing out.
+ """
+ now = self.clock.time_msec()
+
+ if self.time_we_closed:
+ if now - self.time_we_closed > PING_TIMEOUT_MS:
+ logger.info(
+ "[%s] Failed to close connection gracefully, aborting", self.id()
+ )
+ self.transport.abortConnection()
+ else:
+ if now - self.last_sent_command >= PING_TIME:
+ self.send_command(PingCommand(now))
+
+ if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+ logger.info(
+ "[%s] Connection hasn't received command in %r ms. Closing.",
+ self.id(), now - self.last_received_command
+ )
+ self.send_error("ping timeout")
+
+ def lineReceived(self, line):
+ """Called when we've received a line
+ """
+ if line.strip() == "":
+ # Ignore blank lines
+ return
+
+ line = line.decode("utf-8")
+ cmd_name, rest_of_line = line.split(" ", 1)
+
+ if cmd_name not in self.VALID_INBOUND_COMMANDS:
+ logger.error("[%s] invalid command %s", self.id(), cmd_name)
+ self.send_error("invalid command: %s", cmd_name)
+ return
+
+ self.last_received_command = self.clock.time_msec()
+
+ self.inbound_commands_counter.inc(cmd_name)
+
+ cmd_cls = COMMAND_MAP[cmd_name]
+ try:
+ cmd = cmd_cls.from_line(rest_of_line)
+ except Exception as e:
+ logger.exception(
+ "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
+ )
+ self.send_error(
+ "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
+ )
+ return
+
+ # Now lets try and call on_<CMD_NAME> function
+ try:
+ getattr(self, "on_%s" % (cmd_name,))(cmd)
+ except Exception:
+ logger.exception("[%s] Failed to handle line: %r", self.id(), line)
+
+ def close(self):
+ logger.warn("[%s] Closing connection", self.id())
+ self.time_we_closed = self.clock.time_msec()
+ self.transport.loseConnection()
+ self.on_connection_closed()
+
+ def send_error(self, error_string, *args):
+ """Send an error to remote and close the connection.
+ """
+ self.send_command(ErrorCommand(error_string % args))
+ self.close()
+
+ def send_command(self, cmd, do_buffer=True):
+ """Send a command if connection has been established.
+
+ Args:
+ cmd (Command)
+ do_buffer (bool): Whether to buffer the message or always attempt
+ to send the command. This is mostly used to send an error
+ message if we're about to close the connection due our buffers
+ becoming full.
+ """
+ if self.state == ConnectionStates.CLOSED:
+ logger.debug("[%s] Not sending, connection closed", self.id())
+ return
+
+ if do_buffer and self.state != ConnectionStates.ESTABLISHED:
+ self._queue_command(cmd)
+ return
+
+ self.outbound_commands_counter.inc(cmd.NAME)
+
+ string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+ if "\n" in string:
+ raise Exception("Unexpected newline in command: %r", string)
+
+ self.sendLine(string.encode("utf-8"))
+
+ self.last_sent_command = self.clock.time_msec()
+
+ def _queue_command(self, cmd):
+ """Queue the command until the connection is ready to write to again.
+ """
+ logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ self.pending_commands.append(cmd)
+
+ if len(self.pending_commands) > self.max_line_buffer:
+ # The other side is failing to keep up and out buffers are becoming
+ # full, so lets close the connection.
+ # XXX: should we squawk more loudly?
+ logger.error("[%s] Remote failed to keep up", self.id())
+ self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
+ self.close()
+
+ def _send_pending_commands(self):
+ """Send any queued commandes
+ """
+ pending = self.pending_commands
+ self.pending_commands = []
+ for cmd in pending:
+ self.send_command(cmd)
+
+ def on_PING(self, line):
+ self.received_ping = True
+
+ def on_ERROR(self, cmd):
+ logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
+
+ def pauseProducing(self):
+ """This is called when both the kernel send buffer and the twisted
+ tcp connection send buffers have become full.
+
+ We don't actually have any control over those sizes, so we buffer some
+ commands ourselves before knifing the connection due to the remote
+ failing to keep up.
+ """
+ logger.info("[%s] Pause producing", self.id())
+ self.state = ConnectionStates.PAUSED
+
+ def resumeProducing(self):
+ """The remote has caught up after we started buffering!
+ """
+ logger.info("[%s] Resume producing", self.id())
+ self.state = ConnectionStates.ESTABLISHED
+ self._send_pending_commands()
+
+ def stopProducing(self):
+ """We're never going to send any more data (normally because either
+ we or the remote has closed the connection)
+ """
+ logger.info("[%s] Stop producing", self.id())
+ self.on_connection_closed()
+
+ def connectionLost(self, reason):
+ logger.info("[%s] Replication connection closed: %r", self.id(), reason)
+ if isinstance(reason, Failure):
+ connection_close_counter.inc(reason.type.__name__)
+ else:
+ connection_close_counter.inc(reason.__class__.__name__)
+
+ try:
+ # Remove us from list of connections to be monitored
+ connected_connections.remove(self)
+ except ValueError:
+ pass
+
+ # Stop the looping call sending pings.
+ if self._send_ping_loop and self._send_ping_loop.running:
+ self._send_ping_loop.stop()
+
+ self.on_connection_closed()
+
+ def on_connection_closed(self):
+ logger.info("[%s] Connection was closed", self.id())
+
+ self.state = ConnectionStates.CLOSED
+ self.pending_commands = []
+
+ if self.transport:
+ self.transport.unregisterProducer()
+
+ def __str__(self):
+ return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
+ self.name, self.conn_id, self.addr,
+ )
+
+ def id(self):
+ return "%s-%s" % (self.name, self.conn_id)
+
+
+class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
+ VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
+ VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
+
+ def __init__(self, server_name, clock, streamer, addr):
+ BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
+
+ self.server_name = server_name
+ self.streamer = streamer
+ self.addr = addr
+
+ # The streams the client has subscribed to and is up to date with
+ self.replication_streams = set()
+
+ # The streams the client is currently subscribing to.
+ self.connecting_streams = set()
+
+ # Map from stream name to list of updates to send once we've finished
+ # subscribing the client to the stream.
+ self.pending_rdata = {}
+
+ def connectionMade(self):
+ self.send_command(ServerCommand(self.server_name))
+ BaseReplicationStreamProtocol.connectionMade(self)
+ self.streamer.new_connection(self)
+
+ def on_NAME(self, cmd):
+ logger.info("[%s] Renamed to %r", self.id(), cmd.data)
+ self.name = cmd.data
+
+ def on_USER_SYNC(self, cmd):
+ self.streamer.on_user_sync(
+ self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+ )
+
+ def on_REPLICATE(self, cmd):
+ stream_name = cmd.stream_name
+ token = cmd.token
+
+ if stream_name == "ALL":
+ # Subscribe to all streams we're publishing to.
+ for stream in self.streamer.streams_by_name.iterkeys():
+ self.subscribe_to_stream(stream, token)
+ else:
+ self.subscribe_to_stream(stream_name, token)
+
+ def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
+
+ def on_REMOVE_PUSHER(self, cmd):
+ self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+
+ def on_INVALIDATE_CACHE(self, cmd):
+ self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+
+ def on_USER_IP(self, cmd):
+ self.streamer.on_user_ip(
+ cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ cmd.last_seen,
+ )
+
+ @defer.inlineCallbacks
+ def subscribe_to_stream(self, stream_name, token):
+ """Subscribe the remote to a streams.
+
+ This invloves checking if they've missed anything and sending those
+ updates down if they have. During that time new updates for the stream
+ are queued and sent once we've sent down any missed updates.
+ """
+ self.replication_streams.discard(stream_name)
+ self.connecting_streams.add(stream_name)
+
+ try:
+ # Get missing updates
+ updates, current_token = yield self.streamer.get_stream_updates(
+ stream_name, token,
+ )
+
+ # Send all the missing updates
+ for update in updates:
+ token, row = update[0], update[1]
+ self.send_command(RdataCommand(stream_name, token, row))
+
+ # We send a POSITION command to ensure that they have an up to
+ # date token (especially useful if we didn't send any updates
+ # above)
+ self.send_command(PositionCommand(stream_name, current_token))
+
+ # Now we can send any updates that came in while we were subscribing
+ pending_rdata = self.pending_rdata.pop(stream_name, [])
+ for token, update in pending_rdata:
+ # Only send updates newer than the current token
+ if token > current_token:
+ self.send_command(RdataCommand(stream_name, token, update))
+
+ # They're now fully subscribed
+ self.replication_streams.add(stream_name)
+ except Exception as e:
+ logger.exception("[%s] Failed to handle REPLICATE command", self.id())
+ self.send_error("failed to handle replicate: %r", e)
+ finally:
+ self.connecting_streams.discard(stream_name)
+
+ def stream_update(self, stream_name, token, data):
+ """Called when a new update is available to stream to clients.
+
+ We need to check if the client is interested in the stream or not
+ """
+ if stream_name in self.replication_streams:
+ # The client is subscribed to the stream
+ self.send_command(RdataCommand(stream_name, token, data))
+ elif stream_name in self.connecting_streams:
+ # The client is being subscribed to the stream
+ logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
+ self.pending_rdata.setdefault(stream_name, []).append((token, data))
+ else:
+ # The client isn't subscribed
+ logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+
+ def send_sync(self, data):
+ self.send_command(SyncCommand(data))
+
+ def on_connection_closed(self):
+ BaseReplicationStreamProtocol.on_connection_closed(self)
+ self.streamer.lost_connection(self)
+
+
+class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
+ VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
+ VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
+
+ def __init__(self, client_name, server_name, clock, handler):
+ BaseReplicationStreamProtocol.__init__(self, clock)
+
+ self.client_name = client_name
+ self.server_name = server_name
+ self.handler = handler
+
+ # Map of stream to batched updates. See RdataCommand for info on how
+ # batching works.
+ self.pending_batches = {}
+
+ def connectionMade(self):
+ self.send_command(NameCommand(self.client_name))
+ BaseReplicationStreamProtocol.connectionMade(self)
+
+ # Once we've connected subscribe to the necessary streams
+ for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+ self.replicate(stream_name, token)
+
+ # Tell the server if we have any users currently syncing (should only
+ # happen on synchrotrons)
+ currently_syncing = self.handler.get_currently_syncing_users()
+ now = self.clock.time_msec()
+ for user_id in currently_syncing:
+ self.send_command(UserSyncCommand(user_id, True, now))
+
+ # We've now finished connecting to so inform the client handler
+ self.handler.update_connection(self)
+
+ def on_SERVER(self, cmd):
+ if cmd.data != self.server_name:
+ logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
+ self.send_error("Wrong remote")
+
+ def on_RDATA(self, cmd):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.inc(stream_name)
+
+ try:
+ row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
+ except Exception:
+ logger.exception(
+ "[%s] Failed to parse RDATA: %r %r",
+ self.id(), stream_name, cmd.row
+ )
+ raise
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream. Batch
+ # until we get an update for the stream with a non None token
+ self.pending_batches.setdefault(stream_name, []).append(row)
+ else:
+ # Check if this is the last of a batch of updates
+ rows = self.pending_batches.pop(stream_name, [])
+ rows.append(row)
+
+ self.handler.on_rdata(stream_name, cmd.token, rows)
+
+ def on_POSITION(self, cmd):
+ self.handler.on_position(cmd.stream_name, cmd.token)
+
+ def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
+
+ def replicate(self, stream_name, token):
+ """Send the subscription request to the server
+ """
+ if stream_name not in STREAMS_MAP:
+ raise Exception("Invalid stream name %r" % (stream_name,))
+
+ logger.info(
+ "[%s] Subscribing to replication stream: %r from %r",
+ self.id(), stream_name, token
+ )
+
+ self.send_command(ReplicateCommand(stream_name, token))
+
+ def on_connection_closed(self):
+ BaseReplicationStreamProtocol.on_connection_closed(self)
+ self.handler.update_connection(None)
+
+
+# The following simply registers metrics for the replication connections
+
+metrics.register_callback(
+ "pending_commands",
+ lambda: {
+ (p.name, p.conn_id): len(p.pending_commands)
+ for p in connected_connections
+ },
+ labels=["name", "conn_id"],
+)
+
+
+def transport_buffer_size(protocol):
+ if protocol.transport:
+ size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
+ return size
+ return 0
+
+
+metrics.register_callback(
+ "transport_send_buffer",
+ lambda: {
+ (p.name, p.conn_id): transport_buffer_size(p)
+ for p in connected_connections
+ },
+ labels=["name", "conn_id"],
+)
+
+
+def transport_kernel_read_buffer_size(protocol, read=True):
+ SIOCINQ = 0x541B
+ SIOCOUTQ = 0x5411
+
+ if protocol.transport:
+ fileno = protocol.transport.getHandle().fileno()
+ if read:
+ op = SIOCINQ
+ else:
+ op = SIOCOUTQ
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+ return size
+ return 0
+
+
+metrics.register_callback(
+ "transport_kernel_send_buffer",
+ lambda: {
+ (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
+ for p in connected_connections
+ },
+ labels=["name", "conn_id"],
+)
+
+
+metrics.register_callback(
+ "transport_kernel_read_buffer",
+ lambda: {
+ (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
+ for p in connected_connections
+ },
+ labels=["name", "conn_id"],
+)
+
+
+metrics.register_callback(
+ "inbound_commands",
+ lambda: {
+ (k[0], p.name, p.conn_id): count
+ for p in connected_connections
+ for k, count in p.inbound_commands_counter.counts.iteritems()
+ },
+ labels=["command", "name", "conn_id"],
+)
+
+metrics.register_callback(
+ "outbound_commands",
+ lambda: {
+ (k[0], p.name, p.conn_id): count
+ for p in connected_connections
+ for k, count in p.outbound_commands_counter.counts.iteritems()
+ },
+ labels=["command", "name", "conn_id"],
+)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+ "inbound_rdata_count",
+ labels=["stream_name"],
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
new file mode 100644
index 0000000000..a41af4fd6c
--- /dev/null
+++ b/synapse/replication/tcp/resource.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The server side of the replication stream.
+"""
+
+from twisted.internet import defer, reactor
+from twisted.internet.protocol import Factory
+
+from .streams import STREAMS_MAP, FederationStream
+from .protocol import ServerReplicationStreamProtocol
+
+from synapse.util.metrics import Measure, measure_func
+
+import logging
+import synapse.metrics
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+stream_updates_counter = metrics.register_counter(
+ "stream_updates", labels=["stream_name"]
+)
+user_sync_counter = metrics.register_counter("user_sync")
+federation_ack_counter = metrics.register_counter("federation_ack")
+remove_pusher_counter = metrics.register_counter("remove_pusher")
+invalidate_cache_counter = metrics.register_counter("invalidate_cache")
+user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationStreamProtocolFactory(Factory):
+ """Factory for new replication connections.
+ """
+ def __init__(self, hs):
+ self.streamer = ReplicationStreamer(hs)
+ self.clock = hs.get_clock()
+ self.server_name = hs.config.server_name
+
+ def buildProtocol(self, addr):
+ return ServerReplicationStreamProtocol(
+ self.server_name,
+ self.clock,
+ self.streamer,
+ addr
+ )
+
+
+class ReplicationStreamer(object):
+ """Handles replication connections.
+
+ This needs to be poked when new replication data may be available. When new
+ data is available it will propagate to all connected clients.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+
+ # Current connections.
+ self.connections = []
+
+ metrics.register_callback("total_connections", lambda: len(self.connections))
+
+ # List of streams that clients can subscribe to.
+ # We only support federation stream if federation sending hase been
+ # disabled on the master.
+ self.streams = [
+ stream(hs) for stream in STREAMS_MAP.itervalues()
+ if stream != FederationStream or not hs.config.send_federation
+ ]
+
+ self.streams_by_name = {stream.NAME: stream for stream in self.streams}
+
+ metrics.register_callback(
+ "connections_per_stream",
+ lambda: {
+ (stream_name,): len([
+ conn for conn in self.connections
+ if stream_name in conn.replication_streams
+ ])
+ for stream_name in self.streams_by_name
+ },
+ labels=["stream_name"],
+ )
+
+ self.federation_sender = None
+ if not hs.config.send_federation:
+ self.federation_sender = hs.get_federation_sender()
+
+ self.notifier.add_replication_callback(self.on_notifier_poke)
+
+ # Keeps track of whether we are currently checking for updates
+ self.is_looping = False
+ self.pending_updates = False
+
+ reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+
+ def on_shutdown(self):
+ # close all connections on shutdown
+ for conn in self.connections:
+ conn.send_error("server shutting down")
+
+ @defer.inlineCallbacks
+ def on_notifier_poke(self):
+ """Checks if there is actually any new data and sends it to the
+ connections if there are.
+
+ This should get called each time new data is available, even if it
+ is currently being executed, so that nothing gets missed
+ """
+ if not self.connections:
+ # Don't bother if nothing is listening. We still need to advance
+ # the stream tokens otherwise they'll fall beihind forever
+ for stream in self.streams:
+ stream.discard_updates_and_advance()
+ return
+
+ # If we're in the process of checking for new updates, mark that fact
+ # and return
+ if self.is_looping:
+ logger.debug("Noitifier poke loop already running")
+ self.pending_updates = True
+ return
+
+ self.pending_updates = True
+ self.is_looping = True
+
+ try:
+ # Keep looping while there have been pokes about potential updates.
+ # This protects against the race where a stream we already checked
+ # gets an update while we're handling other streams.
+ while self.pending_updates:
+ self.pending_updates = False
+
+ with Measure(self.clock, "repl.stream.get_updates"):
+ # First we tell the streams that they should update their
+ # current tokens.
+ for stream in self.streams:
+ stream.advance_current_token()
+
+ for stream in self.streams:
+ if stream.last_token == stream.upto_token:
+ continue
+
+ logger.debug(
+ "Getting stream: %s: %s -> %s",
+ stream.NAME, stream.last_token, stream.upto_token
+ )
+ try:
+ updates, current_token = yield stream.get_updates()
+ except Exception:
+ logger.info("Failed to handle stream %s", stream.NAME)
+ raise
+
+ logger.debug(
+ "Sending %d updates to %d connections",
+ len(updates), len(self.connections),
+ )
+
+ if updates:
+ logger.info(
+ "Streaming: %s -> %s", stream.NAME, updates[-1][0]
+ )
+ stream_updates_counter.inc_by(len(updates), stream.NAME)
+
+ # Some streams return multiple rows with the same stream IDs,
+ # we need to make sure they get sent out in batches. We do
+ # this by setting the current token to all but the last of
+ # a series of updates with the same token to have a None
+ # token. See RdataCommand for more details.
+ batched_updates = _batch_updates(updates)
+
+ for conn in self.connections:
+ for token, row in batched_updates:
+ try:
+ conn.stream_update(stream.NAME, token, row)
+ except Exception:
+ logger.exception("Failed to replicate")
+
+ logger.debug("No more pending updates, breaking poke loop")
+ finally:
+ self.pending_updates = False
+ self.is_looping = False
+
+ @measure_func("repl.get_stream_updates")
+ def get_stream_updates(self, stream_name, token):
+ """For a given stream get all updates since token. This is called when
+ a client first subscribes to a stream.
+ """
+ stream = self.streams_by_name.get(stream_name, None)
+ if not stream:
+ raise Exception("unknown stream %s", stream_name)
+
+ return stream.get_updates_since(token)
+
+ @measure_func("repl.federation_ack")
+ def federation_ack(self, token):
+ """We've received an ack for federation stream from a client.
+ """
+ federation_ack_counter.inc()
+ if self.federation_sender:
+ self.federation_sender.federation_ack(token)
+
+ @measure_func("repl.on_user_sync")
+ @defer.inlineCallbacks
+ def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ """A client has started/stopped syncing on a worker.
+ """
+ user_sync_counter.inc()
+ yield self.presence_handler.update_external_syncs_row(
+ conn_id, user_id, is_syncing, last_sync_ms,
+ )
+
+ @measure_func("repl.on_remove_pusher")
+ @defer.inlineCallbacks
+ def on_remove_pusher(self, app_id, push_key, user_id):
+ """A client has asked us to remove a pusher
+ """
+ remove_pusher_counter.inc()
+ yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id=app_id, pushkey=push_key, user_id=user_id
+ )
+
+ self.notifier.on_new_replication_data()
+
+ @measure_func("repl.on_invalidate_cache")
+ def on_invalidate_cache(self, cache_func, keys):
+ """The client has asked us to invalidate a cache
+ """
+ invalidate_cache_counter.inc()
+ getattr(self.store, cache_func).invalidate(tuple(keys))
+
+ @measure_func("repl.on_user_ip")
+ @defer.inlineCallbacks
+ def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ """The client saw a user request
+ """
+ user_ip_cache_counter.inc()
+ yield self.store.insert_client_ip(
+ user_id, access_token, ip, user_agent, device_id, last_seen,
+ )
+
+ def send_sync_to_all_connections(self, data):
+ """Sends a SYNC command to all clients.
+
+ Used in tests.
+ """
+ for conn in self.connections:
+ conn.send_sync(data)
+
+ def new_connection(self, connection):
+ """A new client connection has been established
+ """
+ self.connections.append(connection)
+
+ def lost_connection(self, connection):
+ """A client connection has been lost
+ """
+ try:
+ self.connections.remove(connection)
+ except ValueError:
+ pass
+
+ # We need to tell the presence handler that the connection has been
+ # lost so that it can handle any ongoing syncs on that connection.
+ self.presence_handler.update_external_syncs_clear(connection.conn_id)
+
+
+def _batch_updates(updates):
+ """Takes a list of updates of form [(token, row)] and sets the token to
+ None for all rows where the next row has the same token. This is used to
+ implement batching.
+
+ For example:
+
+ [(1, _), (1, _), (2, _), (3, _), (3, _)]
+
+ becomes:
+
+ [(None, _), (1, _), (2, _), (None, _), (3, _)]
+ """
+ if not updates:
+ return []
+
+ new_updates = []
+ for i, update in enumerate(updates[:-1]):
+ if update[0] == updates[i + 1][0]:
+ new_updates.append((None, update[1]))
+ else:
+ new_updates.append(update)
+
+ new_updates.append(updates[-1])
+ return new_updates
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
new file mode 100644
index 0000000000..4c60bf79f9
--- /dev/null
+++ b/synapse/replication/tcp/streams.py
@@ -0,0 +1,506 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines all the valid streams that clients can subscribe to, and the format
+of the rows returned by each stream.
+
+Each stream is defined by the following information:
+
+ stream name: The name of the stream
+ row type: The type that is used to serialise/deserialse the row
+ current_token: The function that returns the current token for the stream
+ update_function: The function that returns a list of updates between two tokens
+"""
+
+from twisted.internet import defer
+from collections import namedtuple
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+MAX_EVENTS_BEHIND = 10000
+
+
+EventStreamRow = namedtuple("EventStreamRow", (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+))
+BackfillStreamRow = namedtuple("BackfillStreamRow", (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+))
+PresenceStreamRow = namedtuple("PresenceStreamRow", (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+))
+TypingStreamRow = namedtuple("TypingStreamRow", (
+ "room_id", # str
+ "user_ids", # list(str)
+))
+ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+))
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
+ "user_id", # str
+))
+PushersStreamRow = namedtuple("PushersStreamRow", (
+ "user_id", # str
+ "app_id", # str
+ "pushkey", # str
+ "deleted", # bool
+))
+CachesStreamRow = namedtuple("CachesStreamRow", (
+ "cache_func", # str
+ "keys", # list(str)
+ "invalidation_ts", # int
+))
+PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+))
+DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
+ "user_id", # str
+ "destination", # str
+))
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
+ "entity", # str
+))
+FederationStreamRow = namedtuple("FederationStreamRow", (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+))
+TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
+ "user_id", # str
+ "room_id", # str
+ "data", # dict
+))
+AccountDataStreamRow = namedtuple("AccountDataStream", (
+ "user_id", # str
+ "room_id", # str
+ "data_type", # str
+ "data", # dict
+))
+CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
+ "room_id", # str
+ "type", # str
+ "state_key", # str
+ "event_id", # str, optional
+))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+ "group_id", # str
+ "user_id", # str
+ "type", # str
+ "content", # dict
+))
+
+
+class Stream(object):
+ """Base class for the streams.
+
+ Provides a `get_updates()` function that returns new updates since the last
+ time it was called up until the point `advance_current_token` was called.
+ """
+ NAME = None # The name of the stream
+ ROW_TYPE = None # The type of the row
+ _LIMITED = True # Whether the update function takes a limit
+
+ def __init__(self, hs):
+ # The token from which we last asked for updates
+ self.last_token = self.current_token()
+
+ # The token that we will get updates up to
+ self.upto_token = self.current_token()
+
+ def advance_current_token(self):
+ """Updates `upto_token` to "now", which updates up until which point
+ get_updates[_since] will fetch rows till.
+ """
+ self.upto_token = self.current_token()
+
+ def discard_updates_and_advance(self):
+ """Called when the stream should advance but the updates would be discarded,
+ e.g. when there are no currently connected workers.
+ """
+ self.upto_token = self.current_token()
+ self.last_token = self.upto_token
+
+ @defer.inlineCallbacks
+ def get_updates(self):
+ """Gets all updates since the last time this function was called (or
+ since the stream was constructed if it hadn't been called before),
+ until the `upto_token`
+
+ Returns:
+ (list(ROW_TYPE), int): list of updates plus the token used as an
+ upper bound of the updates (i.e. the "current token")
+ """
+ updates, current_token = yield self.get_updates_since(self.last_token)
+ self.last_token = current_token
+
+ defer.returnValue((updates, current_token))
+
+ @defer.inlineCallbacks
+ def get_updates_since(self, from_token):
+ """Like get_updates except allows specifying from when we should
+ stream updates
+
+ Returns:
+ (list(ROW_TYPE), int): list of updates plus the token used as an
+ upper bound of the updates (i.e. the "current token")
+ """
+ if from_token in ("NOW", "now"):
+ defer.returnValue(([], self.upto_token))
+
+ current_token = self.upto_token
+
+ from_token = int(from_token)
+
+ if from_token == current_token:
+ defer.returnValue(([], current_token))
+
+ if self._LIMITED:
+ rows = yield self.update_function(
+ from_token, current_token,
+ limit=MAX_EVENTS_BEHIND + 1,
+ )
+
+ if len(rows) >= MAX_EVENTS_BEHIND:
+ raise Exception("stream %s has fallen behined" % (self.NAME))
+ else:
+ rows = yield self.update_function(
+ from_token, current_token,
+ )
+
+ updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows]
+
+ defer.returnValue((updates, current_token))
+
+ def current_token(self):
+ """Gets the current token of the underlying streams. Should be provided
+ by the sub classes
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
+ def update_function(self, from_token, current_token, limit=None):
+ """Get updates between from_token and to_token. If Stream._LIMITED is
+ True then limit is provided, otherwise it's not.
+
+ Returns:
+ Deferred(list(tuple)): the first entry in the tuple is the token for
+ that update, and the rest of the tuple gets used to construct
+ a ``ROW_TYPE`` instance
+ """
+ raise NotImplementedError()
+
+
+class EventsStream(Stream):
+ """We received a new event, or an event went from being an outlier to not
+ """
+ NAME = "events"
+ ROW_TYPE = EventStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+ self.current_token = store.get_current_events_token
+ self.update_function = store.get_all_new_forward_event_rows
+
+ super(EventsStream, self).__init__(hs)
+
+
+class BackfillStream(Stream):
+ """We fetched some old events and either we had never seen that event before
+ or it went from being an outlier to not.
+ """
+ NAME = "backfill"
+ ROW_TYPE = BackfillStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+ self.current_token = store.get_current_backfill_token
+ self.update_function = store.get_all_new_backfill_event_rows
+
+ super(BackfillStream, self).__init__(hs)
+
+
+class PresenceStream(Stream):
+ NAME = "presence"
+ _LIMITED = False
+ ROW_TYPE = PresenceStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+ presence_handler = hs.get_presence_handler()
+
+ self.current_token = store.get_current_presence_token
+ self.update_function = presence_handler.get_all_presence_updates
+
+ super(PresenceStream, self).__init__(hs)
+
+
+class TypingStream(Stream):
+ NAME = "typing"
+ _LIMITED = False
+ ROW_TYPE = TypingStreamRow
+
+ def __init__(self, hs):
+ typing_handler = hs.get_typing_handler()
+
+ self.current_token = typing_handler.get_current_token
+ self.update_function = typing_handler.get_all_typing_updates
+
+ super(TypingStream, self).__init__(hs)
+
+
+class ReceiptsStream(Stream):
+ NAME = "receipts"
+ ROW_TYPE = ReceiptsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_max_receipt_stream_id
+ self.update_function = store.get_all_updated_receipts
+
+ super(ReceiptsStream, self).__init__(hs)
+
+
+class PushRulesStream(Stream):
+ """A user has changed their push rules
+ """
+ NAME = "push_rules"
+ ROW_TYPE = PushRulesStreamRow
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(PushRulesStream, self).__init__(hs)
+
+ def current_token(self):
+ push_rules_token, _ = self.store.get_push_rules_stream_token()
+ return push_rules_token
+
+ @defer.inlineCallbacks
+ def update_function(self, from_token, to_token, limit):
+ rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+ defer.returnValue([(row[0], row[2]) for row in rows])
+
+
+class PushersStream(Stream):
+ """A user has added/changed/removed a pusher
+ """
+ NAME = "pushers"
+ ROW_TYPE = PushersStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_pushers_stream_token
+ self.update_function = store.get_all_updated_pushers_rows
+
+ super(PushersStream, self).__init__(hs)
+
+
+class CachesStream(Stream):
+ """A cache was invalidated on the master and no other stream would invalidate
+ the cache on the workers
+ """
+ NAME = "caches"
+ ROW_TYPE = CachesStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_cache_stream_token
+ self.update_function = store.get_all_updated_caches
+
+ super(CachesStream, self).__init__(hs)
+
+
+class PublicRoomsStream(Stream):
+ """The public rooms list changed
+ """
+ NAME = "public_rooms"
+ ROW_TYPE = PublicRoomsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_current_public_room_stream_id
+ self.update_function = store.get_all_new_public_rooms
+
+ super(PublicRoomsStream, self).__init__(hs)
+
+
+class DeviceListsStream(Stream):
+ """Someone added/changed/removed a device
+ """
+ NAME = "device_lists"
+ _LIMITED = False
+ ROW_TYPE = DeviceListsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_device_list_changes_for_remotes
+
+ super(DeviceListsStream, self).__init__(hs)
+
+
+class ToDeviceStream(Stream):
+ """New to_device messages for a client
+ """
+ NAME = "to_device"
+ ROW_TYPE = ToDeviceStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_to_device_stream_token
+ self.update_function = store.get_all_new_device_messages
+
+ super(ToDeviceStream, self).__init__(hs)
+
+
+class FederationStream(Stream):
+ """Data to be sent over federation. Only available when master has federation
+ sending disabled.
+ """
+ NAME = "federation"
+ ROW_TYPE = FederationStreamRow
+
+ def __init__(self, hs):
+ federation_sender = hs.get_federation_sender()
+
+ self.current_token = federation_sender.get_current_token
+ self.update_function = federation_sender.get_replication_rows
+
+ super(FederationStream, self).__init__(hs)
+
+
+class TagAccountDataStream(Stream):
+ """Someone added/removed a tag for a room
+ """
+ NAME = "tag_account_data"
+ ROW_TYPE = TagAccountDataStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_max_account_data_stream_id
+ self.update_function = store.get_all_updated_tags
+
+ super(TagAccountDataStream, self).__init__(hs)
+
+
+class AccountDataStream(Stream):
+ """Global or per room account data was changed
+ """
+ NAME = "account_data"
+ ROW_TYPE = AccountDataStreamRow
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ self.current_token = self.store.get_max_account_data_stream_id
+
+ super(AccountDataStream, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def update_function(self, from_token, to_token, limit):
+ global_results, room_results = yield self.store.get_all_updated_account_data(
+ from_token, from_token, to_token, limit
+ )
+
+ results = list(room_results)
+ results.extend(
+ (stream_id, user_id, None, account_data_type, content,)
+ for stream_id, user_id, account_data_type, content in global_results
+ )
+
+ defer.returnValue(results)
+
+
+class CurrentStateDeltaStream(Stream):
+ """Current state for a room was changed
+ """
+ NAME = "current_state_deltas"
+ ROW_TYPE = CurrentStateDeltaStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_max_current_state_delta_stream_id
+ self.update_function = store.get_all_updated_current_state_deltas
+
+ super(CurrentStateDeltaStream, self).__init__(hs)
+
+
+class GroupServerStream(Stream):
+ NAME = "groups"
+ ROW_TYPE = GroupsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_group_stream_token
+ self.update_function = store.get_all_groups_changes
+
+ super(GroupServerStream, self).__init__(hs)
+
+
+STREAMS_MAP = {
+ stream.NAME: stream
+ for stream in (
+ EventsStream,
+ BackfillStream,
+ PresenceStream,
+ TypingStream,
+ ReceiptsStream,
+ PushRulesStream,
+ PushersStream,
+ CachesStream,
+ PublicRoomsStream,
+ DeviceListsStream,
+ ToDeviceStream,
+ FederationStream,
+ TagAccountDataStream,
+ AccountDataStream,
+ CurrentStateDeltaStream,
+ GroupServerStream,
+ )
+}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index f9f5a3e077..16f5a73b95 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -40,6 +40,7 @@ from synapse.rest.client.v2_alpha import (
register,
auth,
receipts,
+ read_marker,
keys,
tokenrefresh,
tags,
@@ -50,6 +51,8 @@ from synapse.rest.client.v2_alpha import (
devices,
thirdparty,
sendtodevice,
+ user_directory,
+ groups,
)
from synapse.http.server import JsonResource
@@ -88,6 +91,7 @@ class ClientRestResource(JsonResource):
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
+ read_marker.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
@@ -98,3 +102,5 @@ class ClientRestResource(JsonResource):
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
+ user_directory.register_servlets(hs, client_resource)
+ groups.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 29fcd72375..efd5c9873d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +16,9 @@
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError
-from synapse.types import UserID
+from synapse.api.constants import Membership
+from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
+from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -112,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
- "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
@@ -127,17 +135,114 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self.handlers.message_handler.purge_history(room_id, event_id)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ delete_local_events = bool(body.get("delete_local_events", False))
+
+ # establish the topological ordering we should keep events from. The
+ # user can provide an event_id in the URL or the request body, or can
+ # provide a timestamp in the request body.
+ if event_id is None:
+ event_id = body.get('purge_up_to_event_id')
+
+ if event_id is not None:
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ depth = event.depth
+ logger.info(
+ "[purge] purging up to depth %i (event_id %s)",
+ depth, event_id,
+ )
+ elif 'purge_up_to_ts' in body:
+ ts = body['purge_up_to_ts']
+ if not isinstance(ts, int):
+ raise SynapseError(
+ 400, "purge_up_to_ts must be an int",
+ errcode=Codes.BAD_JSON,
+ )
+
+ stream_ordering = (
+ yield self.store.find_first_stream_ordering_after_ts(ts)
+ )
+
+ room_event_after_stream_ordering = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ )
+ if room_event_after_stream_ordering:
+ (_, depth, _) = room_event_after_stream_ordering
+ else:
+ logger.warn(
+ "[purge] purging events not possible: No event found "
+ "(received_ts %i => stream_ordering %i)",
+ ts, stream_ordering,
+ )
+ raise SynapseError(
+ 404,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
+ )
+ logger.info(
+ "[purge] purging up to depth %i (received_ts %i => "
+ "stream_ordering %i)",
+ depth, ts, stream_ordering,
+ )
+ else:
+ raise SynapseError(
+ 400,
+ "must specify purge_up_to_event_id or purge_up_to_ts",
+ errcode=Codes.BAD_JSON,
+ )
+
+ purge_id = yield self.handlers.message_handler.start_purge_history(
+ room_id, depth,
+ delete_local_events=delete_local_events,
+ )
- defer.returnValue((200, {}))
+ defer.returnValue((200, {
+ "purge_id": purge_id,
+ }))
+
+
+class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ super(PurgeHistoryStatusRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, purge_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ purge_status = self.handlers.message_handler.get_purge_status(purge_id)
+ if purge_status is None:
+ raise NotFoundError("purge id '%s' not found" % purge_id)
+
+ defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- self.store = hs.get_datastore()
super(DeactivateAccountRestServlet, self).__init__(hs)
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@@ -148,18 +253,171 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(target_user_id)
- yield self.store.user_delete_threepids(target_user_id)
- yield self.store.user_set_password_hash(target_user_id, None)
-
+ yield self._deactivate_account_handler.deactivate_account(target_user_id)
defer.returnValue((200, {}))
+class ShutdownRoomRestServlet(ClientV1RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+ PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+
+ def __init__(self, hs):
+ super(ShutdownRoomRestServlet, self).__init__(hs)
+ self.store = hs.get_datastore()
+ self.handlers = hs.get_handlers()
+ self.state = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ content = parse_json_object_from_request(request)
+
+ new_room_user_id = content.get("new_room_user_id")
+ if not new_room_user_id:
+ raise SynapseError(400, "Please provide field `new_room_user_id`")
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info = yield self.handlers.room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {
+ "users_default": -10,
+ },
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ yield self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info("Shutting down room %r", room_id)
+
+ yield self.store.block_room(room_id, requester_user_id)
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ kicked_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ target_requester = create_requester(user_id)
+ yield self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False
+ )
+
+ yield self.room_member_handler.forget(target_requester.user, room_id)
+
+ yield self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False
+ )
+
+ kicked_users.append(user_id)
+
+ aliases_for_room = yield self.store.get_aliases_for_room(room_id)
+
+ yield self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ defer.returnValue((200, {
+ "kicked_users": kicked_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ }))
+
+
+class QuarantineMediaInRoom(ClientV1RestServlet):
+ """Quarantines all media in a room so that no one can download it via
+ this server.
+ """
+ PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
+
+ def __init__(self, hs):
+ super(QuarantineMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ num_quarantined = yield self.store.quarantine_media_ids_in_room(
+ room_id, requester.user.to_string(),
+ )
+
+ defer.returnValue((200, {"num_quarantined": num_quarantined}))
+
+
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
@user:to_reset_password?access_token=admin_access_token
@@ -177,12 +435,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.auth_handler = hs.get_auth_handler()
+ self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -198,7 +456,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
logger.info("new_password: %r", new_password)
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))
@@ -206,7 +464,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
@@ -225,7 +483,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -258,7 +516,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token
@@ -296,7 +554,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
class SearchUsersRestServlet(ClientV1RestServlet):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/search_users/
@admin:user?access_token=admin_access_token&term=alice
@@ -316,7 +574,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have a administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -347,9 +605,13 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
+ PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
+ ShutdownRoomRestServlet(hs).register(http_server)
+ QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..197335d7aa 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8930f1826f..1c3933380f 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -39,6 +39,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
+ self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
@@ -70,7 +71,10 @@ class ClientDirectoryServer(ClientV1RestServlet):
logger.debug("Got servers: %s", servers)
# TODO(erikj): Check types.
- # TODO(erikj): Check that room exists
+
+ room = yield self.store.get_room(room_id)
+ if room is None:
+ raise SynapseError(400, "Room does not exist")
dir_handler = self.handlers.directory_handler
@@ -89,7 +93,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
)
except SynapseError as e:
raise e
- except:
+ except Exception:
logger.exception("Failed to create association")
raise
except AuthError:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 72057f1b0c..34df5be4e9 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,12 +19,13 @@ from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID
from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
import logging
from saml2 import BINDING_HTTP_POST
@@ -33,13 +34,57 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
+from twisted.web.client import PartialDownloadError
+
logger = logging.getLogger(__name__)
+def login_submission_legacy_convert(submission):
+ """
+ If the input login submission is an old style object
+ (ie. with top-level user / medium / address) convert it
+ to a typed object.
+ """
+ if "user" in submission:
+ submission["identifier"] = {
+ "type": "m.id.user",
+ "user": submission["user"],
+ }
+ del submission["user"]
+
+ if "medium" in submission and "address" in submission:
+ submission["identifier"] = {
+ "type": "m.id.thirdparty",
+ "medium": submission["medium"],
+ "address": submission["address"],
+ }
+ del submission["medium"]
+ del submission["address"]
+
+
+def login_id_thirdparty_from_phone(identifier):
+ """
+ Convert a phone login identifier type to a generic threepid identifier
+ Args:
+ identifier(dict): Login identifier dict of type 'm.id.phone'
+
+ Returns: Login identifier dict of type 'm.id.threepid'
+ """
+ if "country" not in identifier or "number" not in identifier:
+ raise SynapseError(400, "Invalid phone-type identifier")
+
+ msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
+
+ return {
+ "type": "m.id.thirdparty",
+ "medium": "msisdn",
+ "address": msisdn,
+ }
+
+
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@@ -48,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@@ -75,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+ flows.extend((
+ {"type": t} for t in self.auth_handler.get_supported_login_types()
+ ))
return (200, {"flows": flows})
@@ -87,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
-
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
+ if self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@@ -111,49 +151,102 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
- raise SynapseError(400, "Bad login type.")
+ result = yield self._do_other_login(login_submission)
+ defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if 'medium' in login_submission and 'address' in login_submission:
- address = login_submission['address']
- if login_submission['medium'] == 'email':
+ def _do_other_login(self, login_submission):
+ """Handle non-token/saml/jwt logins
+
+ Args:
+ login_submission:
+
+ Returns:
+ (int, object): HTTP code/response
+ """
+ # Log the request we got, but only certain fields to minimise the chance of
+ # logging someone's password (even if they accidentally put it in the wrong
+ # field)
+ logger.info(
+ "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
+ login_submission.get('identifier'),
+ login_submission.get('medium'),
+ login_submission.get('address'),
+ login_submission.get('user'),
+ )
+ login_submission_legacy_convert(login_submission)
+
+ if "identifier" not in login_submission:
+ raise SynapseError(400, "Missing param: identifier")
+
+ identifier = login_submission["identifier"]
+ if "type" not in identifier:
+ raise SynapseError(400, "Login identifier has no type")
+
+ # convert phone type identifiers to generic threepids
+ if identifier["type"] == "m.id.phone":
+ identifier = login_id_thirdparty_from_phone(identifier)
+
+ # convert threepid identifiers to user IDs
+ if identifier["type"] == "m.id.thirdparty":
+ address = identifier.get('address')
+ medium = identifier.get('medium')
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- login_submission['medium'], address
+ medium, address,
)
if not user_id:
+ logger.warn(
+ "unknown 3pid identifier medium %s, address %r",
+ medium, address,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- else:
- user_id = login_submission['user']
- if not user_id.startswith('@'):
- user_id = UserID.create(
- user_id, self.hs.hostname
- ).to_string()
+ identifier = {
+ "type": "m.id.user",
+ "user": user_id,
+ }
+
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+ if "user" not in identifier:
+ raise SynapseError(400, "User identifier is missing 'user' key")
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_password_login(
- user_id=user_id,
- password=login_submission["password"],
+ canonical_user_id, callback = yield auth_handler.validate_login(
+ identifier["user"],
+ login_submission,
+ )
+
+ device_id = yield self._register_device(
+ canonical_user_id, login_submission,
)
- device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name"),
+ canonical_user_id, device_id,
)
+
result = {
- "user_id": user_id, # may have changed
+ "user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
+ if callback is not None:
+ yield callback(result)
+
defer.returnValue((200, result))
@defer.inlineCallbacks
@@ -166,7 +259,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
@@ -200,7 +292,7 @@ class LoginRestServlet(ClientV1RestServlet):
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
@@ -209,7 +301,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
@@ -341,7 +432,12 @@ class CasTicketServlet(ClientV1RestServlet):
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
- body = yield http_client.get_raw(uri, args)
+ try:
+ body = yield http_client.get_raw(uri, args)
+ except PartialDownloadError as pde:
+ # Twisted raises this error if the connection is closed,
+ # even if that's being used old-http style to signal end-of-data
+ body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
@@ -361,7 +457,7 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..e092158cb7 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@@ -30,15 +31,33 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
- access_token = get_access_token_from_request(request)
- yield self.store.delete_access_token(access_token)
+ try:
+ requester = yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # this implies the access token has already been deleted.
+ defer.returnValue((401, {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired"
+ }))
+ else:
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
+ else:
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
+
defer.returnValue((200, {}))
@@ -47,8 +66,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,7 +77,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.store.user_delete_access_tokens(user_id)
+
+ # first delete all of the user's devices
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+
+ # .. and then delete any access tokens which weren't associated with
+ # devices.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eafdce865e..4a73813c58 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
+from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user)
+ state = format_user_presence_state(state, self.clock.time_msec())
defer.returnValue((200, state))
@@ -75,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise KeyError()
except SynapseError as e:
raise e
- except:
+ except Exception:
raise SynapseError(400, "Unable to parse state")
yield self.presence_handler.set_state(user, state)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1a5045c9ec..e4e3611a14 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
@@ -52,10 +52,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
try:
new_name = content["displayname"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
@@ -94,10 +94,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
try:
new_name = content["avatar_url"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_avatar_url(
+ yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9a2ed6ed88..0206e664c1 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -73,6 +73,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -81,12 +82,10 @@ class PushersSetRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- pusher_pool = self.hs.get_pusherpool()
-
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
content['app_id'], content['pushkey'], user_id=user.to_string()
)
defer.returnValue((200, {}))
@@ -109,14 +108,14 @@ class PushersSetRestServlet(ClientV1RestServlet):
append = content['append']
if not append:
- yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try:
- yield pusher_pool.add_pusher(
+ yield self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content['kind'],
@@ -151,7 +150,8 @@ class PushersRemoveRestServlet(RestServlet):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -161,10 +161,8 @@ class PushersRemoveRestServlet(RestServlet):
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
- pusher_pool = self.hs.get_pusherpool()
-
try:
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index ecf7e311a9..9b3022e0b0 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -30,6 +30,8 @@ from hashlib import sha1
import hmac
import logging
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -70,10 +72,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +89,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
@@ -321,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
- if not isinstance(register_json.get("mac", None), basestring):
+ if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), basestring):
+ if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), basestring):
+ if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
@@ -336,9 +350,9 @@ class RegisterRestServlet(ClientV1RestServlet):
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
- if "\x00" in user:
+ if b"\x00" in user:
raise SynapseError(400, "Invalid user")
- if "\x00" in password:
+ if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
@@ -346,20 +360,20 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"])
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
- want_mac.update("\x00")
+ want_mac.update(b"\x00")
want_mac.update(password)
- want_mac.update("\x00")
- want_mac.update("admin" if admin else "notadmin")
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
- localpart=user,
+ localpart=user.lower(),
password=password,
admin=bool(admin),
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 728e3df0e3..fcf9c9ab44 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,9 +28,10 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
)
+from six.moves.urllib import parse as urlparse
+
import logging
-import urllib
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -82,6 +84,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -154,7 +158,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.handlers.room_member_handler.update_membership(
+ event = yield self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -162,15 +166,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- msg_handler = self.handlers.message_handler
- event, context = yield msg_handler.create_event(
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ requester,
event_dict,
- token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield msg_handler.send_nonmember_event(requester, event, context)
-
ret = {}
if event:
ret = {"event_id": event.event_id}
@@ -182,7 +183,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -194,15 +195,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -221,7 +226,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -237,7 +242,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
@@ -246,10 +251,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
- except:
+ except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
- handler = self.handlers.room_member_handler
+ handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
@@ -258,7 +263,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier,
))
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -397,16 +402,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
- self.state = hs.get_state_handler()
+ self.message_handler = hs.get_handlers().message_handler
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.message_handler.get_joined_members(
+ requester, room_id,
+ )
defer.returnValue((200, {
- "joined": users_with_profile
+ "joined": users_with_profile,
}))
@@ -427,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
- filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -484,13 +491,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(RoomEventServlet, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ event = yield self.event_handler.get_event(requester.user, event_id)
+
+ time_now = self.clock.time_msec()
+ if event:
+ defer.returnValue((200, serialize_event(event, time_now)))
+ else:
+ defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
- super(RoomEventContext, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@@ -505,7 +534,6 @@ class RoomEventContext(ClientV1RestServlet):
room_id,
event_id,
limit,
- requester.is_guest,
)
if not results:
@@ -531,7 +559,7 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -544,7 +572,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False,
)
- yield self.handlers.room_member_handler.forget(
+ yield self.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
@@ -562,12 +590,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@@ -585,13 +613,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.handlers.room_member_handler.do_3pid_invite(
+ yield self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -613,7 +641,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']}
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -623,7 +651,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content=event_content,
)
- defer.returnValue((200, {}))
+ return_value = {}
+
+ if membership_action == "join":
+ return_value["room_id"] = room_id
+
+ defer.returnValue((200, return_value))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -641,6 +674,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -651,8 +685,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -686,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
- room_id = urllib.unquote(room_id)
- target_user = UserID.from_string(urllib.unquote(user_id))
+ room_id = urlparse.unquote(room_id)
+ target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
@@ -749,8 +782,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
- room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
+ room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
@@ -802,4 +834,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventContext(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 03141c623c..c43b30b73a 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(
+ request,
+ self.hs.config.turn_allow_guests
+ )
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 20e765f48f..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
import re
-import logging
+from twisted.internet import defer
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@@ -47,3 +48,47 @@ def client_v2_patterns(path_regex, releases=(0,),
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
+
+
+def set_timeline_upper_limit(filter_json, filter_timeline_limit):
+ if filter_timeline_limit < 0:
+ return # no upper limits
+ timeline = filter_json.get('room', {}).get('timeline', {})
+ if 'limit' in timeline:
+ filter_json['room']['timeline']["limit"] = min(
+ filter_json['room']['timeline']['limit'],
+ filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+ """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+ Takes a on_POST method which returns a deferred (errcode, body) response
+ and adds exception handling to turn a InteractiveAuthIncompleteError into
+ a 401 response.
+
+ Normal usage is:
+
+ @interactive_auth_handler
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ # ...
+ yield self.auth_handler.check_auth
+ """
+ def wrapped(*args, **kwargs):
+ res = defer.maybeDeferred(orig, *args, **kwargs)
+ res.addErrback(_catch_incomplete_interactive_auth)
+ return res
+ return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+ """helper for interactive_auth_handler
+
+ Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+ Args:
+ f (failure.Failure):
+ """
+ f.trap(InteractiveAuthIncompleteError)
+ return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 398e7f5eb0..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,27 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from twisted.internet import defer
+from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
-from synapse.api.errors import LoginError, SynapseError, Codes
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet, assert_params_in_request,
+ parse_json_object_from_request,
+)
from synapse.util.async import run_on_reactor
-
-from ._base import client_v2_patterns
-
-import logging
-
+from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
-class PasswordRequestTokenRestServlet(RestServlet):
+class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
- super(PasswordRequestTokenRestServlet, self).__init__()
+ super(EmailPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -40,14 +44,14 @@ class PasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret', 'email', 'send_attempt'
+ ])
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
@@ -60,6 +64,42 @@ class PasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnPasswordRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.datastore = self.hs.get_datastore()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number', 'send_attempt',
+ ])
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
+ existingUid = yield self.datastore.get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is None:
+ raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$")
@@ -68,55 +108,62 @@ class PasswordRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.datastore = self.hs.get_datastore()
+ self._set_password_handler = hs.get_set_password_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- [LoginType.EMAIL_IDENTITY]
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
+ # there are two possibilities here. Either the user does not have an
+ # access token, and needs to do a password reset; or they have one and
+ # need to validate their identity.
+ #
+ # In the first case, we offer a couple of means of identifying
+ # themselves (email and msisdn, though it's unclear if msisdn actually
+ # works).
+ #
+ # In the second case, we require a password to confirm their identity.
+
+ if has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- elif LoginType.EMAIL_IDENTITY in result:
- threepid = result[LoginType.EMAIL_IDENTITY]
- if 'medium' not in threepid or 'address' not in threepid:
- raise SynapseError(500, "Malformed threepid")
- if threepid['medium'] == 'email':
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- threepid['address'] = threepid['address'].lower()
- # if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- threepid['medium'], threepid['address']
+ params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
)
- if not threepid_user_id:
- raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
- user_id = threepid_user_id
+ user_id = requester.user.to_string()
else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = None
+ result, params, _ = yield self.auth_handler.check_auth(
+ [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+ body, self.hs.get_ip_from_request(request),
+ )
+
+ if LoginType.EMAIL_IDENTITY in result:
+ threepid = result[LoginType.EMAIL_IDENTITY]
+ if 'medium' not in threepid or 'address' not in threepid:
+ raise SynapseError(500, "Malformed threepid")
+ if threepid['medium'] == 'email':
+ # For emails, transform the address to lowercase.
+ # We store all email addreses as lowercase in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ threepid['address'] = threepid['address'].lower()
+ # if using email, we must know about the email they're authing with!
+ threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+ threepid['medium'], threepid['address']
+ )
+ if not threepid_user_id:
+ raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+ user_id = threepid_user_id
+ else:
+ logger.error("Auth succeeded but no known type!", result.keys())
+ raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
user_id, new_password, requester
)
@@ -130,52 +177,43 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
+ super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- super(DeactivateAccountRestServlet, self).__init__()
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
- requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = yield self.auth.get_user_by_req(request)
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(user_id)
- yield self.store.user_delete_threepids(user_id)
- yield self.store.user_set_password_hash(user_id, None)
+ # allow ASes to dectivate their own users
+ if requester.app_service:
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string()
+ )
+ defer.returnValue((200, {}))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(),
+ )
defer.returnValue((200, {}))
-class ThreepidRequestTokenRestServlet(RestServlet):
+class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
- super(ThreepidRequestTokenRestServlet, self).__init__()
+ super(EmailThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
+ self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -190,7 +228,12 @@ class ThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
+ existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -201,6 +244,49 @@ class ThreepidRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnThreepidRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.datastore = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ required = [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number', 'send_attempt',
+ ]
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if absent:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
+ existingUid = yield self.datastore.get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is not None:
+ raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$")
@@ -210,6 +296,7 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -217,7 +304,7 @@ class ThreepidRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
- threepids = yield self.hs.get_datastore().user_get_threepids(
+ threepids = yield self.datastore.user_get_threepids(
requester.user.to_string()
)
@@ -258,7 +345,7 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']:
logger.debug(
- "Binding emails %s to %s",
+ "Binding threepid %s to %s",
threepid, user_id
)
yield self.identity_handler.bind_threepid(
@@ -301,10 +388,27 @@ class ThreepidDeleteRestServlet(RestServlet):
defer.returnValue((200, {}))
+class WhoamiRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/whoami$")
+
+ def __init__(self, hs):
+ super(WhoamiRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
+ defer.returnValue((200, {'user_id': requester.user.to_string()}))
+
+
def register_servlets(hs, http_server):
- PasswordRequestTokenRestServlet(hs).register(http_server)
+ EmailPasswordRequestTokenRestServlet(hs).register(http_server)
+ MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
- ThreepidRequestTokenRestServlet(hs).register(http_server)
+ EmailThreepidRequestTokenRestServlet(hs).register(http_server)
+ MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index b16079cece..0e0a187efd 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -16,7 +16,7 @@
from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
@@ -82,6 +82,13 @@ class RoomAccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "m.fully_read":
+ raise SynapseError(
+ 405,
+ "Cannot set m.fully_read through this API."
+ " Use /rooms/!roomId:server.name/read_markers"
+ )
+
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index a1feaf3d54..35d58b367a 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,15 +17,15 @@ import logging
from twisted.internet import defer
-from synapse.api import constants, errors
+from synapse.api import errors
from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -46,9 +46,53 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
+class DeleteDevicesRestServlet(servlet.RestServlet):
+ """
+ API for bulk deletion of devices. Accepts a JSON object with a devices
+ key which lists the device_ids to delete. Requires user interactive auth.
+ """
+ PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
+
+ def __init__(self, hs):
+ super(DeleteDevicesRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.auth_handler = hs.get_auth_handler()
+
+ @interactive_auth_handler
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
+ try:
+ body = servlet.parse_json_object_from_request(request)
+ except errors.SynapseError as e:
+ if e.errcode == errors.Codes.NOT_JSON:
+ # deal with older clients which didn't pass a J*DELETESON dict
+ # the same as those that pass an empty dict
+ body = {}
+ else:
+ raise e
+
+ if 'devices' not in body:
+ raise errors.SynapseError(
+ 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
+ )
+
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
+
+ yield self.device_handler.delete_devices(
+ requester.user.to_string(),
+ body['devices'],
+ )
+ defer.returnValue((200, {}))
+
+
class DeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
- releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -70,8 +114,11 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
body = servlet.parse_json_object_from_request(request)
@@ -83,17 +130,12 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [constants.LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
- requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
- requester.user.to_string(),
- device_id,
+ requester.user.to_string(), device_id,
)
defer.returnValue((200, {}))
@@ -111,5 +153,6 @@ class DeviceRestServlet(servlet.RestServlet):
def register_servlets(hs, http_server):
+ DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b4084fec62..1b9dc4528d 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -20,6 +20,7 @@ from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from ._base import client_v2_patterns
+from ._base import set_timeline_upper_limit
import logging
@@ -49,7 +50,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_id = int(filter_id)
- except:
+ except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
@@ -85,6 +86,11 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
+ set_timeline_upper_limit(
+ content,
+ self.hs.config.filter_timeline_limit
+ )
+
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
new file mode 100644
index 0000000000..3bb1ec2af6
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import GroupID
+
+from ._base import client_v2_patterns
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GroupServlet(RestServlet):
+ """Get the group profile
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+
+ def __init__(self, hs):
+ super(GroupServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ group_description = yield self.groups_handler.get_group_profile(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, group_description))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ yield self.groups_handler.update_group_profile(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class GroupSummaryServlet(RestServlet):
+ """Get the full group summary
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+
+ def __init__(self, hs):
+ super(GroupSummaryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ get_group_summary = yield self.groups_handler.get_group_summary(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, get_group_summary))
+
+
+class GroupSummaryRoomsCatServlet(RestServlet):
+ """Update/delete a rooms entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryRoomsCatServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoryServlet(RestServlet):
+ """Get/add/update/delete a group category
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoriesServlet(RestServlet):
+ """Get all group categories
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoriesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupRoleServlet(RestServlet):
+ """Get/add/update/delete a group role
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRolesServlet(RestServlet):
+ """Get all group roles
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRolesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupSummaryUsersRoleServlet(RestServlet):
+ """Update/delete a user's entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/users/:room_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryUsersRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRoomServlet(RestServlet):
+ """Get all rooms in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+
+ def __init__(self, hs):
+ super(GroupRoomServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupUsersServlet(RestServlet):
+ """Get all users in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+
+ def __init__(self, hs):
+ super(GroupUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupInvitedUsersServlet(RestServlet):
+ """Get users invited to a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+
+ def __init__(self, hs):
+ super(GroupInvitedUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_invited_users_in_group(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSettingJoinPolicyServlet(RestServlet):
+ """Set group join policy
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+
+ def __init__(self, hs):
+ super(GroupSettingJoinPolicyServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+
+ result = yield self.groups_handler.set_group_join_policy(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupCreateServlet(RestServlet):
+ """Create a group
+ """
+ PATTERNS = client_v2_patterns("/create_group$")
+
+ def __init__(self, hs):
+ super(GroupCreateServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.server_name = hs.hostname
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ # TODO: Create group on remote server
+ content = parse_json_object_from_request(request)
+ localpart = content.pop("localpart")
+ group_id = GroupID(localpart, self.server_name).to_string()
+
+ result = yield self.groups_handler.create_group(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsServlet(RestServlet):
+ """Add a room to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsConfigServlet(RestServlet):
+ """Update the config of a room in a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsConfigServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id, config_key):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersInviteServlet(RestServlet):
+ """Invite a user to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.store = hs.get_datastore()
+ self.is_mine_id = hs.is_mine_id
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ config = content.get("config", {})
+ result = yield self.groups_handler.invite(
+ group_id, user_id, requester_user_id, config,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersKickServlet(RestServlet):
+ """Kick a user from the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersKickServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfLeaveServlet(RestServlet):
+ """Leave a joined group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/leave$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfLeaveServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, requester_user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfJoinServlet(RestServlet):
+ """Attempt to join a group, or knock
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/join$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfJoinServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.join_group(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfAcceptInviteServlet(RestServlet):
+ """Accept a group invite
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfAcceptInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.accept_invite(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfUpdatePublicityServlet(RestServlet):
+ """Update whether we publicise a users membership of a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfUpdatePublicityServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ publicise = content["publicise"]
+ yield self.store.update_group_publicity(
+ group_id, requester_user_id, publicise,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class PublicisedGroupsForUserServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ result = yield self.groups_handler.get_publicised_groups_for_user(
+ user_id
+ )
+
+ defer.returnValue((200, result))
+
+
+class PublicisedGroupsForUsersServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ content = parse_json_object_from_request(request)
+ user_ids = content["user_ids"]
+
+ result = yield self.groups_handler.bulk_get_publicised_groups(
+ user_ids
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupsForUserServlet(RestServlet):
+ """Get all groups the logged in user is joined to
+ """
+ PATTERNS = client_v2_patterns(
+ "/joined_groups$"
+ )
+
+ def __init__(self, hs):
+ super(GroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_joined_groups(requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+ GroupServlet(hs).register(http_server)
+ GroupSummaryServlet(hs).register(http_server)
+ GroupInvitedUsersServlet(hs).register(http_server)
+ GroupUsersServlet(hs).register(http_server)
+ GroupRoomServlet(hs).register(http_server)
+ GroupSettingJoinPolicyServlet(hs).register(http_server)
+ GroupCreateServlet(hs).register(http_server)
+ GroupAdminRoomsServlet(hs).register(http_server)
+ GroupAdminRoomsConfigServlet(hs).register(http_server)
+ GroupAdminUsersInviteServlet(hs).register(http_server)
+ GroupAdminUsersKickServlet(hs).register(http_server)
+ GroupSelfLeaveServlet(hs).register(http_server)
+ GroupSelfJoinServlet(hs).register(http_server)
+ GroupSelfAcceptInviteServlet(hs).register(http_server)
+ GroupsForUserServlet(hs).register(http_server)
+ GroupCategoryServlet(hs).register(http_server)
+ GroupCategoriesServlet(hs).register(http_server)
+ GroupSummaryRoomsCatServlet(hs).register(http_server)
+ GroupRoleServlet(hs).register(http_server)
+ GroupRolesServlet(hs).register(http_server)
+ GroupSelfUpdatePublicityServlet(hs).register(http_server)
+ GroupSummaryUsersRoleServlet(hs).register(http_server)
+ PublicisedGroupsForUserServlet(hs).register(http_server)
+ PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 6a3cfe84f8..3cc87ea63f 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/query$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns(
- "/keys/changes$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -188,13 +181,11 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- changed = yield self.device_handler.get_user_ids_changed(
+ results = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
- defer.returnValue((200, {
- "changed": list(changed),
- }))
+ defer.returnValue((200, results))
class OneTimeKeyServlet(RestServlet):
@@ -215,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/claim$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..ec170109fe 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$", releases=())
+ PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
new file mode 100644
index 0000000000..2f8784fe06
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReadMarkerRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
+
+ def __init__(self, hs):
+ super(ReadMarkerRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.receipts_handler = hs.get_receipts_handler()
+ self.read_marker_handler = hs.get_read_marker_handler()
+ self.presence_handler = hs.get_presence_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+
+ yield self.presence_handler.bump_presence_active_time(requester.user)
+
+ body = parse_json_object_from_request(request)
+
+ read_event_id = body.get("m.read", None)
+ if read_event_id:
+ yield self.receipts_handler.received_client_receipt(
+ room_id,
+ "m.read",
+ user_id=requester.user.to_string(),
+ event_id=read_event_id
+ )
+
+ read_marker_event_id = body.get("m.fully_read", None)
+ if read_marker_event_id:
+ yield self.read_marker_handler.received_client_read_marker(
+ room_id,
+ user_id=requester.user.to_string(),
+ event_id=read_marker_event_id
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ccca5a12d5..5cab00aea9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,17 +17,25 @@
from twisted.internet import defer
import synapse
+import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
+)
+from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from six import string_types
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
@@ -43,7 +52,7 @@ else:
logger = logging.getLogger(__name__)
-class RegisterRequestTokenRestServlet(RestServlet):
+class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
@@ -51,7 +60,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegisterRequestTokenRestServlet, self).__init__()
+ super(EmailRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -59,14 +68,14 @@ class RegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret', 'email', 'send_attempt'
+ ])
- if len(absent) > 0:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
@@ -79,6 +88,86 @@ class RegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnRegisterRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number',
+ 'send_attempt',
+ ])
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "Phone number is already in use", Codes.THREEPID_IN_USE
+ )
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
+class UsernameAvailabilityRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/register/available")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(UsernameAvailabilityRestServlet, self).__init__()
+ self.hs = hs
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.ratelimiter = FederationRateLimiter(
+ hs.get_clock(),
+ # Time window of 2s
+ window_size=2000,
+ # Artificially delay requests if rate > sleep_limit/window_size
+ sleep_limit=1,
+ # Amount of artificial delay to apply
+ sleep_msec=1000,
+ # Error with 429 if more than reject_limit requests are queued
+ reject_limit=1,
+ # Allow 1 request at a time
+ concurrent_requests=1,
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ ip = self.hs.get_ip_from_request(request)
+ with self.ratelimiter.ratelimit(ip) as wait_deferred:
+ yield wait_deferred
+
+ username = parse_string(request, "username", required=True)
+
+ yield self.registration_handler.check_username(username)
+
+ defer.returnValue((200, {"available": True}))
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
@@ -95,9 +184,11 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
+ self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
@@ -121,14 +212,14 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
- if (not isinstance(body['password'], basestring) or
+ if (not isinstance(body['password'], string_types) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
- if (not isinstance(body['username'], basestring) or
+ if (not isinstance(body['username'], string_types) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
@@ -146,15 +237,30 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
+
+ # XXX we should check that desired_username is valid. Currently
+ # we give appservices carte blanche for any insanity in mxids,
+ # because the IRC bridges rely on being able to register stupid
+ # IDs.
+
access_token = get_access_token_from_request(request)
- if isinstance(desired_username, basestring):
+ if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
+ # for either shared secret or regular registration, downcase the
+ # provided username before attempting to register it. This should mean
+ # that people who try to register with upper-case in their usernames
+ # don't get a nasty surprise. (Note that we treat username
+ # case-insenstively in login, so they are free to carry on imagining
+ # that their username is CrAzYh4cKeR if that keeps them happy)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
@@ -200,32 +306,86 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
+ # Only give msisdn flows if the x_show_msisdn flag is given:
+ # this is a hack to work around the fact that clients were shipped
+ # that use fallback registration if they see any flows that they don't
+ # recognise, which means we break registration for these clients if we
+ # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
+ # Android <=0.6.9 have fallen below an acceptable threshold, this
+ # parameter should go away and we should always advertise msisdn flows.
+ show_msisdn = False
+ if 'x_show_msisdn' in body and body['x_show_msisdn']:
+ show_msisdn = True
+
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
+ if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
+ flows.extend([
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+ ])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY]
- ]
-
- authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
+ if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
+ flows.extend([
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
+ ])
+
+ auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
- if not authed:
- defer.returnValue((401, auth_result))
- return
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
registered_user_id
)
- # don't re-register the email address
+ # don't re-register the threepids
add_email = False
+ add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
@@ -236,6 +396,9 @@ class RegisterRestServlet(RestServlet):
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
@@ -250,6 +413,7 @@ class RegisterRestServlet(RestServlet):
)
add_email = True
+ add_msisdn = True
return_dict = yield self._create_registration_details(
registered_user_id, params
@@ -262,6 +426,13 @@ class RegisterRestServlet(RestServlet):
params.get("bind_email")
)
+ if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ yield self._register_msisdn_threepid(
+ registered_user_id, threepid, return_dict["access_token"],
+ params.get("bind_msisdn")
+ )
+
defer.returnValue((200, return_dict))
def on_OPTIONS(self, _):
@@ -278,15 +449,24 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
+ if not username:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
- user = username.encode("utf-8")
+ # use the username from the original request rather than the
+ # downcased one in `username` for the mac calculation
+ user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
+ # FIXME this is different to the /v1/register endpoint, which
+ # includes the password and admin flag in the hashed text. Why are
+ # these different?
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
@@ -323,8 +503,9 @@ class RegisterRestServlet(RestServlet):
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
+ # This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ return
yield self.auth_handler.add_threepid(
user_id,
@@ -372,6 +553,43 @@ class RegisterRestServlet(RestServlet):
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
+ def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
+ """Add a phone number as a 3pid identifier
+
+ Also optionally binds msisdn to the given user_id on the identity server
+
+ Args:
+ user_id (str): id of user
+ threepid (object): m.login.msisdn auth response
+ token (str): access_token for the user
+ bind_email (bool): true if the client requested the email to be
+ bound at the identity server
+ Returns:
+ defer.Deferred:
+ """
+ reqd = ('medium', 'address', 'validated_at')
+ if any(x not in threepid for x in reqd):
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue()
+
+ yield self.auth_handler.add_threepid(
+ user_id,
+ threepid['medium'],
+ threepid['address'],
+ threepid['validated_at'],
+ )
+
+ if bind_msisdn:
+ logger.info("bind_msisdn specified: binding")
+ logger.debug("Binding msisdn %s to %s", threepid, user_id)
+ yield self.identity_handler.bind_threepid(
+ threepid['threepid_creds'], user_id
+ )
+ else:
+ logger.info("bind_msisdn not specified: not binding msisdn")
+
+ @defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -380,25 +598,28 @@ class RegisterRestServlet(RestServlet):
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
- device_id and initial_device_name
+ device_id, initial_device_name and inhibit_login
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- device_id = yield self._register_device(user_id, params)
+ result = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ }
+ if not params.get("inhibit_login", False):
+ device_id = yield self._register_device(user_id, params)
- access_token = (
- yield self.auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
- initial_display_name=params.get("initial_device_display_name")
+ access_token = (
+ yield self.auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id,
+ )
)
- )
- defer.returnValue({
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- })
+ result.update({
+ "access_token": access_token,
+ "device_id": device_id,
+ })
+ defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
@@ -433,7 +654,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- self.device_handler.check_device_registered(
+ yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@@ -449,5 +670,7 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
- RegisterRequestTokenRestServlet(hs).register(http_server)
+ EmailRegisterRequestTokenRestServlet(hs).register(http_server)
+ MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
+ UsernameAvailabilityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..90bdb1db15 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- releases=[], v2_alpha=False
+ v2_alpha=False
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index b3d8001638..eb91c0b293 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean
)
+from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events.utils import (
@@ -27,12 +28,12 @@ from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns
+from ._base import set_timeline_upper_limit
-import copy
import itertools
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ class SyncRestServlet(RestServlet):
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
@@ -108,7 +110,7 @@ class SyncRestServlet(RestServlet):
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
- logger.info(
+ logger.debug(
"/sync: user=%r, timeout=%r, since=%r,"
" set_presence=%r, filter_id=%r, device_id=%r" % (
user, timeout, since, set_presence, filter_id, device_id
@@ -121,7 +123,9 @@ class SyncRestServlet(RestServlet):
if filter_id.startswith('{'):
try:
filter_object = json.loads(filter_id)
- except:
+ set_timeline_upper_limit(filter_object,
+ self.hs.config.filter_timeline_limit)
+ except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
@@ -160,27 +164,35 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
+ response_content = self.encode_response(
+ time_now, sync_result, requester.access_token_id, filter
+ )
+
+ defer.returnValue((200, response_content))
- joined = self.encode_joined(
- sync_result.joined, time_now, requester.access_token_id, filter.event_fields
+ @staticmethod
+ def encode_response(time_now, sync_result, access_token_id, filter):
+ joined = SyncRestServlet.encode_joined(
+ sync_result.joined, time_now, access_token_id, filter.event_fields
)
- invited = self.encode_invited(
- sync_result.invited, time_now, requester.access_token_id
+ invited = SyncRestServlet.encode_invited(
+ sync_result.invited, time_now, access_token_id,
)
- archived = self.encode_archived(
- sync_result.archived, time_now, requester.access_token_id,
+ archived = SyncRestServlet.encode_archived(
+ sync_result.archived, time_now, access_token_id,
filter.event_fields,
)
- response_content = {
+ return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
- "changed": list(sync_result.device_lists),
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
},
- "presence": self.encode_presence(
+ "presence": SyncRestServlet.encode_presence(
sync_result.presence, time_now
),
"rooms": {
@@ -188,20 +200,32 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
+ "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
- defer.returnValue((200, response_content))
-
- def encode_presence(self, events, time_now):
- formatted = []
- for event in events:
- event = copy.deepcopy(event)
- event['sender'] = event['content'].pop('user_id')
- formatted.append(event)
- return {"events": formatted}
+ @staticmethod
+ def encode_presence(events, time_now):
+ return {
+ "events": [
+ {
+ "type": "m.presence",
+ "sender": event.user_id,
+ "content": format_user_presence_state(
+ event, time_now, include_user_id=False
+ ),
+ }
+ for event in events
+ ]
+ }
- def encode_joined(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_joined(rooms, time_now, token_id, event_fields):
"""
Encode the joined rooms in a sync result
@@ -220,13 +244,14 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields
)
return joined
- def encode_invited(self, rooms, time_now, token_id):
+ @staticmethod
+ def encode_invited(rooms, time_now, token_id):
"""
Encode the invited rooms in a sync result
@@ -247,6 +272,7 @@ class SyncRestServlet(RestServlet):
invite = serialize_event(
room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
+ is_invite=True,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -258,7 +284,8 @@ class SyncRestServlet(RestServlet):
return invited
- def encode_archived(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_archived(rooms, time_now, token_id, event_fields):
"""
Encode the archived rooms in a sync result
@@ -277,7 +304,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields
)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 31f94bc6e9..6773b9ba60 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -36,15 +36,14 @@ class ThirdPartyProtocolsServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -54,7 +53,7 @@ class ThirdPartyProtocolServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
@@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -77,7 +75,7 @@ class ThirdPartyUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
@@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
@@ -101,7 +98,7 @@ class ThirdPartyLocationServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
new file mode 100644
index 0000000000..2d4a43c353
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectorySearchRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/user_directory/search$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(UserDirectorySearchRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ user_id = requester.user.to_string()
+
+ body = parse_json_object_from_request(request)
+
+ limit = body.get("limit", 10)
+ limit = min(limit, 50)
+
+ try:
+ search_term = body["search_term"]
+ except Exception:
+ raise SynapseError(400, "`search_term` is required field")
+
+ results = yield self.user_directory_handler.search_users(
+ user_id, search_term, limit,
+ )
+
+ defer.returnValue((200, results))
+
+
+def register_servlets(hs, http_server):
+ UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index e984ea47db..2ecb15deee 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
"r0.0.1",
"r0.1.0",
"r0.2.0",
+ "r0.3.0",
]
})
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index ff95269ba8..be68d9a096 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -84,12 +84,11 @@ class LocalKey(Resource):
}
old_verify_keys = {}
- for key in self.config.old_signing_keys:
- key_id = "%s:%s" % (key.alg, key.version)
+ for key_id, key in self.config.old_signing_keys.items():
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
u"key": encode_base64(verify_key_bytes),
- u"expired_ts": key.expired,
+ u"expired_ts": key.expired_ts,
}
tls_fingerprints = self.config.tls_fingerprints
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9fe2013657..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ logger.debug("Federation denied with %s", server_name)
+ continue
+
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
@@ -213,7 +221,7 @@ class RemoteKey(Resource):
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
- except:
+ except Exception:
logger.exception("Failed to get key for %r", server_name)
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index b9600f2167..c0d2f06855 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,6 +17,7 @@ from synapse.http.server import respond_with_json, finish_request
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
+from synapse.util import logcontext
from twisted.internet import defer
from twisted.protocols.basic import FileSender
@@ -27,7 +28,7 @@ import os
import logging
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ def parse_media_id(request):
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
- except:
+ except Exception:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
@@ -69,42 +70,133 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
- if upload_name:
- if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
- else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
- request.setHeader(
- b"Content-Length", b"%d" % (file_size,)
- )
+ add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield FileSender().beginFileTransfer(f, request)
+ yield logcontext.make_deferred_yieldable(
+ FileSender().beginFileTransfer(f, request)
+ )
finish_request(request)
else:
respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request (twisted.web.http.Request)
+ media_type (str): The media/content type.
+ file_size (int): Size in bytes of the media, if known.
+ upload_name (str): The name of the requested file, if any.
+ """
+ request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+ )
+
+ request.setHeader(
+ b"Content-Length", b"%d" % (file_size,)
+ )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request (twisted.web.http.Request)
+ responder (Responder|None)
+ media_type (str): The media/content type.
+ file_size (int|None): Size in bytes of the media. If not known it should be None
+ upload_name (str|None): The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ logger.debug("Responding to media request with responder %s")
+ add_file_headers(request, media_type, file_size, upload_name)
+ with responder:
+ yield responder.write_to_consumer(request)
+ finish_request(request)
+
+
+class Responder(object):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+ def write_to_consumer(self, consumer):
+ """Stream response into consumer
+
+ Args:
+ consumer (IConsumer)
+
+ Returns:
+ Deferred: Resolves once the response has finished being written
+ """
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FileInfo(object):
+ """Details about a requested/uploaded file.
+
+ Attributes:
+ server_name (str): The server name where the media originated from,
+ or None if local.
+ file_id (str): The local ID of the file. For local files this is the
+ same as the media_id
+ url_cache (bool): If the file is for the url preview cache
+ thumbnail (bool): Whether the file is a thumbnail or not.
+ thumbnail_width (int)
+ thumbnail_height (int)
+ thumbnail_method (str)
+ thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ """
+ def __init__(self, server_name, file_id, url_cache=False,
+ thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+ thumbnail_method=None, thumbnail_type=None):
+ self.server_name = server_name
+ self.file_id = file_id
+ self.url_cache = url_cache
+ self.thumbnail = thumbnail
+ self.thumbnail_width = thumbnail_width
+ self.thumbnail_height = thumbnail_height
+ self.thumbnail_method = thumbnail_method
+ self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index dfb87ffd15..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse.http.servlet
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
@@ -31,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
+
+ # Both of these are expected by @request_handler()
self.clock = hs.get_clock()
+ self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
@@ -56,43 +57,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id, name)
+ yield self.media_repo.get_local_media(request, media_id, name)
else:
- yield self._respond_remote_file(
- request, server_name, media_id, name
- )
-
- @defer.inlineCallbacks
- def _respond_local_file(self, request, media_id, name):
- media_info = yield self.store.get_local_media(media_id)
- if not media_info:
- respond_404(request)
- return
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- file_path = self.filepaths.local_media_filepath(media_id)
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
-
- @defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id, name):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- filesystem_id = media_info["filesystem_id"]
- upload_name = name if name else media_info["upload_name"]
-
- file_path = self.filepaths.remote_media_filepath(
- server_name, filesystem_id
- )
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
+ yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 0137458f71..d5164e47e0 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -14,60 +14,200 @@
# limitations under the License.
import os
+import re
+import functools
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
+
+
+def _wrap_in_base_path(func):
+ """Takes a function that returns a relative path and turns it into an
+ absolute path based on the location of the primary media store
+ """
+ @functools.wraps(func)
+ def _wrapped(self, *args, **kwargs):
+ path = func(self, *args, **kwargs)
+ return os.path.join(self.base_path, path)
+
+ return _wrapped
class MediaFilePaths(object):
+ """Describes where files are stored on disk.
- def __init__(self, base_path):
- self.base_path = base_path
+ Most of the functions have a `*_rel` variant which returns a file path that
+ is relative to the base media store path. This is mainly used when we want
+ to write to the backup media store (when one is configured)
+ """
- def default_thumbnail(self, default_top_level, default_sub_type, width,
- height, content_type, method):
+ def __init__(self, primary_base_path):
+ self.base_path = primary_base_path
+
+ def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
+ height, content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "default_thumbnails", default_top_level,
+ "default_thumbnails", default_top_level,
default_sub_type, file_name
)
- def local_media_filepath(self, media_id):
+ default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
+
+ def local_media_filepath_rel(self, media_id):
return os.path.join(
- self.base_path, "local_content",
+ "local_content",
media_id[0:2], media_id[2:4], media_id[4:]
)
- def local_media_thumbnail(self, media_id, width, height, content_type,
- method):
+ local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+
+ def local_media_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "local_thumbnails",
+ "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
- def remote_media_filepath(self, server_name, file_id):
+ local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+
+ def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
- self.base_path, "remote_content", server_name,
+ "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:]
)
- def remote_media_thumbnail(self, server_name, file_id, width, height,
- content_type, method):
+ remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+
+ def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
+ content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
- self.base_path, "remote_thumbnail", server_name,
+ "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
+ remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)
+
+ def url_cache_filepath_rel(self, media_id):
+ if NEW_FORMAT_ID_RE.match(media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ return os.path.join(
+ "url_cache",
+ media_id[:10], media_id[11:]
+ )
+ else:
+ return os.path.join(
+ "url_cache",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+
+ def url_cache_filepath_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id file"
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2],
+ ),
+ ]
+
+ def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s-%s" % (
+ width, height, top_level_type, sub_type, method
+ )
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ file_name
+ )
+ else:
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ file_name
+ )
+
+ url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+
+ def url_cache_thumbnail_directory(self, media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id thumbnails"
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2],
+ ),
+ ]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 3cbeca503c..9800ce7581 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,26 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer, threads
+import twisted.internet.error
+import twisted.web.http
+from twisted.web.resource import Resource
+
+from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
-
-from twisted.web.resource import Resource
-
from .thumbnailer import Thumbnailer
+from .storage_provider import StorageProviderWrapper
+from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError
-
-from twisted.internet import defer, threads
+from synapse.api.errors import (
+ SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.retryutils import NotRetryingDestination
import os
import errno
@@ -40,12 +47,12 @@ import shutil
import cgi
import logging
-import urlparse
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -57,46 +64,90 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+ self.primary_base_path = hs.config.media_store_path
+ self.filepaths = MediaFilePaths(self.primary_base_path)
+
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
+
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+
+ # List of StorageProviders where we should search for media and
+ # potentially upload to.
+ storage_providers = []
+
+ for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ backend = clz(hs, provider_config)
+ provider = StorageProviderWrapper(
+ backend,
+ store_local=wrapper_config.store_local,
+ store_remote=wrapper_config.store_remote,
+ store_synchronous=wrapper_config.store_synchronous,
+ )
+ storage_providers.append(provider)
+
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
+ )
self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
)
@defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
+
yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
+ local_media, remote_media, self.clock.time_msec()
)
- @staticmethod
- def _makedirs(filepath):
- dirname = os.path.dirname(filepath)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
+
+ Args:
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
+ """
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
+ """Store uploaded content for a local user and return the mxc URL
+
+ Args:
+ media_type(str): The content type of the file
+ upload_name(str): The name of the file
+ content: A file like object that is the content to store
+ content_length(int): The length of the content
+ auth_user(str): The user_id of the uploader
+
+ Returns:
+ Deferred[str]: The mxc url of the stored content
+ """
media_id = random_string(24)
- fname = self.filepaths.local_media_filepath(media_id)
- self._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ )
- # This shouldn't block for very long because the content will have
- # already been uploaded at this point.
- with open(fname, "wb") as f:
- f.write(content)
+ fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
@@ -108,104 +159,275 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
- media_info = {
- "media_type": media_type,
- "media_length": content_length,
- }
- yield self._generate_local_thumbnails(media_id, media_info)
+ yield self._generate_thumbnails(
+ None, media_id, media_id, media_type,
+ )
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
- def get_remote_media(self, server_name, media_id):
+ def get_local_media(self, request, media_id, name):
+ """Responds to reqests for local media, if exists, or returns 404.
+
+ Args:
+ request(twisted.web.http.Request)
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content.)
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info or media_info["quarantined_by"]:
+ respond_404(request)
+ return
+
+ self.mark_recently_accessed(None, media_id)
+
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ url_cache = media_info["url_cache"]
+
+ file_info = FileInfo(
+ None, media_id,
+ url_cache=url_cache,
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_media(self, request, server_name, media_id, name):
+ """Respond to requests for remote media.
+
+ Args:
+ request(twisted.web.http.Request)
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ self.mark_recently_accessed(server_name, media_id)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
+ key = (server_name, media_id)
+ with (yield self.remote_media_linearizer.queue(key)):
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # We deliberately stream the file outside the lock
+ if responder:
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+ else:
+ respond_404(request)
+
+ @defer.inlineCallbacks
+ def get_remote_media_info(self, server_name, media_id):
+ """Gets the media info associated with the remote file, downloading
+ if necessary.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[dict]: The media_info of the file
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
- media_info = yield self._get_remote_media_impl(server_name, media_id)
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # Ensure we actually use the responder so that it releases resources
+ if responder:
+ with responder:
+ pass
+
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
+ """Looks for media in local cache, if not there then attempt to
+ download from remote server.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[(Responder, media_info)]
+ """
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
- if not media_info:
- media_info = yield self._download_remote_file(
- server_name, media_id
- )
+
+ # file_id is the ID we use to track the file locally. If we've already
+ # seen the file then reuse the existing ID, otherwise genereate a new
+ # one.
+ if media_info:
+ file_id = media_info["filesystem_id"]
else:
- self.recently_accessed_remotes.add((server_name, media_id))
- yield self.store.update_cached_last_access_time(
- [(server_name, media_id)], self.clock.time_msec()
- )
- defer.returnValue(media_info)
+ file_id = random_string(24)
- @defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id):
- file_id = random_string(24)
+ file_info = FileInfo(server_name, file_id)
+
+ # If we have an entry in the DB, try and look for it
+ if media_info:
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ raise NotFoundError()
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ defer.returnValue((responder, media_info))
+
+ # Failed to find the file anywhere, lets download it.
- fname = self.filepaths.remote_media_filepath(
- server_name, file_id
+ media_info = yield self._download_remote_file(
+ server_name, media_id, file_id
)
- self._makedirs(fname)
- try:
- with open(fname, "wb") as f:
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ responder = yield self.media_storage.fetch_media(file_info)
+ defer.returnValue((responder, media_info))
+
+ @defer.inlineCallbacks
+ def _download_remote_file(self, server_name, media_id, file_id):
+ """Attempt to download the remote file from the given server name,
+ using the given file_id as the local id.
+
+ Args:
+ server_name (str): Originating server
+ media_id (str): The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ file_id (str): Local file ID
+
+ Returns:
+ Deferred[MediaInfo]
+ """
+
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ )
+
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ request_path = "/".join((
+ "/_matrix/media/v1/download", server_name, media_id,
+ ))
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
+ )
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+ except NotRetryingDestination:
+ logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ yield finish()
+
+ media_type = headers["Content-Type"][0]
+
+ time_now_ms = self.clock.time_msec()
+
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
try:
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size,
- )
- except Exception as e:
- logger.warn("Failed to fetch remoted media %r", e)
- raise SynapseError(502, "Failed to fetch remoted media")
-
- media_type = headers["Content-Type"][0]
- time_now_ms = self.clock.time_msec()
-
- content_disposition = headers.get("Content-Disposition", None)
- if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get("filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith("utf-8''"):
- upload_name = upload_name_utf8[7:]
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get("filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii
-
- if upload_name:
- upload_name = urlparse.unquote(upload_name)
- try:
- upload_name = upload_name.decode("utf-8")
- except UnicodeDecodeError:
- upload_name = None
- else:
- upload_name = None
-
- logger.info("Stored remote media in file %r", fname)
-
- yield self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
- except:
- os.remove(fname)
- raise
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
+ logger.info("Stored remote media in file %r", fname)
+
+ yield self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
media_info = {
"media_type": media_type,
@@ -215,8 +437,8 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_remote_thumbnails(
- server_name, media_id, media_info
+ yield self._generate_thumbnails(
+ server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -224,9 +446,8 @@ class MediaRepository(object):
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+ def _generate_thumbnail(self, thumbnailer, t_width, t_height,
t_method, t_type):
- thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -238,69 +459,126 @@ class MediaRepository(object):
return
if t_method == "crop":
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ t_width, t_height = thumbnailer.aspect(t_width, t_height)
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
+ t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
- t_len = None
+ t_byte_source = None
- return t_len
+ return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type):
- input_path = self.filepaths.local_media_filepath(media_id)
-
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
+ t_method, t_type, url_cache):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ None, media_id, url_cache=url_cache,
+ ))
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ url_cache=url_cache,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=False,
+ ))
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_local_thumbnails(self, media_id, media_info):
- media_type = media_info["media_type"]
+ def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+ url_cache=False):
+ """Generate and store thumbnails for an image.
+
+ Args:
+ server_name (str|None): The server name if remote media, else None if local
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content)
+ file_id (str): Local file ID
+ media_type (str): The content type of the file
+ url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ used exclusively by the url previewer
+
+ Returns:
+ Deferred[dict]: Dict with "width" and "height" keys of original image
+ """
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
- input_path = self.filepaths.local_media_filepath(media_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=url_cache,
+ ))
+
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -312,125 +590,68 @@ class MediaRepository(object):
)
return
- local_thumbnails = []
-
- def generate_thumbnails():
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+ # they have the same dimensions of a scaled one.
+ thumbnails = {}
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "crop":
+ thumbnails.setdefault((r_width, r_height, r_type), r_method)
+ elif r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
+ thumbnails[(t_width, t_height, r_type)] = r_method
+
+ # Now we generate the thumbnails for each dimension, store it
+ for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
+ # Generate the thumbnail
+ if t_method == "crop":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.crop,
+ t_width, t_height, t_type,
))
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ elif t_method == "scale":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.scale,
+ t_width, t_height, t_type,
))
+ else:
+ logger.error("Unrecognized method: %r", t_method)
+ continue
+
+ if not t_byte_source:
+ continue
+
+ try:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for l in local_thumbnails:
- yield self.store.store_local_thumbnail(*l)
-
- defer.returnValue({
- "width": m_width,
- "height": m_height,
- })
-
- @defer.inlineCallbacks
- def _generate_remote_thumbnails(self, server_name, media_id, media_info):
- media_type = media_info["media_type"]
- file_id = media_info["filesystem_id"]
- requirements = self._get_thumbnail_requirements(media_type)
- if not requirements:
- return
-
- remote_thumbnails = []
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- thumbnailer = Thumbnailer(input_path)
- m_width = thumbnailer.width
- m_height = thumbnailer.height
+ t_len = os.path.getsize(output_path)
- def generate_thumbnails():
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
- )
- return
-
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
+ # Write to database
+ if server_name:
+ yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
- ])
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
)
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- ])
-
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for r in remote_thumbnails:
- yield self.store.store_remote_media_thumbnail(*r)
+ else:
+ yield self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
defer.returnValue({
"width": m_width,
@@ -451,6 +672,8 @@ class MediaRepository(object):
logger.info("Deleting: %r", key)
+ # TODO: Should we delete from the backup store
+
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
@@ -525,7 +748,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+ self.putChild("thumbnail", ThumbnailResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+ self.putChild("preview_url", PreviewUrlResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..d23fe10b07
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,263 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+import six
+
+from ._base import Responder
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ local_media_directory (str): Base path where we store media on disk
+ filepaths (MediaFilePaths)
+ storage_providers ([StorageProvider]): List of StorageProvider that are
+ used to fetch and store files.
+ """
+
+ def __init__(self, local_media_directory, filepaths, storage_providers):
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+
+ @defer.inlineCallbacks
+ def store_file(self, source, file_info):
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info (FileInfo): Info about the file to store
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, f,
+ ))
+ yield finish_cb()
+
+ defer.returnValue(fname)
+
+ @contextlib.contextmanager
+ def store_into_file(self, file_info):
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns a Deferred.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info (FileInfo): Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ yield finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ finished_called = [False]
+
+ @defer.inlineCallbacks
+ def finish():
+ for provider in self.storage_providers:
+ yield provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ try:
+ with open(fname, "wb") as f:
+ yield f, fname, finish
+ except Exception:
+ t, v, tb = sys.exc_info()
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+ six.reraise(t, v, tb)
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ @defer.inlineCallbacks
+ def fetch_media(self, file_info):
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[Responder|None]: Returns a Responder if the file was found,
+ otherwise None.
+ """
+
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(FileResponder(open(local_path, "rb")))
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ defer.returnValue(res)
+
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def ensure_media_is_in_local_cache(self, file_info):
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[str]: Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(local_path)
+
+ dirname = os.path.dirname(local_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(open(local_path, "w"))
+ yield res.write_to_consumer(consumer)
+ yield consumer.wait()
+ defer.returnValue(local_path)
+
+ raise Exception("file could not be found")
+
+ def _file_info_to_path(self, file_info):
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ str
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id,
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.local_media_filepath_rel(
+ file_info.file_id,
+ )
+
+
+def _write_file_synchronously(source, dest):
+ """Write `source` to the file like `dest` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object that's to be written
+ dest: A file like object to be written to
+ """
+ source.seek(0) # Ensure we read from the start of the file
+ shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file (file): A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+ def __init__(self, open_file):
+ self.open_file = open_file
+
+ def write_to_consumer(self, consumer):
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 99760d622f..9290d7946f 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,39 +12,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import simplejson as json
+import urlparse
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
+from ._base import FileInfo
+
from synapse.api.errors import (
SynapseError, Codes,
)
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
- request_handler, respond_with_json_bytes
+ request_handler, respond_with_json_bytes,
+ respond_with_json,
)
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-
-import logging
logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
@@ -56,19 +64,27 @@ class PreviewUrlResource(Resource):
self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
+ self.primary_base_path = media_repo.primary_base_path
+ self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
- # simple memory cache mapping urls to OG metadata
- self.cache = ExpiringCache(
+ # memory cache mapping urls to an ObservableDeferred returning
+ # JSON-encoded OG metadata
+ self._cache = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
- self.cache.start()
+ self._cache.start()
+
+ self._cleaner_loop = self.clock.looping_call(
+ self._expire_url_cache_data, 10 * 1000
+ )
- self.downloads = {}
+ def render_OPTIONS(self, request):
+ return respond_with_json(request, 200, {}, send_cors=True)
def render_GET(self, request):
self._async_render_GET(request)
@@ -86,6 +102,7 @@ class PreviewUrlResource(Resource):
else:
ts = self.clock.time_msec()
+ # XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
@@ -118,53 +135,62 @@ class PreviewUrlResource(Resource):
Codes.UNKNOWN
)
- # first check the memory cache - good to handle all the clients on this
- # HS thundering away to preview the same URL at the same time.
- og = self.cache.get(url)
- if og:
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
- return
+ # the in-memory cache:
+ # * ensures that only one request is active at a time
+ # * takes load off the DB for the thundering herds
+ # * also caches any failures (unlike the DB) so we don't keep
+ # requesting the same endpoint
+
+ observable = self._cache.get(url)
+
+ if not observable:
+ download = run_in_background(
+ self._do_preview,
+ url, requester.user, ts,
+ )
+ observable = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
+ self._cache[url] = observable
+ else:
+ logger.info("Returning cached response")
- # then check the URL cache in the DB (which will also provide us with
+ og = yield make_deferred_yieldable(observable.observe())
+ respond_with_json_bytes(request, 200, og, send_cors=True)
+
+ @defer.inlineCallbacks
+ def _do_preview(self, url, user, ts):
+ """Check the db, and download the URL and build a preview
+
+ Args:
+ url (str):
+ user (str):
+ ts (int):
+
+ Returns:
+ Deferred[str]: json-encoded og data
+ """
+ # check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result and
- cache_result["download_ts"] + cache_result["expires"] > ts and
+ cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
- respond_with_json_bytes(
- request, 200, cache_result["og"].encode('utf-8'),
- send_cors=True
- )
+ defer.returnValue(cache_result["og"])
return
- # Ensure only one download for a given URL is active at a time
- download = self.downloads.get(url)
- if download is None:
- download = self._download_url(url, requester.user)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.downloads[url] = download
-
- @download.addBoth
- def callback(media_info):
- del self.downloads[url]
- return media_info
- media_info = yield download.observe()
-
- # FIXME: we should probably update our cache now anyway, so that
- # even if the OG calculation raises, we don't keep hammering on the
- # remote server. For now, leave it uncached to aid debugging OG
- # calculation problems
+ media_info = yield self._download_url(url, user)
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
- dims = yield self.media_repo._generate_local_thumbnails(
- media_info['filesystem_id'], media_info
+ file_id = media_info['filesystem_id']
+ dims = yield self.media_repo._generate_thumbnails(
+ None, file_id, file_id, media_info["media_type"],
+ url_cache=True,
)
og = {
@@ -204,13 +230,15 @@ class PreviewUrlResource(Resource):
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
- _rebase_url(og['og:image'], media_info['uri']), requester.user
+ _rebase_url(og['og:image'], media_info['uri']), user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
- dims = yield self.media_repo._generate_local_thumbnails(
- image_info['filesystem_id'], image_info
+ file_id = image_info['filesystem_id']
+ dims = yield self.media_repo._generate_thumbnails(
+ None, file_id, file_id, image_info["media_type"],
+ url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -231,21 +259,20 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- # store OG in ephemeral in-memory cache
- self.cache[url] = og
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
- media_info["expires"],
- json.dumps(og),
+ media_info["expires"] + media_info["created_ts"],
+ jsonog,
media_info["filesystem_id"],
media_info["created_ts"],
)
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+ defer.returnValue(jsonog)
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -253,21 +280,36 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- # XXX: horrible duplication with base_resource's _download_remote_file()
- file_id = random_string(24)
+ file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fname = self.filepaths.local_media_filepath(file_id)
- self.media_repo._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=file_id,
+ url_cache=True,
+ )
- try:
- with open(fname, "wb") as f:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
+ except Exception as e:
# FIXME: pass through 404s and other error messages nicely
+ logger.warn("Error downloading %s: %r", url, e)
+ raise SynapseError(
+ 500, "Failed to download content: %s" % (
+ traceback.format_exception_only(sys.exc_type, e),
+ ),
+ Codes.UNKNOWN,
+ )
+ yield finish()
- media_type = headers["Content-Type"][0]
+ try:
+ if "Content-Type" in headers:
+ media_type = headers["Content-Type"][0]
+ else:
+ media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
@@ -303,14 +345,15 @@ class PreviewUrlResource(Resource):
upload_name=download_name,
media_length=length,
user_id=user,
+ url_cache=url,
)
except Exception as e:
- os.remove(fname)
- raise SynapseError(
- 500, ("Failed to download content: %s" % e),
- Codes.UNKNOWN
- )
+ logger.error("Error handling downloaded %s: %r", url, e)
+ # TODO: we really ought to delete the downloaded file in this
+ # case, since we won't have recorded it in the db, and will
+ # therefore not expire it.
+ raise
defer.returnValue({
"media_type": media_type,
@@ -327,6 +370,95 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
+ @defer.inlineCallbacks
+ def _expire_url_cache_data(self):
+ """Clean up expired url cache content, media and thumbnails.
+ """
+ # TODO: Delete from backup media store
+
+ now = self.clock.time_msec()
+
+ logger.info("Running url preview cache expiry")
+
+ if not (yield self.store.has_completed_background_updates()):
+ logger.info("Still running DB updates; skipping expiry")
+ return
+
+ # First we delete expired url cache entries
+ media_ids = yield self.store.get_expired_url_cache(now)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ yield self.store.delete_url_cache(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d entries from url cache", len(removed_media))
+
+ # Now we delete old images associated with the url cache.
+ # These may be cached for a bit on the client (i.e., they
+ # may have a room open with a preview url thing open).
+ # So we wait a couple of days before deleting, just in case.
+ expire_before = now - 2 * 24 * 60 * 60 * 1000
+ media_ids = yield self.store.get_url_cache_media_before(expire_before)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
+ try:
+ shutil.rmtree(thumbnail_dir)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ yield self.store.delete_url_cache_media(removed_media)
+
+ logger.info("Deleted %d media from url cache", len(removed_media))
+
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
@@ -424,7 +556,14 @@ def _calc_og(tree, media_uri):
from lxml import etree
TAGS_TO_REMOVE = (
- "header", "nav", "aside", "footer", "script", "style", etree.Comment
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
@@ -434,6 +573,8 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
+ else:
+ og['og:description'] = summarize_paragraphs([og['og:description']])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..0252afd9d3
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.config._base import Config
+from synapse.util.logcontext import run_in_background
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+ def store_file(self, path, file_info):
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred
+ """
+ pass
+
+ def fetch(self, path, file_info):
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred(Responder): Returns a Responder if the provider has the file,
+ otherwise returns None.
+ """
+ pass
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend (StorageProvider)
+ store_local (bool): Whether to store new local files or not.
+ store_synchronous (bool): Whether to wait for file to be successfully
+ uploaded, or todo the upload in the backgroud.
+ store_remote (bool): Whether remote media should be uploaded
+ """
+ def __init__(self, backend, store_local, store_synchronous, store_remote):
+ self.backend = backend
+ self.store_local = store_local
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def store_file(self, path, file_info):
+ if not file_info.server_name and not self.store_local:
+ return defer.succeed(None)
+
+ if file_info.server_name and not self.store_remote:
+ return defer.succeed(None)
+
+ if self.store_synchronous:
+ return self.backend.store_file(path, file_info)
+ else:
+ # TODO: Handle errors.
+ def store():
+ try:
+ return self.backend.store_file(path, file_info)
+ except Exception:
+ logger.exception("Error storing file")
+ run_in_background(store)
+ return defer.succeed(None)
+
+ def fetch(self, path, file_info):
+ return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ hs (HomeServer)
+ config: The config returned by `parse_config`.
+ """
+
+ def __init__(self, hs, config):
+ self.cache_directory = hs.config.media_store_path
+ self.base_directory = config
+
+ def store_file(self, path, file_info):
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ return threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
+ def fetch(self, path, file_info):
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
+
+ @staticmethod
+ def parse_config(config):
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d8f54adc99..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
# limitations under the License.
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+ parse_media_id, respond_404, respond_with_file, FileInfo,
+ respond_with_responder,
+)
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
+ self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,6 +79,7 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
@@ -84,11 +89,10 @@ class ThumbnailResource(Resource):
if not media_info:
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ respond_404(request)
+ return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -96,20 +100,25 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
)
- yield respond_with_file(request, t_type, file_path)
+ t_type = file_info.thumbnail_type
+ t_length = thumbnail_info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
- )
+ logger.info("Couldn't find any generated thumbnails")
+ respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
@@ -120,11 +129,10 @@ class ThumbnailResource(Resource):
if not media_info:
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ respond_404(request)
+ return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@@ -134,37 +142,43 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type
+ media_id, desired_width, desired_height, desired_method, desired_type,
+ url_cache=media_info["url_cache"],
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -179,14 +193,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, desired_width, desired_height,
- desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -197,22 +221,16 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead.
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -222,59 +240,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_id = thumbnail_info["filesystem_id"]
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
- )
-
- @defer.inlineCallbacks
- def _respond_default_thumbnail(self, request, media_info, width, height,
- method, m_type):
- # XXX: how is this meant to work? store.get_default_thumbnails
- # appears to always return [] so won't this always 404?
- media_type = media_info["media_type"]
- top_level_type = media_type.split("/")[0]
- sub_type = media_type.split("/")[-1].split(";")[0]
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, sub_type,
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, "_default",
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- "_default", "_default",
- )
- if not thumbnail_infos:
+ logger.info("Failed to find any generated thumbnails")
respond_404(request)
- return
-
- thumbnail_info = self._select_thumbnail(
- width, height, "crop", m_type, thumbnail_infos
- )
-
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- t_length = thumbnail_info["thumbnail_length"]
-
- file_path = self.filepaths.default_thumbnail(
- top_level_type, sub_type, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3868d4f65f..e1ee535b9a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -50,12 +50,16 @@ class Thumbnailer(object):
else:
return ((max_height * self.width) // self.height, max_height)
- def scale(self, output_path, width, height, output_type):
- """Rescales the image to the given dimensions"""
+ def scale(self, width, height, output_type):
+ """Rescales the image to the given dimensions.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
+ """
scaled = self.image.resize((width, height), Image.ANTIALIAS)
- return self.save_image(scaled, output_type, output_path)
+ return self._encode_image(scaled, output_type)
- def crop(self, output_path, width, height, output_type):
+ def crop(self, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -65,6 +69,9 @@ class Thumbnailer(object):
Args:
max_width: The largest possible width.
max_height: The larget possible height.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
@@ -82,13 +89,9 @@ class Thumbnailer(object):
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
- return self.save_image(cropped, output_type, output_path)
+ return self._encode_image(cropped, output_type)
- def save_image(self, output_image, output_type, output_path):
+ def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
- output_bytes = output_bytes_io.getvalue()
- with open(output_path, "wb") as output_file:
- output_file.write(output_bytes)
- logger.info("Stored thumbnail in file %r", output_path)
- return len(output_bytes)
+ return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 4ab33f73bf..a31e75cb46 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -81,19 +81,19 @@ class UploadResource(Resource):
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders("Content-Type")[0]
+ media_type = headers.getRawHeaders(b"Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
- # if headers.hasHeader("Content-Disposition"):
- # disposition = headers.getRawHeaders("Content-Disposition")[0]
+ # if headers.hasHeader(b"Content-Disposition"):
+ # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
- media_type, upload_name, request.content.read(),
+ media_type, upload_name, request.content,
content_length, requester.user
)
diff --git a/synapse/server.py b/synapse/server.py
index c577032041..ebdea6b0c4 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,29 +31,47 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
-from synapse.federation import initialize_http_replication
+from synapse.events.spamcheck import SpamChecker
+from synapse.federation.federation_client import FederationClient
+from synapse.federation.federation_server import FederationServer
from synapse.federation.send_queue import FederationRemoteSendQueue
+from synapse.federation.federation_server import FederationHandlerRegistry
from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
+from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room_list import RoomListHandler
+from synapse.handlers.room_member import RoomMemberMasterHandler
+from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
+from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler
+from synapse.handlers.read_marker import ReadMarkerHandler
+from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.message import EventCreationHandler
+from synapse.groups.groups_server import GroupsServerHandler
+from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
+from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
-from synapse.rest.media.v1.media_repository import MediaRepository
-from synapse.state import StateHandler
+from synapse.rest.media.v1.media_repository import (
+ MediaRepository,
+ MediaRepositoryResource,
+)
+from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -82,18 +100,14 @@ class HomeServer(object):
"""
DEPENDENCIES = [
- 'config',
- 'clock',
'http_client',
'db_pool',
- 'persistence_service',
- 'replication_layer',
- 'datastore',
+ 'federation_client',
+ 'federation_server',
'handlers',
- 'v1auth',
'auth',
- 'rest_servlet_factory',
'state_handler',
+ 'state_resolution_handler',
'presence_handler',
'sync_handler',
'typing_handler',
@@ -108,19 +122,12 @@ class HomeServer(object):
'application_service_scheduler',
'application_service_handler',
'device_message_handler',
+ 'profile_handler',
+ 'event_creation_handler',
+ 'deactivate_account_handler',
+ 'set_password_handler',
'notifier',
- 'distributor',
- 'client_resource',
- 'resource_for_federation',
- 'resource_for_static_content',
- 'resource_for_web_client',
- 'resource_for_content_repo',
- 'resource_for_server_key',
- 'resource_for_server_key_v2',
- 'resource_for_media_repository',
- 'resource_for_metrics',
'event_sources',
- 'ratelimiter',
'keyring',
'pusherpool',
'event_builder_factory',
@@ -128,10 +135,22 @@ class HomeServer(object):
'http_client_context_factory',
'simple_http_client',
'media_repository',
+ 'media_repository_resource',
'federation_transport_client',
'federation_sender',
'receipts_handler',
'macaroon_generator',
+ 'tcp_replication',
+ 'read_marker_handler',
+ 'action_generator',
+ 'user_directory_handler',
+ 'groups_local_handler',
+ 'groups_server_handler',
+ 'groups_attestation_signing',
+ 'groups_attestation_renewer',
+ 'spam_checker',
+ 'room_member_handler',
+ 'federation_registry',
]
def __init__(self, hostname, **kwargs):
@@ -165,8 +184,26 @@ class HomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
- def build_replication_layer(self):
- return initialize_http_replication(self)
+ def get_clock(self):
+ return self.clock
+
+ def get_datastore(self):
+ return self.datastore
+
+ def get_config(self):
+ return self.config
+
+ def get_distributor(self):
+ return self.distributor
+
+ def get_ratelimiter(self):
+ return self.ratelimiter
+
+ def build_federation_client(self):
+ return FederationClient(self)
+
+ def build_federation_server(self):
+ return FederationServer(self)
def build_handlers(self):
return Handlers(self)
@@ -187,18 +224,12 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
- def build_v1auth(self):
- orf = Auth(self)
- # Matrix spec makes no reference to what HTTP status code is returned,
- # but the V1 API uses 403 where it means 401, and the webclient
- # relies on this behaviour, so V1 gets its own copy of the auth
- # with backwards compat behaviour.
- orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
- return orf
-
def build_state_handler(self):
return StateHandler(self)
+ def build_state_resolution_handler(self):
+ return StateResolutionHandler(self)
+
def build_presence_handler(self):
return PresenceHandler(self)
@@ -244,6 +275,18 @@ class HomeServer(object):
def build_initial_sync_handler(self):
return InitialSyncHandler(self)
+ def build_profile_handler(self):
+ return ProfileHandler(self)
+
+ def build_event_creation_handler(self):
+ return EventCreationHandler(self)
+
+ def build_deactivate_account_handler(self):
+ return DeactivateAccountHandler(self)
+
+ def build_set_password_handler(self):
+ return SetPasswordHandler(self)
+
def build_event_sources(self):
return EventSources(self)
@@ -273,6 +316,28 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def get_db_conn(self, run_new_connection=True):
+ """Makes a new connection to the database, skipping the db pool
+
+ Returns:
+ Connection: a connection object implementing the PEP-249 spec
+ """
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def build_media_repository_resource(self):
+ # build the media repo resource. This indirects through the HomeServer
+ # to ensure that we only have a single instance of
+ return MediaRepositoryResource(self)
+
def build_media_repository(self):
return MediaRepository(self)
@@ -290,6 +355,41 @@ class HomeServer(object):
def build_receipts_handler(self):
return ReceiptsHandler(self)
+ def build_read_marker_handler(self):
+ return ReadMarkerHandler(self)
+
+ def build_tcp_replication(self):
+ raise NotImplementedError()
+
+ def build_action_generator(self):
+ return ActionGenerator(self)
+
+ def build_user_directory_handler(self):
+ return UserDirectoryHandler(self)
+
+ def build_groups_local_handler(self):
+ return GroupsLocalHandler(self)
+
+ def build_groups_server_handler(self):
+ return GroupsServerHandler(self)
+
+ def build_groups_attestation_signing(self):
+ return GroupAttestationSigning(self)
+
+ def build_groups_attestation_renewer(self):
+ return GroupAttestionRenewer(self)
+
+ def build_spam_checker(self):
+ return SpamChecker(self)
+
+ def build_room_member_handler(self):
+ if self.config.worker_app:
+ return RoomMemberWorkerHandler(self)
+ return RoomMemberMasterHandler(self)
+
+ def build_federation_registry(self):
+ return FederationHandlerRegistry()
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9570df5537..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,10 +1,16 @@
import synapse.api.auth
+import synapse.federation.transaction_queue
+import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
+import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
-import synapse.storage
+import synapse.handlers.set_password
+import synapse.rest.media.v1.media_repository
import synapse.state
+import synapse.storage
+
class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth:
@@ -27,3 +33,24 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
+
+ def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+ pass
+
+ def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+ pass
+
+ def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+ pass
+
+ def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
+ pass
+
+ def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+ pass
+
+ def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+ pass
+
+ def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+ pass
diff --git a/synapse/state.py b/synapse/state.py
index 383d32b163..26093c8434 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
+from synapse.util.caches import CACHE_SIZE_FACTOR
from collections import namedtuple
from frozendict import frozendict
import logging
import hashlib
-import os
logger = logging.getLogger(__name__)
@@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
@@ -61,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ # dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
+
+ # the ID of a state group if one and only one is involved.
+ # otherwise, None otherwise?
self.state_group = state_group
self.prev_group = prev_group
@@ -84,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object):
- """ Responsible for doing state conflict resolution.
+ """Fetches bits of state from the stores, and does state resolution
+ where necessary
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
-
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
- self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ self._state_resolution_handler = hs.get_state_resolution_handler()
def start_caching(self):
- logger.debug("start_caching")
-
- self._state_cache = ExpiringCache(
- cache_name="state_cache",
- clock=self.clock,
- max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
- iterable=True,
- reset_expiry_on_get=True,
- )
-
- self._state_cache.start()
+ # TODO: remove this shim
+ self._state_resolution_handler.start_caching()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
@@ -130,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
@@ -143,25 +132,33 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = {
- key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+ key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
}
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_state_ids(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state_ids(self, room_id, latest_event_ids=None):
+ """Get the current state, or the state at a set of events, for a room
+
+ Args:
+ room_id (str):
+
+ latest_event_ids (iterable[str]|None): if given, the forward
+ extremities to resolve. If None, we look them up from the
+ database (via a cache)
+
+ Returns:
+ Deferred[dict[(str, str), str)]]: the state dict, mapping from
+ (event_type, state_key) -> event_id
+ """
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- if event_type:
- defer.returnValue(state.get((event_type, state_key)))
- return
-
defer.returnValue(state)
@defer.inlineCallbacks
@@ -169,54 +166,71 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_state(
- room_id, entry.state_id, entry.state
- )
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@defer.inlineCallbacks
+ def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
+ defer.returnValue(joined_hosts)
+
+ @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
- """ Fills out the context with the `current state` of the graph. The
- `current state` here is defined to be the state of the event graph
- just before the event - i.e. it never includes `event`
+ """Build an EventContext structure for the event.
- If `event` has `auth_events` then this will also fill out the
- `auth_events` field on `context` from the `current_state`.
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
Args:
- event (EventBase)
+ event (synapse.events.EventBase):
+ old_state (dict|None): The state at the event if it can't be
+ calculated from existing events. This is normally only specified
+ when receiving an event from federation where we don't have the
+ prev events for, e.g. when backfilling.
Returns:
- an EventContext
+ synapse.events.snapshot.EventContext:
"""
- context = EventContext()
if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
+ context = EventContext()
if old_state:
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
- context.current_state_events = dict(context.prev_state_ids)
+ context.current_state_ids = dict(context.prev_state_ids)
key = (event.type, event.state_key)
- context.current_state_events[key] = event.event_id
+ context.current_state_ids[key] = event.event_id
else:
- context.current_state_events = context.prev_state_ids
+ context.current_state_ids = context.prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
- context.state_group = self.store.get_next_state_group()
+
+ # We don't store state for outliers, so we don't generate a state
+ # froup for it.
+ context.state_group = None
+
defer.returnValue(context)
if old_state:
+ # We already have the state, so we don't need to calculate it.
+ # Let's just correctly fill out the context and create a
+ # new state group for it.
+
+ context = EventContext()
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
- context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
@@ -229,26 +243,29 @@ class StateHandler(object):
else:
context.current_state_ids = context.prev_state_ids
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=None,
+ delta_ids=None,
+ current_state_ids=context.current_state_ids,
+ )
+
context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
- if event.is_state():
- entry = yield self.resolve_state_groups(
- event.room_id, [e for e, _ in event.prev_events],
- event_type=event.type,
- state_key=event.state_key,
- )
- else:
- entry = yield self.resolve_state_groups(
- event.room_id, [e for e, _ in event.prev_events],
- )
+ entry = yield self.resolve_state_groups_for_events(
+ event.room_id, [e for e, _ in event.prev_events],
+ )
curr_state = entry.state
+ context = EventContext()
context.prev_state_ids = curr_state
if event.is_state():
- context.state_group = self.store.get_next_state_group()
+ # If this is a state event then we need to create a new state
+ # group for the state after this event.
key = (event.type, event.state_key)
if key in context.prev_state_ids:
@@ -258,58 +275,176 @@ class StateHandler(object):
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
- context.prev_group = entry.prev_group
- context.delta_ids = entry.delta_ids
- if context.delta_ids is not None:
- context.delta_ids = dict(context.delta_ids)
+ if entry.state_group:
+ # If the state at the event has a state group assigned then
+ # we can use that as the prev group
+ context.prev_group = entry.state_group
+ context.delta_ids = {
+ key: event.event_id
+ }
+ elif entry.prev_group:
+ # If the state at the event only has a prev group, then we can
+ # use that as a prev group too.
+ context.prev_group = entry.prev_group
+ context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id
+
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=context.prev_group,
+ delta_ids=context.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
else:
+ context.current_state_ids = context.prev_state_ids
+ context.prev_group = entry.prev_group
+ context.delta_ids = entry.delta_ids
+
if entry.state_group is None:
- entry.state_group = self.store.get_next_state_group()
+ entry.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=entry.prev_group,
+ delta_ids=entry.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
entry.state_id = entry.state_group
context.state_group = entry.state_group
- context.current_state_ids = context.prev_state_ids
- context.prev_group = entry.prev_group
- context.delta_ids = entry.delta_ids
context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
- @log_function
- def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
+ def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
Returns:
- a Deferred tuple of (`state_group`, `state`, `prev_state`).
- `state_group` is the name of a state group if one and only one is
- involved. `state` is a map from (type, state_key) to event, and
- `prev_state` is a list of event ids.
+ Deferred[_StateCacheEntry]: resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
+ # map from state group id to the state in that state group (where
+ # 'state' is a map from state key to event id)
+ # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids
)
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
-
- group_names = frozenset(state_groups_ids.keys())
- if len(group_names) == 1:
+ if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop()
+ prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
- prev_group=name,
- delta_ids={},
+ prev_group=prev_group,
+ delta_ids=delta_ids,
))
+ result = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups_ids, None, self._state_map_factory,
+ )
+ defer.returnValue(result)
+
+ def _state_map_factory(self, ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+
+ def resolve_events(self, state_sets, event):
+ logger.info(
+ "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+ )
+ state_set_ids = [{
+ (ev.type, ev.state_key): ev.event_id
+ for ev in st
+ } for st in state_sets]
+
+ state_map = {
+ ev.event_id: ev
+ for st in state_sets
+ for ev in st
+ }
+
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+ new_state = {
+ key: state_map[ev_id] for key, ev_id in new_state.iteritems()
+ }
+
+ return new_state
+
+
+class StateResolutionHandler(object):
+ """Responsible for doing state conflict resolution.
+
+ Note that the storage layer depends on this handler, so all functions must
+ be storage-independent.
+ """
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+
+ # dict of set of event_ids -> _StateCacheEntry.
+ self._state_cache = None
+ self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+ def start_caching(self):
+ logger.debug("start_caching")
+
+ self._state_cache = ExpiringCache(
+ cache_name="state_cache",
+ clock=self.clock,
+ max_len=SIZE_OF_CACHE,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+ iterable=True,
+ reset_expiry_on_get=True,
+ )
+
+ self._state_cache.start()
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(
+ self, room_id, state_groups_ids, event_map, state_map_factory,
+ ):
+ """Resolves conflicts between a set of state groups
+
+ Always generates a new state group (unless we hit the cache), so should
+ not be called for a single state group
+
+ Args:
+ room_id (str): room we are resolving for (used for logging)
+ state_groups_ids (dict[int, dict[(str, str), str]]):
+ map from state group id to the state in that state group
+ (where 'state' is a map from state key to event id)
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ Returns:
+ Deferred[_StateCacheEntry]: resolved state
+ """
+ logger.debug(
+ "resolve_state_groups state_groups %s",
+ state_groups_ids.keys()
+ )
+
+ group_names = frozenset(state_groups_ids.keys())
+
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
@@ -320,56 +455,62 @@ class StateHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
+ # build a map from state key to the event_ids which set that state.
+ # dict[(str, str), set[str])
state = {}
- for st in state_groups_ids.values():
- for key, e_id in st.items():
+ for st in state_groups_ids.itervalues():
+ for key, e_id in st.iteritems():
state.setdefault(key, set()).add(e_id)
+ # build a map from state key to the event_ids which set that state,
+ # including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
- for k, v in state.items()
+ for k, v in state.iteritems()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events(
+ new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
- state_map_factory=lambda ev_ids: self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
+ event_map=event_map,
+ state_map_factory=state_map_factory,
)
else:
new_state = {
- key: e_ids.pop() for key, e_ids in state.items()
+ key: e_ids.pop() for key, e_ids in state.iteritems()
}
- state_group = None
- new_state_event_ids = frozenset(new_state.values())
- for sg, events in state_groups_ids.items():
- if new_state_event_ids == frozenset(e_id for e_id in events):
- state_group = sg
- break
- if state_group is None:
- # Worker instances don't have access to this method, but we want
- # to set the state_group on the main instance to increase cache
- # hits.
- if hasattr(self.store, "get_next_state_group"):
- state_group = self.store.get_next_state_group()
-
- prev_group = None
- delta_ids = None
- for old_group, old_ids in state_groups_ids.items():
- if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
- n_delta_ids = {
- k: v
- for k, v in new_state.items()
- if old_ids.get(k) != v
- }
- if not delta_ids or len(n_delta_ids) < len(delta_ids):
- prev_group = old_group
- delta_ids = n_delta_ids
+ with Measure(self.clock, "state.create_group_ids"):
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
+ state_group = None
+ new_state_event_ids = frozenset(new_state.itervalues())
+ for sg, events in state_groups_ids.iteritems():
+ if new_state_event_ids == frozenset(e_id for e_id in events):
+ state_group = sg
+ break
+
+ # TODO: We want to create a state group for this set of events, to
+ # increase cache hits, but we need to make sure that it doesn't
+ # end up as a prev_group without being added to the database
+
+ prev_group = None
+ delta_ids = None
+ for old_group, old_ids in state_groups_ids.iteritems():
+ if not set(new_state) - set(old_ids):
+ n_delta_ids = {
+ k: v
+ for k, v in new_state.iteritems()
+ if old_ids.get(k) != v
+ }
+ if not delta_ids or len(n_delta_ids) < len(delta_ids):
+ prev_group = old_group
+ delta_ids = n_delta_ids
cache = _StateCacheEntry(
state=new_state,
@@ -383,30 +524,6 @@ class StateHandler(object):
defer.returnValue(cache)
- def resolve_events(self, state_sets, event):
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
-
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
-
- with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events(state_set_ids, state_map)
-
- new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.items()
- }
-
- return new_state
-
def _ordered_events(events):
def key_func(e):
@@ -415,19 +532,17 @@ def _ordered_events(events):
return sorted(events, key=key_func)
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- state_map_factory(dict|callable): If callable, then will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event. Otherwise, should be
- a dict from event_id to event of all events in state_sets.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent] is a map from
- (type, state_key) to event.
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -436,13 +551,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
- if callable(state_map_factory):
- return _resolve_with_state_fac(
- unconflicted_state, conflicted_state, state_map_factory
- )
-
- state_map = state_map_factory
-
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -456,6 +564,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
+
+ Args:
+ state_sets(list[dict[(str, str), str]]):
+ List of dicts of (type, state_key) -> event_id, which are the
+ different state groups to resolve.
+
+ Returns:
+ (dict[(str, str), str], dict[(str, str), set[str]]):
+ A tuple of (unconflicted_state, conflicted_state), where:
+
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
+
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
@@ -486,24 +609,63 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
- state_map_factory):
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
+ if event_map is not None:
+ needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d conflicted events", len(needed_events))
+ # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+ # the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
+ if event_map is not None:
+ state_map.update(event_map)
+ # get the ids of the auth events which allow us to authenticate the
+ # conflicted state, picking only from the unconflicting state.
+ #
+ # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
+ if event_map is not None:
+ new_needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d auth events", len(new_needed_events))
@@ -541,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = {
key: state_map[ev_id]
- for key, ev_id in auth_event_ids.items()
+ for key, ev_id in auth_event_ids.iteritems()
if ev_id in state_map
}
@@ -549,7 +711,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
- except:
+ except Exception:
logger.exception("Failed to resolve state")
raise
@@ -579,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in conflicted_state.iteritems():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -589,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in conflicted_state.iteritems():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -599,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in conflicted_state.iteritems():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d604e7668f..8cdfd50f90 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
-from ._base import LoggingTransaction
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore, UserPresenceState
@@ -37,7 +35,7 @@ from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
from .event_push_actions import EventPushActionsStore
from .deviceinbox import DeviceInboxStore
-
+from .group_server import GroupServerStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
@@ -49,6 +47,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
from .client_ips import ClientIpStore
+from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
@@ -86,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
+ UserDirectoryStore,
+ GroupServerStore,
):
def __init__(self, db_conn, hs):
@@ -101,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
- )
- self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -121,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
- self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -133,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
+ self._group_updates_id_gen = StreamIdGenerator(
+ db_conn, "local_group_updates", "stream_id",
+ )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
@@ -141,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore,
else:
self._cache_id_gen = None
- events_max = self._stream_id_gen.get_current_token()
- event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
- entity_column="room_id",
- stream_column="stream_ordering",
- max_value=events_max,
- )
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
- prefilled_cache=event_cache_prefill,
- )
-
- self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
- )
-
- account_max = self._account_data_id_gen.get_current_token()
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max,
- )
-
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -175,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill
)
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn, "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._push_rules_stream_id_gen.get_current_token()[0],
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache", push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
@@ -221,23 +185,35 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- after_callbacks=[]
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
)
- self._find_stream_orderings_for_times_txn(cur)
- cur.close()
- self.find_stream_orderings_looping_call = self._clock.looping_call(
- self._find_stream_orderings_for_times, 60 * 60 * 1000
+ _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+ db_conn, "local_group_updates",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._group_updates_id_gen.get_current_token(),
+ limit=1000,
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", min_group_updates_id,
+ prefilled_cache=_group_updates_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- super(DataStore, self).__init__(hs)
+ super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
@@ -266,36 +242,110 @@ class DataStore(RoomMemberStore, RoomStore,
return [UserPresenceState(**row) for row in rows]
- @defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
- txn.execute(
- "SELECT COUNT(DISTINCT user_id) AS users"
- " FROM user_ips"
- " WHERE last_seen > ?",
- # This is close enough to a day for our purposes.
- (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
- )
- rows = self.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
- def get_user_ip_and_agents(self, user):
- return self._simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user.to_string()},
- retcols=[
- "access_token", "ip", "user_agent", "last_seen"
- ],
- desc="get_user_ip_and_agents",
- )
+ txn.execute(sql, (yesterday,))
+ count, = txn.fetchone()
+ return count
+
+ return self.runInteraction("count_users", _count_users)
+
+ def count_r30_users(self):
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns counts globaly for a given user as well as breaking
+ by platform
+ """
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] is 'unknown':
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ count, = txn.fetchone()
+ results['all'] = count
+
+ return results
+
+ return self.runInteraction("count_r30_users", _count_r30_users)
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b0dc391190..2262776ab2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,9 +16,7 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
-from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics
@@ -28,10 +26,6 @@ from twisted.internet import defer
import sys
import time
import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__)
@@ -53,20 +47,27 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+ __slots__ = [
+ "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
+ ]
- def __init__(self, txn, name, database_engine, after_callbacks):
+ def __init__(self, txn, name, database_engine, after_callbacks,
+ exception_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
- def call_after(self, callback, *args):
+ def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
- self.after_callbacks.append((callback, args))
+ self.after_callbacks.append((callback, args, kwargs))
+
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -74,13 +75,22 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
+ def __iter__(self):
+ return self.txn.__iter__()
+
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args)
+ def _make_sql_one_line(self, sql):
+ "Strip newlines out of SQL so that the loggers in the DB are on one line"
+ return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
def _do_execute(self, func, sql, *args):
+ sql = self._make_sql_one_line(sql)
+
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
@@ -91,7 +101,7 @@ class LoggingTransaction(object):
"[SQL values] {%s} %r",
self.name, args[0]
)
- except:
+ except Exception:
# Don't let logging failures stop SQL from working
pass
@@ -127,7 +137,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.items():
+ for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -150,7 +160,7 @@ class PerformanceCounters(object):
class SQLBaseStore(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
@@ -168,10 +178,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
- )
-
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -209,8 +215,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, logging_context,
- func, *args, **kwargs):
+ def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
+ logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -229,7 +235,8 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks
+ txn, name, self.database_engine, after_callbacks,
+ exception_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -284,47 +291,66 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
after_callbacks = []
+ exception_callbacks = []
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ return self._new_transaction(
+ conn, desc, after_callbacks, exception_callbacks, current_context,
+ func, *args, **kwargs
+ )
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
+ try:
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, current_context,
- func, *args, **kwargs
- )
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
- try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
- finally:
- for after_callback, after_args in after_callbacks:
- after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
@@ -350,9 +376,9 @@ class SQLBaseStore(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(column[0] for column in cursor.description)
+ col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list(
- intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
+ dict(zip(col_headers, row)) for row in cursor
)
return results
@@ -417,6 +443,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
+ def _simple_insert_many(self, table, values, desc):
+ return self.runInteraction(
+ desc, self._simple_insert_many_txn, table, values
+ )
+
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
@@ -452,23 +483,53 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
+ @defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
"""
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
- insertion_values (dict): key/values to use when inserting
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
"""
- return self.runInteraction(
- desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock
- )
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+ lock=lock
+ )
+ defer.returnValue(result)
+ except self.database_engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warn(
+ "IntegrityError when upserting into %s; retrying: %s",
+ table, e
+ )
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
@@ -476,45 +537,38 @@ class SQLBaseStore(object):
if lock:
self.database_engine.lock_table(txn, table)
- # Try to update
+ # First try to update.
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
- logger.debug(
- "[SQL] %s Args=%s",
- sql, sqlargs,
- )
txn.execute(sql, sqlargs)
- if txn.rowcount == 0:
- # We didn't update and rows so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
- )
- logger.debug(
- "[SQL] %s Args=%s",
- sql, keyvalues.values(),
- )
- txn.execute(sql, allvalues.values())
+ # We didn't update any rows so insert a new one
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
- return True
- else:
- return False
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues)
+ )
+ txn.execute(sql, allvalues.values())
+ # successfully inserted
+ return True
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
+ return a single row, returning multiple columns from it.
Args:
table : string giving the table name
@@ -567,22 +621,20 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
- else:
- where = ""
-
sql = (
- "SELECT %(retcol)s FROM %(table)s %(where)s"
+ "SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
- "where": where,
}
- txn.execute(sql, keyvalues.values())
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+ txn.execute(sql, keyvalues.values())
+ else:
+ txn.execute(sql)
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
@@ -591,7 +643,7 @@ class SQLBaseStore(object):
Args:
table (str): table name
- keyvalues (dict): column names and values to select the rows with
+ keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:
@@ -715,7 +767,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.items():
+ for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -728,6 +780,33 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
+ def _simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc,
+ self._simple_update_txn,
+ table, keyvalues, updatevalues,
+ )
+
+ @staticmethod
+ def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ return txn.rowcount
+
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
@@ -753,27 +832,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- @staticmethod
- def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
- )
+ @classmethod
+ def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
- if txn.rowcount == 0:
+ if rowcount == 0:
raise StoreError(404, "No row found")
- if txn.rowcount > 1:
+ if rowcount > 1:
raise StoreError(500, "More than one row matched")
@staticmethod
@@ -843,6 +908,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
+ def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ """
+ if not iterable:
+ return
+
+ sql = "DELETE FROM %s" % table
+
+ clauses = []
+ values = []
+ clauses.append(
+ "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+ )
+ values.extend(iterable)
+
+ for key, value in keyvalues.iteritems():
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (
+ sql,
+ " AND ".join(clauses),
+ )
+ return txn.execute(sql, values)
+
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -863,16 +969,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
- rows = txn.fetchall()
- txn.close()
cache = {
row[0]: int(row[1])
- for row in rows
+ for row in txn
}
+ txn.close()
+
if cache:
- min_val = min(cache.values())
+ min_val = min(cache.itervalues())
else:
min_val = max_value
@@ -895,6 +1001,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
+ txn.call_on_exception(ctx.__exit__, None, None, None)
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e92d..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
-import ujson as json
+import abc
+import simplejson as json
import logging
logger = logging.getLogger(__name__)
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_account_data_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ account_max = self.get_max_account_data_stream_id()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+ @abc.abstractmethod
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream ID for account data stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2)
+ @cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
"""
Returns:
@@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
for row in rows
})
+ @cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
@@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
+ @cached(num_args=3, max_entries=5000)
+ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ """Get the client account_data of given type for a user for a room.
+
+ Args:
+ user_id(str): The user to get the account_data for.
+ room_id(str): The room to get the account_data for.
+ account_data_type (str): The account data type to get.
+ Returns:
+ A deferred of the room account_data for that type, or None if
+ there isn't any set.
+ """
+ def get_account_data_for_room_and_type_txn(txn):
+ content_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={
+ "user_id": user_id,
+ "room_id": room_id,
+ "account_data_type": account_data_type,
+ },
+ retcol="content",
+ allow_none=True
+ )
+
+ return json.loads(content_json) if content_json else None
+
+ return self.runInteraction(
+ "get_account_data_for_room_and_type",
+ get_account_data_for_room_and_type_txn,
+ )
+
def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit):
"""Get all the client account_data that has changed on the server
@@ -182,7 +244,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {
- row[0]: json.loads(row[1]) for row in txn.fetchall()
+ row[0]: json.loads(row[1]) for row in txn
}
sql = (
@@ -193,7 +255,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
- for row in txn.fetchall():
+ for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
@@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
+ @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+ ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", ignorer_user_id,
+ on_invalidate=cache_context.invalidate,
+ )
+ if not ignored_account_data:
+ defer.returnValue(False)
+
+ defer.returnValue(
+ ignored_user_id in ignored_account_data.get("ignored_users", {})
+ )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ def __init__(self, db_conn, hs):
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ super(AccountDataStore, self).__init__(db_conn, hs)
+
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream id for the private user data stream
+
+ Returns:
+ A deferred int.
+ """
+ return self._account_data_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@@ -222,9 +314,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
- def add_account_data_txn(txn, next_id):
- self._simple_upsert_txn(
- txn,
+ with self._account_data_id_gen.get_next() as next_id:
+ # no need to lock here as room_account_data has a unique constraint
+ # on (user_id, room_id, account_data_type) so _simple_upsert will
+ # retry if there is a conflict.
+ yield self._simple_upsert(
+ desc="add_room_account_data",
table="room_account_data",
keyvalues={
"user_id": user_id,
@@ -234,18 +329,23 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
- }
- )
- txn.call_after(
- self._account_data_stream_cache.entity_has_changed,
- user_id, next_id,
+ },
+ lock=False,
)
- txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
- self._update_max_stream_id(txn, next_id)
- with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "add_room_account_data", add_account_data_txn, next_id
+ # it's theoretically possible for the above to succeed and the
+ # below to fail - in which case we might reuse a stream id on
+ # restart, and the above update might not get propagated. That
+ # doesn't sound any worse than the whole update getting lost,
+ # which is what would happen if we combined the two into one
+ # transaction.
+ yield self._update_max_stream_id(next_id)
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id,))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type,), content,
)
result = self._account_data_id_gen.get_current_token()
@@ -263,9 +363,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
- def add_account_data_txn(txn, next_id):
- self._simple_upsert_txn(
- txn,
+ with self._account_data_id_gen.get_next() as next_id:
+ # no need to lock here as account_data has a unique constraint on
+ # (user_id, account_data_type) so _simple_upsert will retry if
+ # there is a conflict.
+ yield self._simple_upsert(
+ desc="add_user_account_data",
table="account_data",
keyvalues={
"user_id": user_id,
@@ -274,37 +377,43 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
- }
+ },
+ lock=False,
)
- txn.call_after(
- self._account_data_stream_cache.entity_has_changed,
+
+ # it's theoretically possible for the above to succeed and the
+ # below to fail - in which case we might reuse a stream id on
+ # restart, and the above update might not get propagated. That
+ # doesn't sound any worse than the whole update getting lost,
+ # which is what would happen if we combined the two into one
+ # transaction.
+ yield self._update_max_stream_id(next_id)
+
+ self._account_data_stream_cache.entity_has_changed(
user_id, next_id,
)
- txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
- txn.call_after(
- self.get_global_account_data_by_type_for_user.invalidate,
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id,)
)
- self._update_max_stream_id(txn, next_id)
-
- with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "add_user_account_data", add_account_data_txn, next_id
- )
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
- def _update_max_stream_id(self, txn, next_id):
+ def _update_max_stream_id(self, next_id):
"""Update the max stream_id
Args:
- txn: The database cursor
next_id(int): The the revision to advance to.
"""
- update_max_id_sql = (
- "UPDATE account_data_max_stream_id"
- " SET stream_id = ?"
- " WHERE stream_id < ?"
+ def _update(txn):
+ update_max_id_sql = (
+ "UPDATE account_data_max_stream_id"
+ " SET stream_id = ?"
+ " WHERE stream_id < ?"
+ )
+ txn.execute(update_max_id_sql, (next_id, next_id))
+ return self.runInteraction(
+ "update_account_data_max_stream_id",
+ _update,
)
- txn.execute(update_max_id_sql, (next_id, next_id))
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..12ea8a158c 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,39 +14,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
import simplejson as json
from twisted.internet import defer
-from synapse.api.constants import Membership
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.events import EventsWorkerStore
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-class ApplicationServiceStore(SQLBaseStore):
+def _make_exclusive_regex(services_cache):
+ # We precompie a regex constructed from all the regexes that the AS's
+ # have registered for exclusive users.
+ exclusive_user_regexes = [
+ regex.pattern
+ for service in services_cache
+ for regex in service.get_exlusive_user_regexes()
+ ]
+ if exclusive_user_regexes:
+ exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
+ exclusive_user_regex = re.compile(exclusive_user_regex)
+ else:
+ # We handle this case specially otherwise the constructed regex
+ # will always match
+ exclusive_user_regex = None
- def __init__(self, hs):
- super(ApplicationServiceStore, self).__init__(hs)
- self.hostname = hs.hostname
+ return exclusive_user_regex
+
+
+class ApplicationServiceWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname,
hs.config.app_service_config_files
)
+ self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+
+ super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
- """Check if the user is one associated with an app service
+ """Check if the user is one associated with an app service (exclusively)
"""
- for service in self.services_cache:
- if service.is_interested_in_user(user_id):
- return True
- return False
+ if self.exclusive_user_regex:
+ return bool(self.exclusive_user_regex.match(user_id))
+ else:
+ return False
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
@@ -78,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore):
return service
return None
- def get_app_service_rooms(self, service):
- """Get a list of RoomsForUser for this application service.
-
- Application services may be "interested" in lots of rooms depending on
- the room ID, the room aliases, or the members in the room. This function
- takes all of these into account and returns a list of RoomsForUser which
- represent the entire list of room IDs that this application service
- wants to know about.
+ def get_app_service_by_id(self, as_id):
+ """Get the application service with the given appservice ID.
Args:
- service: The application service to get a room list for.
+ as_id (str): The application service ID.
Returns:
- A list of RoomsForUser.
+ synapse.appservice.ApplicationService or None.
"""
- return self.runInteraction(
- "get_app_service_rooms",
- self._get_app_service_rooms_txn,
- service,
- )
-
- def _get_app_service_rooms_txn(self, txn, service):
- # get all rooms matching the room ID regex.
- room_entries = self._simple_select_list_txn(
- txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
- )
- matching_room_list = set([
- r["room_id"] for r in room_entries if
- service.is_interested_in_room(r["room_id"])
- ])
-
- # resolve room IDs for matching room alias regex.
- room_alias_mappings = self._simple_select_list_txn(
- txn=txn, table="room_aliases", keyvalues=None,
- retcols=["room_id", "room_alias"]
- )
- matching_room_list |= set([
- r["room_id"] for r in room_alias_mappings if
- service.is_interested_in_alias(r["room_alias"])
- ])
-
- # get all rooms for every user for this AS. This is scoped to users on
- # this HS only.
- user_list = self._simple_select_list_txn(
- txn=txn, table="users", keyvalues=None, retcols=["name"]
- )
- user_list = [
- u["name"] for u in user_list if
- service.is_interested_in_user(u["name"])
- ]
- rooms_for_user_matching_user_id = set() # RoomsForUser list
- for user_id in user_list:
- # FIXME: This assumes this store is linked with RoomMemberStore :(
- rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
- txn=txn,
- user_id=user_id,
- membership_list=[Membership.JOIN]
- )
- rooms_for_user_matching_user_id |= set(rooms_for_user)
-
- # make RoomsForUser tuples for room ids and aliases which are not in the
- # main rooms_for_user_list - e.g. they are rooms which do not have AS
- # registered users in it.
- known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
- missing_rooms_for_user = [
- RoomsForUser(r, service.sender, "join") for r in
- matching_room_list if r not in known_room_ids
- ]
- rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
-
- return rooms_for_user_matching_user_id
+ for service in self.services_cache:
+ if service.id == as_id:
+ return service
+ return None
-class ApplicationServiceTransactionStore(SQLBaseStore):
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
- def __init__(self, hs):
- super(ApplicationServiceTransactionStore, self).__init__(hs)
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+ EventsWorkerStore):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@@ -399,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc54a..8af325a9f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse.util.async
from ._base import SQLBaseStore
from . import engines
from twisted.internet import defer
-import ujson as json
+import simplejson as json
import logging
logger = logging.getLogger(__name__)
@@ -79,35 +80,26 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs):
- super(BackgroundUpdateStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(BackgroundUpdateStore, self).__init__(db_conn, hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
- self._background_update_timer = None
+ self._all_done = False
@defer.inlineCallbacks
def start_doing_background_updates(self):
- assert self._background_update_timer is None, \
- "background updates already running"
-
logger.info("Starting background schema updates")
while True:
- sleep = defer.Deferred()
- self._background_update_timer = self._clock.call_later(
- self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
- )
- try:
- yield sleep
- finally:
- self._background_update_timer = None
+ yield synapse.util.async.sleep(
+ self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
- except:
+ except Exception:
logger.exception("Error doing update")
else:
if result is None:
@@ -115,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore):
"No more background updates to do."
" Unscheduling background update task."
)
+ self._all_done = True
defer.returnValue(None)
@defer.inlineCallbacks
+ def has_completed_background_updates(self):
+ """Check if all the background updates have completed
+
+ Returns:
+ Deferred[bool]: True if all background updates have completed
+ """
+ # if we've previously determined that there is nothing left to do, that
+ # is easy
+ if self._all_done:
+ defer.returnValue(True)
+
+ # obviously, if we have things in our queue, we're not done.
+ if self._background_update_queue:
+ defer.returnValue(False)
+
+ # otherwise, check if there are updates to be run. This is important,
+ # as we may be running on a worker which doesn't perform the bg updates
+ # itself, but still wants to wait for them to happen.
+ updates = yield self._simple_select_onecol(
+ "background_updates",
+ keyvalues=None,
+ retcol="1",
+ desc="check_background_updates",
+ )
+ if not updates:
+ self._all_done = True
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
@@ -218,8 +242,29 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_noop_background_update(self, update_name):
+ """Register a noop handler for a background update.
+
+ This is useful when we previously did a background update, but no
+ longer wish to do the update. In this case the background update should
+ be removed from the schema delta files, but there may still be some
+ users who have the background update queued, so this method should
+ also be called to clear the update.
+
+ Args:
+ update_name (str): Name of update
+ """
+ @defer.inlineCallbacks
+ def noop_update(progress, batch_size):
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, noop_update)
+
def register_background_index_update(self, update_name, index_name,
- table, columns, where_clause=None):
+ table, columns, where_clause=None,
+ unique=False,
+ psql_only=False):
"""Helper for store classes to do a background index addition
To use:
@@ -235,48 +280,80 @@ class BackgroundUpdateStore(SQLBaseStore):
index_name (str): name of index to add
table (str): table to add index to
columns (list[str]): columns/expressions to include in index
+ unique (bool): true to make a UNIQUE index
+ psql_only: true to only create this index on psql databases (useful
+ for virtual sqlite tables)
"""
- # if this is postgres, we add the indexes concurrently. Otherwise
- # we fall back to doing it inline
- if isinstance(self.database_engine, engines.PostgresEngine):
- conc = True
- else:
- conc = False
- # We don't use partial indices on SQLite as it wasn't introduced
- # until 3.8, and wheezy has 3.7
- where_clause = None
-
- sql = (
- "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
- " %(where_clause)s"
- ) % {
- "conc": "CONCURRENTLY" if conc else "",
- "name": index_name,
- "table": table,
- "columns": ", ".join(columns),
- "where_clause": "WHERE " + where_clause if where_clause else ""
- }
-
- def create_index_concurrently(conn):
+ def create_index_psql(conn):
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
- c = conn.cursor()
- c.execute(sql)
- conn.set_session(autocommit=False)
- def create_index(conn):
+ try:
+ c = conn.cursor()
+
+ # If a previous attempt to create the index was interrupted,
+ # we may already have a half-built index. Let's just drop it
+ # before trying to create it again.
+
+ sql = "DROP INDEX IF EXISTS %s" % (index_name,)
+ logger.debug("[SQL] %s", sql)
+ c.execute(sql)
+
+ sql = (
+ "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
+ " ON %(table)s"
+ " (%(columns)s) %(where_clause)s"
+ ) % {
+ "unique": "UNIQUE" if unique else "",
+ "name": index_name,
+ "table": table,
+ "columns": ", ".join(columns),
+ "where_clause": "WHERE " + where_clause if where_clause else ""
+ }
+ logger.debug("[SQL] %s", sql)
+ c.execute(sql)
+ finally:
+ conn.set_session(autocommit=False)
+
+ def create_index_sqlite(conn):
+ # Sqlite doesn't support concurrent creation of indexes.
+ #
+ # We don't use partial indices on SQLite as it wasn't introduced
+ # until 3.8, and wheezy and CentOS 7 have 3.7
+ #
+ # We assume that sqlite doesn't give us invalid indices; however
+ # we may still end up with the index existing but the
+ # background_updates not having been recorded if synapse got shut
+ # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
+ # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
+ sql = (
+ "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
+ " (%(columns)s)"
+ ) % {
+ "unique": "UNIQUE" if unique else "",
+ "name": index_name,
+ "table": table,
+ "columns": ", ".join(columns),
+ }
+
c = conn.cursor()
+ logger.debug("[SQL] %s", sql)
c.execute(sql)
+ if isinstance(self.database_engine, engines.PostgresEngine):
+ runner = create_index_psql
+ elif psql_only:
+ runner = None
+ else:
+ runner = create_index_sqlite
+
@defer.inlineCallbacks
def updater(progress, batch_size):
- logger.info("Adding index %s to %s", index_name, table)
- if conc:
- yield self.runWithConnection(create_index_concurrently)
- else:
- yield self.runWithConnection(create_index)
+ if runner is not None:
+ logger.info("Adding index %s to %s", index_name, table)
+ yield self.runWithConnection(runner)
yield self._end_background_update(update_name)
defer.returnValue(1)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 71e5ea112f..7b44dae0fc 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,11 +15,14 @@
import logging
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import Cache
from . import background_updates
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,13 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
- def __init__(self, hs):
+ def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
)
- super(ClientIpStore, self).__init__(hs)
+ super(ClientIpStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",
@@ -44,10 +48,26 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
- @defer.inlineCallbacks
- def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
- now = int(self._clock.time_msec())
- key = (user.to_string(), access_token, ip)
+ self.register_background_index_update(
+ "user_ips_last_seen_index",
+ index_name="user_ips_last_seen",
+ table="user_ips",
+ columns=["user_id", "last_seen"],
+ )
+
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
+ def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
+ now=None):
+ if not now:
+ now = int(self._clock.time_msec())
+ key = (user_id, access_token, ip)
try:
last_seen = self.client_ip_last_seen.get(key)
@@ -56,34 +76,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
+ return
self.client_ip_last_seen.prefill(key, now)
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
+ self._batch_row_update[key] = (user_agent, device_id, now)
+
+ def _update_client_ips_batch(self):
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
+ def _update_client_ips_batch_txn(self, txn, to_update):
+ self.database_engine.lock_table(txn, "user_ips")
+
+ for entry in to_update.iteritems():
+ (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -94,6 +128,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -102,23 +137,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
+
+ if not where_clauses:
+ return []
inner_select = (
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
@@ -143,3 +192,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..a879e5bfc1 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,12 +14,14 @@
# limitations under the License.
import logging
-import ujson
+import simplejson
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
+from synapse.util.caches.expiringcache import ExpiringCache
+
logger = logging.getLogger(__name__)
@@ -27,8 +29,8 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, hs):
- super(DeviceInboxStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
self._background_drop_index_device_inbox,
)
+ # Map of (user_id, device_id) to the last stream_id that has been
+ # deleted up to. This is so that we can no op deletions.
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination):
@@ -74,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = ujson.dumps(edu)
+ edu_json = simplejson.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -166,8 +177,8 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
- message_json = ujson.dumps(messages_by_device["*"])
- for row in txn.fetchall():
+ message_json = simplejson.dumps(messages_by_device["*"])
+ for row in txn:
# Add the message for all devices for this user on this
# server.
device = row[0]
@@ -184,11 +195,11 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
- for row in txn.fetchall():
+ for row in txn:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = ujson.dumps(messages_by_device[device])
+ message_json = simplejson.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
@@ -240,9 +251,9 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(simplejson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
+ @defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
Returns:
A deferred that resolves to the number of messages deleted.
"""
+ # If we have cached the last stream id we've deleted up to, we can
+ # check if there is likely to be anything that needs deleting
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), None
+ )
+ if last_deleted_stream_id:
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_deleted_stream_id
+ )
+ if not has_changed:
+ defer.returnValue(0)
+
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- return self.runInteraction(
+ count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
+ # Update the cache, ensuring that we only ever increase the value
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), 0
+ )
+ self._last_device_delete_cache[(user_id, device_id)] = max(
+ last_deleted_stream_id, up_to_stream_id
+ )
+
+ defer.returnValue(count)
+
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -291,22 +325,25 @@ class DeviceInboxStore(BackgroundUpdateStore):
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
- "SELECT stream_id, user_id"
+ "SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
+ " GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
- "SELECT stream_id, destination"
+ "SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
+ " GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn.fetchall())
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
return rows
@@ -323,12 +360,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
"""
Args:
destination(str): The name of the remote server.
- last_stream_id(int): The last position of the device message stream
+ last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
- current_stream_id(int): The current position of the device
+ current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int): List of messages for the device and where
+ Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
@@ -350,9 +387,9 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(simplejson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8e17800364..712106b83a 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,24 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import ujson as json
+import simplejson as json
from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, Cache
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
- def __init__(self, hs):
- super(DeviceStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(DeviceStore, self).__init__(db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists",
+ keylen=2,
+ max_entries=10000,
+ )
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ self.register_background_index_update(
+ "device_lists_stream_idx",
+ index_name="device_lists_stream_user_id",
+ table="device_lists_stream",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name):
@@ -45,6 +62,10 @@ class DeviceStore(SQLBaseStore):
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
"""
+ key = (user_id, device_id)
+ if self.device_id_exists_cache.get(key, None):
+ defer.returnValue(False)
+
try:
inserted = yield self._simple_insert(
"devices",
@@ -56,6 +77,7 @@ class DeviceStore(SQLBaseStore):
desc="store_device",
or_ignore=True,
)
+ self.device_id_exists_cache.prefill(key, True)
defer.returnValue(inserted)
except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -84,6 +106,7 @@ class DeviceStore(SQLBaseStore):
desc="get_device",
)
+ @defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -93,12 +116,34 @@ class DeviceStore(SQLBaseStore):
Returns:
defer.Deferred
"""
- return self._simple_delete_one(
+ yield self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
+ self.device_id_exists_cache.invalidate((user_id, device_id))
+
+ @defer.inlineCallbacks
+ def delete_devices(self, user_id, device_ids):
+ """Deletes several devices.
+
+ Args:
+ user_id (str): The ID of the user which owns the devices
+ device_ids (list): The IDs of the devices to delete
+ Returns:
+ defer.Deferred
+ """
+ yield self._simple_delete_many(
+ table="devices",
+ column="device_id",
+ iterable=device_ids,
+ keyvalues={"user_id": user_id},
+ desc="delete_devices",
+ )
+ for device_id in device_ids:
+ self.device_id_exists_cache.invalidate((user_id, device_id))
+
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
@@ -144,6 +189,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
+ @cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
@@ -156,16 +202,36 @@ class DeviceStore(SQLBaseStore):
allow_none=True,
)
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_user_devices_from_cache",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- return self._simple_delete(
+ yield self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
+ self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
@@ -191,6 +257,12 @@ class DeviceStore(SQLBaseStore):
}
)
+ txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -234,6 +306,12 @@ class DeviceStore(SQLBaseStore):
]
)
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -249,7 +327,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers
Returns:
- (now_stream_id, [ { updates }, .. ])
+ (int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -270,24 +348,27 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
+ LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
- rows = txn.fetchall()
- if not rows:
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
return (now_stream_id, [])
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in rows}
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_pokes
+ FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
@@ -320,6 +401,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
+ @defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
@@ -332,27 +414,11 @@ class DeviceStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- return self.runInteraction(
- "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
- query_list,
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
)
-
- def _get_user_devices_from_cache_txn(self, txn, query_list):
- user_ids = {user_id for user_id, _ in query_list}
-
- user_ids_in_cache = set()
- for user_id in user_ids:
- stream_ids = self._simple_select_onecol_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
- retcol="stream_id",
- )
- if stream_ids:
- user_ids_in_cache.add(user_id)
-
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -361,32 +427,40 @@ class DeviceStore(SQLBaseStore):
continue
if device_id:
- content = self._simple_select_one_onecol_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- )
- results.setdefault(user_id, {})[device_id] = json.loads(content)
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
else:
- devices = self._simple_select_list_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- )
- results[user_id] = {
- device["device_id"]: json.loads(device["content"])
- for device in devices
- }
- user_ids_in_cache.discard(user_id)
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
- return user_ids_not_in_cache, results
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(json.loads(content))
+
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: json.loads(device["content"])
+ for device in devices
+ })
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -436,32 +510,43 @@ class DeviceStore(SQLBaseStore):
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # First we DELETE all rows such that only the latest row for each
- # (destination, user_id is left. We do this by selecting first and
- # deleting.
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
sql = """
- SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
- HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND stream_id < ?
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
+ )
+
+ sql = """
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows)
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
- # Mark everything that is left as sent
+ # Delete all sent outbound pokes
sql = """
- UPDATE device_lists_outbound_pokes SET sent = ?
+ DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
- txn.execute(sql, (True, destination, stream_id,))
+ txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
@@ -473,12 +558,12 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(changed))
sql = """
- SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
- def get_all_device_list_changes_for_remotes(self, from_key):
+ def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
@@ -486,11 +571,11 @@ class DeviceStore(SQLBaseStore):
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
- WHERE stream_id > ?
+ WHERE ? < stream_id AND stream_id <= ?
"""
return self._execute(
- "get_users_and_hosts_device_list", None,
- sql, from_key,
+ "get_all_device_list_changes_for_remotes", None,
+ sql, from_key, to_key
)
@defer.inlineCallbacks
@@ -518,6 +603,16 @@ class DeviceStore(SQLBaseStore):
host, stream_id,
)
+ # Delete older entries in the table, as we really only care about
+ # when the latest change happened.
+ txn.executemany(
+ """
+ DELETE FROM device_lists_stream
+ WHERE user_id = ? AND device_id = ? AND stream_id < ?
+ """,
+ [(user_id, device_id, stream_id) for device_id in device_ids]
+ )
+
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
@@ -586,6 +681,14 @@ class DeviceStore(SQLBaseStore):
)
)
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 9caaf81f2c..d0c0059757 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
)
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias
@@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
+ def get_room_alias_creator(self, room_alias):
+ return self._simple_select_one_onecol(
+ table="room_aliases",
+ keyvalues={
+ "room_alias": room_alias,
+ },
+ retcol="creator",
+ desc="get_room_alias_creator",
+ allow_none=True
+ )
+
+ @cached(max_entries=5000)
+ def get_aliases_for_room(self, room_id):
+ return self._simple_select_onecol(
+ "room_aliases",
+ {"room_id": room_id},
+ "room_alias",
+ desc="get_aliases_for_room",
+ )
+
+
+class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers
@@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
)
defer.returnValue(ret)
- def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
- table="room_aliases",
- keyvalues={
- "room_alias": room_alias,
- },
- retcol="creator",
- desc="get_room_alias_creator",
- allow_none=True
- )
-
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
@@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,13 +169,22 @@ class DirectoryStore(SQLBaseStore):
(room_alias.to_string(),)
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (room_id,)
+ )
+
return room_id
- @cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
- "room_aliases",
- {"room_id": room_id},
- "room_alias",
- desc="get_aliases_for_room",
+ def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+ def _update_aliases_for_room_txn(txn):
+ sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
+ txn.execute(sql, (new_room_id, creator, old_room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (old_room_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (new_room_id,)
+ )
+ return self.runInteraction(
+ "_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f92..ff8538ddf8 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,8 +14,10 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.util.caches.descriptors import cached
+
from canonicaljson import encode_canonical_json
-import ujson as json
+import simplejson as json
from ._base import SQLBaseStore
@@ -120,26 +122,77 @@ class EndToEndKeyStore(SQLBaseStore):
return result
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+ @defer.inlineCallbacks
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
+ """
+
+ rows = yield self._simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json",),
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ desc="add_e2e_one_time_keys_check",
+ )
+
+ defer.returnValue({
+ (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
+
def _add_e2e_one_time_keys(txn):
- for (algorithm, key_id, json_bytes) in key_list:
- self._simple_upsert_txn(
- txn, table="e2e_one_time_keys_json",
- keyvalues={
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self._simple_insert_many_txn(
+ txn, table="e2e_one_time_keys_json",
+ values=[
+ {
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
- },
- values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
- )
- return self.runInteraction(
- "add_e2e_one_time_keys", _add_e2e_one_time_keys
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
+ yield self.runInteraction(
+ "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
+ @cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device
Returns:
@@ -153,7 +206,7 @@ class EndToEndKeyStore(SQLBaseStore):
)
txn.execute(sql, (user_id, device_id))
result = {}
- for algorithm, key_count in txn.fetchall():
+ for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.runInteraction(
@@ -174,7 +227,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn.fetchall():
+ for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
@@ -184,20 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
- yield self._simple_delete(
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_device_keys_by_device"
- )
- yield self._simple_delete(
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_one_time_keys_by_device"
+ def delete_e2e_keys_by_device_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
+ return self.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..8c868ece75 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,6 +18,7 @@ from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
import importlib
+import platform
SUPPORTED_MODULE = {
@@ -31,6 +32,10 @@ def create_engine(database_config):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
+ # pypy requires psycopg2cffi rather than psycopg2
+ if (name == "psycopg2" and
+ platform.python_implementation() == "PyPy"):
+ name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ txn.execute("SELECT nextval('state_group_id_seq')")
+ return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database
import struct
+import threading
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ # The current max state_group, or None if we haven't looked
+ # in the DB yet.
+ self._current_state_group_id = None
+ self._current_state_group_id_lock = threading.Lock()
+
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._current_state_group_id_lock:
+ if self._current_state_group_id is None:
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ self._current_state_group_id = txn.fetchone()[0]
+
+ self._current_state_group_id += 1
+ return self._current_state_group_id
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 256e50dc20..8fbf7ffba7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,50 +12,64 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from twisted.internet import defer
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+
from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64
import logging
-from Queue import PriorityQueue, Empty
+from six.moves.queue import PriorityQueue, Empty
+
+from six.moves import range
logger = logging.getLogger(__name__)
-class EventFederationStore(SQLBaseStore):
- """ Responsible for storing and serving up the various graphs associated
- with an event. Including the main event graph and the auth chains for an
- event.
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+ SQLBaseStore):
+ def get_auth_chain(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
- Also has methods for getting the front (latest) and back (oldest) edges
- of the event graphs. These are used to generate the parents for new events
- and backfilling from another server respectively.
- """
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def __init__(self, hs):
- super(EventFederationStore, self).__init__(hs)
+ Returns:
+ list of events
+ """
+ return self.get_auth_chain_ids(
+ event_ids, include_given=include_given,
+ ).addCallback(self._get_events)
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
+ def get_auth_chain_ids(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
- def get_auth_chain(self, event_ids):
- return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def get_auth_chain_ids(self, event_ids):
+ Returns:
+ list of event_ids
+ """
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_ids
+ event_ids, include_given
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
- results = set()
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ if include_given:
+ results = set(event_ids)
+ else:
+ results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -67,14 +81,14 @@ class EventFederationStore(SQLBaseStore):
front_list = list(front)
chunks = [
front_list[x:x + 100]
- for x in xrange(0, len(front), 100)
+ for x in range(0, len(front), 100)
]
for chunk in chunks:
txn.execute(
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
- new_front.update([r[0] for r in txn.fetchall()])
+ new_front.update([r[0] for r in txn])
new_front -= results
@@ -110,7 +124,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,))
- return dict(txn.fetchall())
+ return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@@ -122,7 +136,47 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id",
)
+ @defer.inlineCallbacks
+ def get_prev_events_for_room(self, room_id):
+ """
+ Gets a subset of the current forward extremities in the given room.
+
+ Limits the result to 10 extremities, so that we can avoid creating
+ events which refer to hundreds of prev_events.
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+ res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
+ if len(res) > 10:
+ # Sort by reverse depth, so we point to the most recent.
+ res.sort(key=lambda a: -a[2])
+
+ # we use half of the limit for the actual most recent events, and
+ # the other half to randomly point to some of the older events, to
+ # make sure that we don't completely ignore the older events.
+ res = res[0:5] + random.sample(res[5:], 5)
+
+ defer.returnValue(res)
+
def get_latest_event_ids_and_hashes_in_room(self, room_id):
+ """
+ Gets the current forward extremities in the given room
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+
return self.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
@@ -171,22 +225,6 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_max_depth_of_events(self, event_ids):
- sql = (
- "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
- ) % (",".join(["?"] * len(event_ids)),)
-
- rows = yield self._execute(
- "get_max_depth_of_events", None,
- sql, *event_ids
- )
-
- if rows:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(1)
-
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
@@ -198,88 +236,6 @@ class EventFederationStore(SQLBaseStore):
return int(min_depth) if min_depth is not None else None
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
- min_depth = self._get_min_depth_interaction(txn, room_id)
-
- do_insert = depth < min_depth if min_depth else True
-
- if do_insert:
- self._simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
- )
-
- def _handle_mult_prev_events(self, txn, events):
- """
- For the given event, update the event edges table and forward and
- backward extremities tables.
- """
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
- for ev in events
- for e_id, _ in ev.prev_events
- ],
- )
-
- self._update_backward_extremeties(txn, events)
-
- def _update_backward_extremeties(self, txn, events):
- """Updates the event_backward_extremities tables based on the new/updated
- events being persisted.
-
- This is called for new events *and* for events that were outliers, but
- are now being persisted as non-outliers.
-
- Forward extremities are handled when we first start persisting the events.
- """
- events_by_room = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
-
- txn.executemany(query, [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events for e_id, _ in ev.prev_events
- if not ev.internal_metadata.is_outlier()
- ])
-
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.executemany(
- query,
- [
- (ev.event_id, ev.room_id) for ev in events
- if not ev.internal_metadata.is_outlier()
- ]
- )
-
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -334,36 +290,13 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
- rows = txn.fetchall()
- return [event_id for event_id, in rows]
+ return [event_id for event_id, in txn]
return self.runInteraction(
"get_forward_extremeties_for_room",
get_forward_extremeties_for_room_txn
)
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = ("""
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """)
- txn.execute(
- sql,
- (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
- )
- return self.runInteraction(
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn
- )
-
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -436,7 +369,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for row in txn.fetchall():
+ for row in txn:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
@@ -482,7 +415,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for e_id, in txn.fetchall():
+ for e_id, in txn:
new_front.add(e_id)
new_front -= earliest_events
@@ -493,6 +426,135 @@ class EventFederationStore(SQLBaseStore):
return event_results
+
+class EventFederationStore(EventFederationWorkerStore):
+ """ Responsible for storing and serving up the various graphs associated
+ with an event. Including the main event graph and the auth chains for an
+ event.
+
+ Also has methods for getting the front (latest) and back (oldest) edges
+ of the event graphs. These are used to generate the parents for new events
+ and backfilling from another server respectively.
+ """
+
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+ def __init__(self, db_conn, hs):
+ super(EventFederationStore, self).__init__(db_conn, hs)
+
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ if min_depth and depth >= min_depth:
+ return
+
+ self._simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={
+ "room_id": room_id,
+ },
+ values={
+ "min_depth": depth,
+ },
+ )
+
+ def _handle_mult_prev_events(self, txn, events):
+ """
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": ev.event_id,
+ "prev_event_id": e_id,
+ "room_id": ev.room_id,
+ "is_state": False,
+ }
+ for ev in events
+ for e_id, _ in ev.prev_events
+ ],
+ )
+
+ self._update_backward_extremeties(txn, events)
+
+ def _update_backward_extremeties(self, txn, events):
+ """Updates the event_backward_extremities tables based on the new/updated
+ events being persisted.
+
+ This is called for new events *and* for events that were outliers, but
+ are now being persisted as non-outliers.
+
+ Forward extremities are handled when we first start persisting the events.
+ """
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
+
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
+
+ txn.executemany(query, [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ])
+
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [
+ (ev.event_id, ev.room_id) for ev in events
+ if not ev.internal_metadata.is_outlier()
+ ]
+ )
+
+ def _delete_old_forward_extrem_cache(self):
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = ("""
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """)
+ txn.execute(
+ sql,
+ (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+ )
+ return self.runInteraction(
+ "_delete_old_forward_extrem_cache",
+ _delete_old_forward_extrem_cache_txn
+ )
+
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
@@ -505,3 +567,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+ @defer.inlineCallbacks
+ def _background_delete_non_state_event_auth(self, progress, batch_size):
+ def delete_event_auth(txn):
+ target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+ max_stream_id = progress.get("max_stream_id_exclusive")
+
+ if not target_min_stream_id or not max_stream_id:
+ txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ target_min_stream_id = rows[0][0]
+
+ txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ max_stream_id = rows[0][0]
+
+ min_stream_id = max_stream_id - batch_size
+
+ sql = """
+ DELETE FROM event_auth
+ WHERE event_id IN (
+ SELECT event_id FROM events
+ LEFT JOIN state_events USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND state_key IS null
+ )
+ """
+
+ txn.execute(sql, (min_stream_id, max_stream_id,))
+
+ new_progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+ )
+
+ return min_stream_id >= target_min_stream_id
+
+ result = yield self.runInteraction(
+ self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+ defer.returnValue(batch_size)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 522d0114cb..c22762eb5c 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,131 +14,159 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
+from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import RoomStreamToken
from .stream import lower_bound
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
-class EventPushActionsStore(SQLBaseStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
+DEFAULT_HIGHLIGHT_ACTION = [
+ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+]
- def __init__(self, hs):
- self.stream_ordering_month_ago = None
- super(EventPushActionsStore, self).__init__(hs)
- self.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
- )
+def _serialize_action(actions, is_highlight):
+ """Custom serializer for actions. This allows us to "compress" common actions.
- self.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1"
- )
+ We use the fact that most users have the same actions for notifs (and for
+ highlights).
+ We store these default actions as the empty string rather than the full JSON.
+ Since the empty string isn't valid JSON there is no risk of this clashing with
+ any real JSON actions
+ """
+ if is_highlight:
+ if actions == DEFAULT_HIGHLIGHT_ACTION:
+ return "" # We use empty string as the column is non-NULL
+ else:
+ if actions == DEFAULT_NOTIF_ACTION:
+ return ""
+ return json.dumps(actions)
- def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
- """
- Args:
- event: the event set actions for
- tuples: list of tuples of (user_id, actions)
- """
- values = []
- for uid, actions in tuples:
- values.append({
- 'room_id': event.room_id,
- 'event_id': event.event_id,
- 'user_id': uid,
- 'actions': json.dumps(actions),
- 'stream_ordering': event.internal_metadata.stream_ordering,
- 'topological_ordering': event.depth,
- 'notif': 1,
- 'highlight': 1 if _action_has_highlight(actions) else 0,
- })
-
- for uid, __ in tuples:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid)
- )
- self._simple_insert_many_txn(txn, "event_push_actions", values)
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
- def _get_unread_event_push_actions_by_room(txn):
- sql = (
- "SELECT stream_ordering, topological_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(
- sql, (room_id, last_read_event_id)
- )
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
-
- stream_ordering = results[0][0]
- topological_ordering = results[0][1]
- token = RoomStreamToken(
- topological_ordering, stream_ordering
- )
+def _deserialize_action(actions, is_highlight):
+ """Custom deserializer for actions. This allows us to "compress" common actions
+ """
+ if actions:
+ return json.loads(actions)
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ if is_highlight:
+ return DEFAULT_HIGHLIGHT_ACTION
+ else:
+ return DEFAULT_NOTIF_ACTION
- txn.execute(sql, (user_id, room_id))
- row = txn.fetchone()
- notify_count = row[0] if row else 0
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+class EventPushActionsWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
- txn.execute(sql, (user_id, room_id))
- row = txn.fetchone()
- highlight_count = row[0] if row else 0
+ # These get correctly set by _find_stream_orderings_for_times_txn
+ self.stream_ordering_month_ago = None
+ self.stream_ordering_day_ago = None
+
+ cur = LoggingTransaction(
+ db_conn.cursor(),
+ name="_find_stream_orderings_for_times_txn",
+ database_engine=self.database_engine,
+ after_callbacks=[],
+ exception_callbacks=[],
+ )
+ self._find_stream_orderings_for_times_txn(cur)
+ cur.close()
- return {
- "notify_count": notify_count,
- "highlight_count": highlight_count,
- }
+ self.find_stream_orderings_looping_call = self._clock.looping_call(
+ self._find_stream_orderings_for_times, 10 * 60 * 1000
+ )
+ @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
+ def get_unread_event_push_actions_by_room_for_user(
+ self, room_id, user_id, last_read_event_id
+ ):
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
- _get_unread_event_push_actions_by_room
+ self._get_unread_counts_by_receipt_txn,
+ room_id, user_id, last_read_event_id
)
defer.returnValue(ret)
+ def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
+ last_read_event_id):
+ sql = (
+ "SELECT stream_ordering, topological_ordering"
+ " FROM events"
+ " WHERE room_id = ? AND event_id = ?"
+ )
+ txn.execute(
+ sql, (room_id, last_read_event_id)
+ )
+ results = txn.fetchall()
+ if len(results) == 0:
+ return {"notify_count": 0, "highlight_count": 0}
+
+ stream_ordering = results[0][0]
+ topological_ordering = results[0][1]
+
+ return self._get_unread_counts_by_pos_txn(
+ txn, room_id, user_id, topological_ordering, stream_ordering
+ )
+
+ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
+ stream_ordering):
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
+ )
+
+ # First get number of notifications.
+ # We don't need to put a notif=1 clause as all rows always have
+ # notif=1
+ sql = (
+ "SELECT count(*)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " user_id = ?"
+ " AND room_id = ?"
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
+ row = txn.fetchone()
+ notify_count = row[0] if row else 0
+
+ txn.execute("""
+ SELECT notif_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """, (room_id, user_id, stream_ordering,))
+ rows = txn.fetchall()
+ if rows:
+ notify_count += rows[0][0]
+
+ # Now get the number of highlights
+ sql = (
+ "SELECT count(*)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " highlight = 1"
+ " AND user_id = ?"
+ " AND room_id = ?"
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
+ row = txn.fetchone()
+ highlight_count = row[0] if row else 0
+
+ return {
+ "notify_count": notify_count,
+ "highlight_count": highlight_count,
+ }
+
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
def f(txn):
@@ -146,7 +175,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
@@ -176,7 +205,8 @@ class EventPushActionsStore(SQLBaseStore):
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " ep.highlight "
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
@@ -217,7 +247,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight "
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
@@ -246,7 +276,7 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
- "actions": json.loads(row[3]),
+ "actions": _deserialize_action(row[3], row[4]),
} for row in after_read_receipt + no_read_receipt
]
@@ -285,7 +315,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_after_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight, e.received_ts"
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
@@ -327,7 +357,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight, e.received_ts"
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
@@ -357,8 +387,8 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
- "actions": json.loads(row[3]),
- "received_ts": row[4],
+ "actions": _deserialize_action(row[3], row[4]),
+ "received_ts": row[5],
} for row in after_read_receipt + no_read_receipt
]
@@ -371,6 +401,290 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit`
defer.returnValue(notifs[:limit])
+ def add_push_actions_to_staging(self, event_id, user_id_actions):
+ """Add the push actions for the event to the push action staging area.
+
+ Args:
+ event_id (str)
+ user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+ user_id to list of push actions, where an action can either be
+ a string or dict.
+
+ Returns:
+ Deferred
+ """
+
+ if not user_id_actions:
+ return
+
+ # This is a helper function for generating the necessary tuple that
+ # can be used to inert into the `event_push_actions_staging` table.
+ def _gen_entry(user_id, actions):
+ is_highlight = 1 if _action_has_highlight(actions) else 0
+ return (
+ event_id, # event_id column
+ user_id, # user_id column
+ _serialize_action(actions, is_highlight), # actions column
+ 1, # notif column
+ is_highlight, # highlight column
+ )
+
+ def _add_push_actions_to_staging_txn(txn):
+ # We don't use _simple_insert_many here to avoid the overhead
+ # of generating lists of dicts.
+
+ sql = """
+ INSERT INTO event_push_actions_staging
+ (event_id, user_id, actions, notif, highlight)
+ VALUES (?, ?, ?, ?, ?)
+ """
+
+ txn.executemany(sql, (
+ _gen_entry(user_id, actions)
+ for user_id, actions in user_id_actions.iteritems()
+ ))
+
+ return self.runInteraction(
+ "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ )
+
+ @defer.inlineCallbacks
+ def remove_push_actions_from_staging(self, event_id):
+ """Called if we failed to persist the event to ensure that stale push
+ actions don't build up in the DB
+
+ Args:
+ event_id (str)
+ """
+
+ try:
+ res = yield self._simple_delete(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event_id,
+ },
+ desc="remove_push_actions_from_staging",
+ )
+ defer.returnValue(res)
+ except Exception:
+ # this method is called from an exception handler, so propagating
+ # another exception here really isn't helpful - there's nothing
+ # the caller can do about it. Just log the exception and move on.
+ logger.exception(
+ "Error removing push actions after event persistence failure",
+ )
+
+ @defer.inlineCallbacks
+ def _find_stream_orderings_for_times(self):
+ yield self.runInteraction(
+ "_find_stream_orderings_for_times",
+ self._find_stream_orderings_for_times_txn
+ )
+
+ def _find_stream_orderings_for_times_txn(self, txn):
+ logger.info("Searching for stream ordering 1 month ago")
+ self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 month ago: it's %d",
+ self.stream_ordering_month_ago
+ )
+ logger.info("Searching for stream ordering 1 day ago")
+ self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 day ago: it's %d",
+ self.stream_ordering_day_ago
+ )
+
+ def find_first_stream_ordering_after_ts(self, ts):
+ """Gets the stream ordering corresponding to a given timestamp.
+
+ Specifically, finds the stream_ordering of the first event that was
+ received on or after the timestamp. This is done by a binary search on
+ the events table, since there is no index on received_ts, so is
+ relatively slow.
+
+ Args:
+ ts (int): timestamp in millis
+
+ Returns:
+ Deferred[int]: stream ordering of the first event received on/after
+ the timestamp
+ """
+ return self.runInteraction(
+ "_find_first_stream_ordering_after_ts_txn",
+ self._find_first_stream_ordering_after_ts_txn,
+ ts,
+ )
+
+ @staticmethod
+ def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ """
+ Find the stream_ordering of the first event that was received on or
+ after a given timestamp. This is relatively slow as there is no index
+ on received_ts but we can then use this to delete push actions before
+ this.
+
+ received_ts must necessarily be in the same order as stream_ordering
+ and stream_ordering is indexed, so we manually binary search using
+ stream_ordering
+
+ Args:
+ txn (twisted.enterprise.adbapi.Transaction):
+ ts (int): timestamp to search for
+
+ Returns:
+ int: stream ordering
+ """
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ max_stream_ordering = txn.fetchone()[0]
+
+ if max_stream_ordering is None:
+ return 0
+
+ # We want the first stream_ordering in which received_ts is greater
+ # than or equal to ts. Call this point X.
+ #
+ # We maintain the invariants:
+ #
+ # range_start <= X <= range_end
+ #
+ range_start = 0
+ range_end = max_stream_ordering + 1
+
+ # Given a stream_ordering, look up the timestamp at that
+ # stream_ordering.
+ #
+ # The array may be sparse (we may be missing some stream_orderings).
+ # We treat the gaps as the same as having the same value as the
+ # preceding entry, because we will pick the lowest stream_ordering
+ # which satisfies our requirement of received_ts >= ts.
+ #
+ # For example, if our array of events indexed by stream_ordering is
+ # [10, <none>, 20], we should treat this as being equivalent to
+ # [10, 10, 20].
+ #
+ sql = (
+ "SELECT received_ts FROM events"
+ " WHERE stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT 1"
+ )
+
+ while range_end - range_start > 0:
+ middle = (range_end + range_start) // 2
+ txn.execute(sql, (middle,))
+ row = txn.fetchone()
+ if row is None:
+ # no rows with stream_ordering<=middle
+ range_start = middle + 1
+ continue
+
+ middle_ts = row[0]
+ if ts > middle_ts:
+ # we got a timestamp lower than the one we were looking for.
+ # definitely need to look higher: X > middle.
+ range_start = middle + 1
+ else:
+ # we got a timestamp higher than (or the same as) the one we
+ # were looking for. We aren't yet sure about the point we
+ # looked up, but we can be sure that X <= middle.
+ range_end = middle
+
+ return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1"
+ )
+
+ self._doing_notif_rotation = False
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
+
+ def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+ all_events_and_contexts):
+ """Handles moving push actions from staging table to main
+ event_push_actions table for all events in `events_and_contexts`.
+
+ Also ensures that all events in `all_events_and_contexts` are removed
+ from the push action staging area.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ """
+
+ sql = """
+ INSERT INTO event_push_actions (
+ room_id, event_id, user_id, actions, stream_ordering,
+ topological_ordering, notif, highlight
+ )
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ FROM event_push_actions_staging
+ WHERE event_id = ?
+ """
+
+ if events_and_contexts:
+ txn.executemany(sql, (
+ (
+ event.room_id, event.internal_metadata.stream_ordering,
+ event.depth, event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ))
+
+ for event, _ in events_and_contexts:
+ user_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event.event_id,
+ },
+ retcol="user_id",
+ )
+
+ for uid in user_ids:
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid,)
+ )
+
+ # Now we delete the staging area for *all* events that were being
+ # persisted.
+ txn.executemany(
+ "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+ (
+ (event.event_id,)
+ for event, _ in all_events_and_contexts
+ )
+ )
+
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False):
@@ -392,7 +706,7 @@ class EventPushActionsStore(SQLBaseStore):
sql = (
"SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.profile_tag, e.received_ts"
+ " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e"
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
@@ -407,7 +721,7 @@ class EventPushActionsStore(SQLBaseStore):
"get_push_actions_for_user", f
)
for pa in push_actions:
- pa["actions"] = json.loads(pa["actions"])
+ pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions)
@defer.inlineCallbacks
@@ -448,7 +762,7 @@ class EventPushActionsStore(SQLBaseStore):
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- topological_ordering):
+ topological_ordering, stream_ordering):
"""
Purges old push actions for a user and room before a given
topological_ordering.
@@ -479,65 +793,140 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
- " topological_ordering < ?"
+ " topological_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
)
+ txn.execute("""
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """, (room_id, user_id, stream_ordering))
+
@defer.inlineCallbacks
- def _find_stream_orderings_for_times(self):
- yield self.runInteraction(
- "_find_stream_orderings_for_times",
- self._find_stream_orderings_for_times_txn
- )
+ def _rotate_notifs(self):
+ if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
+ return
+ self._doing_notif_rotation = True
- def _find_stream_orderings_for_times_txn(self, txn):
- logger.info("Searching for stream ordering 1 month ago")
- self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
- txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
- )
- logger.info(
- "Found stream ordering 1 month ago: it's %d",
- self.stream_ordering_month_ago
+ try:
+ while True:
+ logger.info("Rotating notifications")
+
+ caught_up = yield self.runInteraction(
+ "_rotate_notifs",
+ self._rotate_notifs_txn
+ )
+ if caught_up:
+ break
+ yield sleep(5)
+ finally:
+ self._doing_notif_rotation = False
+
+ def _rotate_notifs_txn(self, txn):
+ """Archives older notifications into event_push_summary. Returns whether
+ the archiving process has caught up or not.
+ """
+
+ old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
)
- def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
- """
- Find the stream_ordering of the first event that was received after
- a given timestamp. This is relatively slow as there is no index on
- received_ts but we can then use this to delete push actions before
- this.
+ # We don't to try and rotate millions of rows at once, so we cap the
+ # maximum stream ordering we'll rotate before.
+ txn.execute("""
+ SELECT stream_ordering FROM event_push_actions
+ WHERE stream_ordering > ?
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
+ """, (old_rotate_stream_ordering,))
+ stream_row = txn.fetchone()
+ if stream_row:
+ offset_stream_ordering, = stream_row
+ rotate_to_stream_ordering = min(
+ self.stream_ordering_day_ago, offset_stream_ordering
+ )
+ caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+ else:
+ rotate_to_stream_ordering = self.stream_ordering_day_ago
+ caught_up = True
- received_ts must necessarily be in the same order as stream_ordering
- and stream_ordering is indexed, so we manually binary search using
- stream_ordering
+ logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
+
+ self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+
+ # We have caught up iff we were limited by `stream_ordering_day_ago`
+ return caught_up
+
+ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+ old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ # Calculate the new counts that should be upserted into event_push_summary
+ sql = """
+ SELECT user_id, room_id,
+ coalesce(old.notif_count, 0) + upd.notif_count,
+ upd.stream_ordering,
+ old.user_id
+ FROM (
+ SELECT user_id, room_id, count(*) as notif_count,
+ max(stream_ordering) as stream_ordering
+ FROM event_push_actions
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND highlight = 0
+ GROUP BY user_id, room_id
+ ) AS upd
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
- if max_stream_ordering is None:
- return 0
+ txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+ rows = txn.fetchall()
+
+ logger.info("Rotating notifications, handling %d rows", len(rows))
+
+ # If the `old.user_id` above is NULL then we know there isn't already an
+ # entry in the table, so we simply insert it. Otherwise we update the
+ # existing table.
+ self._simple_insert_many_txn(
+ txn,
+ table="event_push_summary",
+ values=[
+ {
+ "user_id": row[0],
+ "room_id": row[1],
+ "notif_count": row[2],
+ "stream_ordering": row[3],
+ }
+ for row in rows if row[4] is None
+ ]
+ )
- range_start = 0
- range_end = max_stream_ordering
+ txn.executemany(
+ """
+ UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ WHERE user_id = ? AND room_id = ?
+ """,
+ ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+ )
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering > ?"
- " ORDER BY stream_ordering"
- " LIMIT 1"
+ txn.execute(
+ "DELETE FROM event_push_actions"
+ " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
+ (old_rotate_stream_ordering, rotate_to_stream_ordering,)
)
- while range_end - range_start > 1:
- middle = int((range_end + range_start) / 2)
- txn.execute(sql, (middle,))
- middle_ts = txn.fetchone()[0]
- if ts > middle_ts:
- range_start = middle
- else:
- range_end = middle
+ logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
- return range_end
+ txn.execute(
+ "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+ (rotate_to_stream_ordering,)
+ )
def _action_has_highlight(actions):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c88f689d3a..05cde96afc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,59 +13,61 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+import itertools
+import logging
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+import simplejson as json
+from twisted.internet import defer
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred
+from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+ PreserveLoggingContext, make_deferred_yieldable,
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
-from synapse.util.caches.descriptors import cached
-
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
-
-import synapse
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.types import get_domain_from_id
import synapse.metrics
-
-import logging
-import math
-import ujson as json
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
persist_event_counter = metrics.register_counter("persisted_events")
+event_counter = metrics.register_counter(
+ "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
+)
-
-def encode_json(json_object):
- if USE_FROZEN_DICTS:
- # ujson doesn't like frozen_dicts
- return encode_canonical_json(json_object)
- else:
- return json.dumps(json_object, ensure_ascii=False)
+# The number of times we are recalculating the current state
+state_delta_counter = metrics.register_counter(
+ "state_delta",
+)
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = metrics.register_counter(
+ "state_delta_single_event",
+)
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = metrics.register_counter(
+ "state_delta_reuse_delta",
+)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+def encode_json(json_object):
+ return frozendict_json_encoder.encode(json_object)
class _EventPeristenceQueue(object):
@@ -82,15 +85,30 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
+
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
end_item = queue[-1]
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred())
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -103,11 +121,11 @@ class _EventPeristenceQueue(object):
def handle_queue(self, room_id, per_item_callback):
"""Attempts to handle the queue for a room if not already being handled.
- The given callback will be invoked with for each item in the queue,1
+ The given callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return
value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferres as well.
+ exceptions will be passed to the deferreds as well.
This function should therefore be called whenever anything is added
to the queue.
@@ -126,18 +144,25 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
+ # handle_queue_loop runs in the sentinel logcontext, so
+ # there is no need to preserve_fn when running the
+ # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
- except Exception as e:
- item.deferred.errback(e)
+ except Exception:
+ item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- preserve_fn(handle_queue_loop)()
+ # set handle_queue_loop off on the background. We don't want to
+ # attribute work done in it to the current request, so we drop the
+ # logcontext altogether.
+ with PreserveLoggingContext():
+ handle_queue_loop()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -173,13 +198,12 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
- def __init__(self, hs):
- super(EventsStore, self).__init__(hs)
- self._clock = hs.get_clock()
+ def __init__(self, db_conn, hs):
+ super(EventsStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@@ -196,8 +220,22 @@ class EventsStore(SQLBaseStore):
where_clause="contains_url = true AND outlier = false",
)
+ # an event_id index on event_search is useful for the purge_history
+ # api. Plus it means we get to enforce some integrity with a UNIQUE
+ # clause
+ self.register_background_index_update(
+ "event_search_event_id_idx",
+ index_name="event_search_event_id_idx",
+ table="event_search",
+ columns=["event_id"],
+ unique=True,
+ psql_only=True,
+ )
+
self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
@@ -210,23 +248,34 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in partitioned.items():
- d = preserve_fn(self._event_persist_queue.add_to_queue)(
+ for room_id, evs_ctxs in partitioned.iteritems():
+ d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs,
backfilled=backfilled,
)
deferreds.append(d)
- for room_id in partitioned.keys():
+ for room_id in partitioned:
self._maybe_start_persisting(room_id)
- return preserve_context_over_deferred(
+ return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False):
+ """
+
+ Args:
+ event (EventBase):
+ context (EventContext):
+ backfilled (bool):
+
+ Returns:
+ Deferred: resolves to (int, int): the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
@@ -234,7 +283,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id)
- yield preserve_context_over_deferred(deferred)
+ yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -242,10 +291,11 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
- yield self._persist_events(
- item.events_and_contexts,
- backfilled=item.backfilled,
- )
+ with Measure(self._clock, "persist_events"):
+ yield self._persist_events(
+ item.events_and_contexts,
+ backfilled=item.backfilled,
+ )
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@@ -253,6 +303,16 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False,
delete_existing=False):
+ """Persist events to db
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+ delete_existing (bool):
+
+ Returns:
+ Deferred: resolves when the events have been persisted
+ """
if not events_and_contexts:
return
@@ -282,8 +342,20 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room
# at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events
current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where each entry is
+ # a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -295,7 +367,7 @@ class EventsStore(SQLBaseStore):
(event, context)
)
- for room_id, ev_ctx_rm in events_by_room.items():
+ for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -305,17 +377,64 @@ class EventsStore(SQLBaseStore):
room_id, ev_ctx_rm, latest_event_ids
)
- if new_latest_event_ids == set(latest_event_ids):
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
# No change in extremities, so no change in state
continue
new_forward_extremeties[room_id] = new_latest_event_ids
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_events) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(e for e, _ in ev.prev_events)
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.info(
+ "Calculating state delta for room %s", room_id,
)
- if state:
- current_state_for_room[room_id] = state
+ current_state = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ if delta is not None:
+ state_delta_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -323,10 +442,35 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
- current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
+ for event, context in chunk:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
+
+ event_counter.inc(event.type, origin_type, origin_entity)
+
+ for room_id, new_state in current_state_for_room.iteritems():
+ self.get_current_state_ids.prefill(
+ (room_id, ), new_state
+ )
+
+ for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
@defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
@@ -370,71 +514,137 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 2-tuple (to_delete, to_insert) where both are state dicts, i.e.
- (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert.
- May return None if there are no changes to be applied.
+ Deferred[dict[(str,str), str]|None]:
+ None if there are no changes to the room state, or
+ a dict of (type, state_key) -> event_id].
"""
- # Now we need to work out the different state sets for
- # each state extremities
- state_sets = []
- missing_event_ids = []
- was_updated = False
+
+ if not new_latest_event_ids:
+ return
+
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
- state_sets.append(ctx.current_state_ids)
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
- groups = set(event_to_groups.values())
- group_to_state = yield self._get_state_for_groups(groups)
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
- state_sets.extend(group_to_state.values())
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- if not new_latest_event_ids:
- current_state = {}
- elif was_updated:
- current_state = yield resolve_events(
- state_sets,
- state_map_factory=lambda ev_ids: self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
- )
- else:
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
return
- existing_state_rows = yield self._simple_select_list(
- table="current_state_events",
- keyvalues={"room_id": room_id},
- retcols=["event_id", "type", "state_key"],
- desc="_calculate_state_delta",
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ def get_events(ev_ids):
+ return self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
+ }
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups, events_map, get_events
)
- existing_events = set(row["event_id"] for row in existing_state_rows)
+ defer.returnValue(res.state)
+
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 2-tuple (to_delete, to_insert) where both are state dicts,
+ i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert.
+ """
+ existing_state = yield self.get_current_state_ids(room_id)
+
+ existing_events = set(existing_state.itervalues())
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
@@ -442,9 +652,8 @@ class EventsStore(SQLBaseStore):
return
to_delete = {
- (row["type"], row["state_key"]): row["event_id"]
- for row in existing_state_rows
- if row["event_id"] in changed_events
+ key: ev_id for key, ev_id in existing_state.iteritems()
+ if ev_id in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
@@ -454,77 +663,104 @@ class EventsStore(SQLBaseStore):
defer.returnValue((to_delete, to_insert))
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
+ @log_function
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ delete_existing=False, state_delta_for_room={},
+ new_forward_extremeties={}):
+ """Insert some number of room events into the necessary database tables.
+
+ Rejected events are only inserted into the events table, the events_json table,
+ and the rejections table. Things reading from those table will need to check
+ whether the event was rejected.
Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]):
+ events to persist
+ backfilled (bool): True if the events were backfilled
+ delete_existing (bool): True to purge existing table rows for the
+ events from the database. This is useful when retrying due to
+ IntegrityError.
+ state_delta_for_room (dict[str, (list[str], list[str])]):
+ The current-state delta for each room. For each room, a tuple
+ (to_delete, to_insert), being a list of event ids to be removed
+ from the current state, and a list of event ids to be added to
+ the current state.
+ new_forward_extremeties (dict[str, list[str]]):
+ The new forward extremities for each room. For each room, a
+ list of the event ids which are the forward extremities.
- Returns:
- Deferred : A FrozenEvent.
"""
- events = yield self._get_events(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
+ all_events_and_contexts = events_and_contexts
+
+ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+ self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
+
+ self._update_forward_extremities_txn(
+ txn,
+ new_forward_extremities=new_forward_extremeties,
+ max_stream_order=max_stream_order,
)
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
+ # Ensure that we don't have the same event twice.
+ events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+ events_and_contexts,
+ )
- defer.returnValue(events[0] if events else None)
+ self._update_room_depths_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ )
- @defer.inlineCallbacks
- def get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- """Get events from the database
+ # _update_outliers_txn filters out any events which have already been
+ # persisted, and returns the filtered list.
+ events_and_contexts = self._update_outliers_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ # From this point onwards the events are only events that we haven't
+ # seen before.
- Returns:
- Deferred : Dict from event_id to event.
- """
- events = yield self._get_events(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+ self._delete_existing_rows_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ self._store_event_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
)
- defer.returnValue({e.event_id: e for e in events})
+ # Insert into event_to_state_groups.
+ self._store_event_state_mappings_txn(txn, events_and_contexts)
- @log_function
- def _persist_events_txn(self, txn, events_and_contexts, backfilled,
- delete_existing=False, current_state_for_room={},
- new_forward_extremeties={}):
- """Insert some number of room events into the necessary database tables.
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
- Rejected events are only inserted into the events table, the events_json table,
- and the rejections table. Things reading from those table will need to check
- whether the event was rejected.
+ # From this point onwards the events are only ones that weren't
+ # rejected.
- If delete_existing is True then existing events will be purged from the
- database before insertion. This is useful when retrying due to IntegrityError.
- """
- max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- for room_id, current_state_tuple in current_state_for_room.iteritems():
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
+ for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
@@ -545,6 +781,29 @@ class EventsStore(SQLBaseStore):
],
)
+ state_deltas = {key: None for key in to_delete}
+ state_deltas.update(to_insert)
+
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_delta_stream",
+ values=[
+ {
+ "stream_id": max_stream_order,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": ev_id,
+ "prev_event_id": to_delete.get(key, None),
+ }
+ for key, ev_id in state_deltas.iteritems()
+ ]
+ )
+
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ room_id, max_stream_order,
+ )
+
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
@@ -553,24 +812,34 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in to_delete.iterkeys()
- if ev_type == EventTypes.Member
- )
- members_changed.update(
- state_key for ev_type, state_key in to_insert.iterkeys()
+ state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
- txn, self.get_rooms_for_user, (member,)
+ txn, self.get_rooms_for_user_with_stream_ordering, (member,)
+ )
+
+ for host in set(get_domain_from_id(u) for u in members_changed):
+ self._invalidate_cache_and_stream(
+ txn, self.is_host_joined, (room_id, host)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.was_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
- for room_id, new_extrem in new_forward_extremeties.items():
+ self._invalidate_cache_and_stream(
+ txn, self.get_current_state_ids, (room_id,)
+ )
+
+ def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+ max_stream_order):
+ for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@@ -588,7 +857,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id,
"room_id": room_id,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem
],
)
@@ -605,13 +874,22 @@ class EventsStore(SQLBaseStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem
]
)
- # Ensure that we don't have the same event twice.
- # Pick the earliest non-outlier if there is one, else the earliest one.
+ @classmethod
+ def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ """Ensure that we don't have the same event twice.
+
+ Pick the earliest non-outlier if there is one, else the earliest one.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ Returns:
+ list[(EventBase, EventContext)]: filtered list
+ """
new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -624,9 +902,17 @@ class EventsStore(SQLBaseStore):
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
+ return new_events_and_contexts.values()
- events_and_contexts = new_events_and_contexts.values()
+ def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ """Update min_depth for each room
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ backfilled (bool): True if the events were backfilled
+ """
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -642,9 +928,24 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in depth_updates.items():
+ for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth)
+ def _update_outliers_txn(self, txn, events_and_contexts):
+ """Update any outliers with new event info.
+
+ This turns outliers into ex-outliers (unless the new event was
+ rejected).
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without events which
+ are already in the events table.
+ """
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
@@ -654,34 +955,30 @@ class EventsStore(SQLBaseStore):
have_persisted = {
event_id: outlier
- for event_id, outlier in txn.fetchall()
+ for event_id, outlier in txn
}
to_remove = set()
for event, context in events_and_contexts:
- if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
- if event.event_id in have_persisted:
- # If we have already seen the event then ignore it.
- to_remove.add(event)
- continue
-
if event.event_id not in have_persisted:
continue
to_remove.add(event)
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ continue
+
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
- # insert into the state_group, state_groups_state and
- # event_to_state_groups tables.
+ # insert into event_to_state_groups.
try:
- self._store_mult_state_groups_txn(txn, ((event, context),))
+ self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
@@ -726,37 +1023,19 @@ class EventsStore(SQLBaseStore):
# event isn't an outlier any more.
self._update_backward_extremeties(txn, [event])
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ @classmethod
+ def _delete_existing_rows_txn(cls, txn, events_and_contexts):
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only events that we haven't
- # seen before.
-
- def event_dict(event):
- return {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
-
- logger.info("Deleting existing")
+ logger.info("Deleting existing")
- for table in (
+ for table in (
"events",
"event_auth",
"event_json",
@@ -779,11 +1058,30 @@ class EventsStore(SQLBaseStore):
"redactions",
"room_memberships",
"topics"
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts]
- )
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
+ def _store_event_txn(self, txn, events_and_contexts):
+ """Insert new events into the event and event_json tables
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ """
+
+ if not events_and_contexts:
+ # nothing to do here
+ return
+
+ def event_dict(event):
+ d = event.get_dict()
+ d.pop("redacted", None)
+ d.pop("redacted_because", None)
+ return d
self._simple_insert_many_txn(
txn,
@@ -827,6 +1125,19 @@ class EventsStore(SQLBaseStore):
],
)
+ def _store_rejected_events_txn(self, txn, events_and_contexts):
+ """Add rows to the 'rejections' table for received events which were
+ rejected
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without the rejected
+ events.
+ """
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
@@ -838,24 +1149,37 @@ class EventsStore(SQLBaseStore):
)
to_remove.add(event)
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ def _update_metadata_tables_txn(self, txn, events_and_contexts,
+ all_events_and_contexts, backfilled):
+ """Update all the miscellaneous tables for new events
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ backfilled (bool): True if the events were backfilled
+ """
+
+ # Insert all the push actions into the event_push_actions table.
+ self._set_push_actions_for_event_and_users_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ )
+
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only ones that weren't rejected.
-
for event, context in events_and_contexts:
- # Insert all the push actions into the event_push_actions table.
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
@@ -874,13 +1198,10 @@ class EventsStore(SQLBaseStore):
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
+ if event.is_state()
],
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -967,13 +1288,6 @@ class EventsStore(SQLBaseStore):
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- if backfilled:
- # Backfilled events come before the current state so we don't need
- # to update the current state table
- return
-
- return
-
def _add_to_cache(self, txn, events_and_contexts):
to_prefill = []
@@ -1037,13 +1351,49 @@ class EventsStore(SQLBaseStore):
defer.returnValue(set(r["event_id"] for r in rows))
- def have_events(self, event_ids):
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
+ Args:
+ event_ids (iterable[str]):
+
Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
"""
if not event_ids:
return defer.succeed({})
@@ -1065,280 +1415,7 @@ class EventsStore(SQLBaseStore):
return res
- return self.runInteraction(
- "have_events", f,
- )
-
- @defer.inlineCallbacks
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids,
- allow_rejected=allow_rejected,
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- allow_rejected=allow_rejected,
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected):
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get((event_id,), None)
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- event_list = []
- i = 0
- while True:
- try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
- event_ids = [
- item for sublist in event_id_lists for item in sublist
- ]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
- )
-
- row_dict = {
- r["event_id"]: r
- for r in rows
- }
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
- except:
- logger.exception("Failed to callback")
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(e)
-
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append(
- (events, events_d)
- )
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self._get_event_from_row)(
- row["internal_metadata"], row["json"], row["redacts"],
- rejected_reason=row["rejects"],
- )
- for row in rows
- ],
- consumeErrors=True
- ))
-
- defer.returnValue({
- e.event.event_id: e
- for e in res if e
- })
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) / N):
- evs = events[i * N:(i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- original_ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
+ return self.runInteraction("get_rejection_reasons", f)
@defer.inlineCallbacks
def count_daily_messages(self):
@@ -1349,66 +1426,52 @@ class EventsStore(SQLBaseStore):
call to this function, it will return None.
"""
def _count_messages(txn):
- now = self.hs.get_clock().time()
-
- txn.execute(
- "SELECT reported_stream_token, reported_time FROM stats_reporting"
- )
- last_reported = self.cursor_to_dict(txn)
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ count, = txn.fetchone()
+ return count
- txn.execute(
- "SELECT stream_ordering"
- " FROM events"
- " ORDER BY stream_ordering DESC"
- " LIMIT 1"
- )
- now_reporting = self.cursor_to_dict(txn)
- if not now_reporting:
- logger.info("Calculating daily messages skipped; no now_reporting")
- return None
- now_reporting = now_reporting[0]["stream_ordering"]
-
- txn.execute("DELETE FROM stats_reporting")
- txn.execute(
- "INSERT INTO stats_reporting"
- " (reported_stream_token, reported_time)"
- " VALUES (?, ?)",
- (now_reporting, now,)
- )
-
- if not last_reported:
- logger.info("Calculating daily messages skipped; no last_reported")
- return None
-
- # Close enough to correct for our purposes.
- yesterday = (now - 24 * 60 * 60)
- since_yesterday_seconds = yesterday - last_reported[0]["reported_time"]
- any_since_yesterday = math.fabs(since_yesterday_seconds) > 60 * 60
- if any_since_yesterday:
- logger.info(
- "Calculating daily messages skipped; since_yesterday_seconds: %d" %
- (since_yesterday_seconds,)
- )
- return None
+ ret = yield self.runInteraction("count_messages", _count_messages)
+ defer.returnValue(ret)
- txn.execute(
- "SELECT COUNT(*) as messages"
- " FROM events NATURAL JOIN event_json"
- " WHERE json like '%m.room.message%'"
- " AND stream_ordering > ?"
- " AND stream_ordering <= ?",
- (
- last_reported[0]["reported_stream_token"],
- now_reporting,
- )
- )
- rows = self.cursor_to_dict(txn)
- if not rows:
- logger.info("Calculating daily messages skipped; messages count missing")
- return None
- return rows[0]["messages"]
+ @defer.inlineCallbacks
+ def count_daily_sent_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago,))
+ count, = txn.fetchone()
+ return count
+
+ ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
+ defer.returnValue(ret)
- ret = yield self.runInteraction("count_messages", _count_messages)
+ @defer.inlineCallbacks
+ def count_daily_active_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ count, = txn.fetchone()
+ return count
+
+ ret = yield self.runInteraction("count_daily_active_rooms", _count)
defer.returnValue(ret)
@defer.inlineCallbacks
@@ -1569,6 +1632,94 @@ class EventsStore(SQLBaseStore):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_id, upper_bound))
+ new_event_updates.extend(txn)
+
+ return new_event_updates
+ return self.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+ return self.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
@cached(num_args=5, max_entries=10)
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
@@ -1582,14 +1733,13 @@ class EventsStore(SQLBaseStore):
def get_all_new_events_txn(txn):
sql = (
- "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
@@ -1615,15 +1765,13 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = []
sql = (
- "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
- " eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
- " ORDER BY e.stream_ordering DESC"
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
@@ -1654,16 +1802,32 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
- def delete_old_state(self, room_id, topological_ordering):
- return self.runInteraction(
- "delete_old_state",
- self._delete_old_state_txn, room_id, topological_ordering
- )
+ def purge_history(
+ self, room_id, topological_ordering, delete_local_events,
+ ):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
- def _delete_old_state_txn(self, txn, room_id, topological_ordering):
- """Deletes old room state
+ topological_ordering (int):
+ minimum topo ordering to preserve
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
"""
+ return self.runInteraction(
+ "purge_history",
+ self._purge_history_txn, room_id, topological_ordering,
+ delete_local_events,
+ )
+
+ def _purge_history_txn(
+ self, txn, room_id, topological_ordering, delete_local_events,
+ ):
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -1684,6 +1848,30 @@ class EventsStore(SQLBaseStore):
# state_groups
# state_groups_state
+ # we will build a temporary table listing the events so that we don't
+ # have to keep shovelling the list back and forth across the
+ # connection. Annoyingly the python sqlite driver commits the
+ # transaction on CREATE, so let's do this first.
+ #
+ # furthermore, we might already have the table from a previous (failed)
+ # purge attempt, so let's drop the table first.
+
+ txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+ txn.execute(
+ "CREATE TEMPORARY TABLE events_to_purge ("
+ " event_id TEXT NOT NULL,"
+ " should_delete BOOLEAN NOT NULL"
+ ")"
+ )
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1704,29 +1892,49 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties"
)
+ logger.info("[purge] looking for events to delete")
+
+ should_delete_expr = "state_key IS NULL"
+ should_delete_params = ()
+ if not delete_local_events:
+ should_delete_expr += " AND event_id NOT LIKE ?"
+ should_delete_params += ("%:" + self.hs.hostname, )
+
+ should_delete_params += (room_id, topological_ordering)
+
+ txn.execute(
+ "INSERT INTO events_to_purge"
+ " SELECT event_id, %s"
+ " FROM events AS e LEFT JOIN state_events USING (event_id)"
+ " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ should_delete_expr,
+ ),
+ should_delete_params,
+ )
txn.execute(
- "SELECT event_id, state_key FROM events"
- " LEFT JOIN state_events USING (room_id, event_id)"
- " WHERE room_id = ? AND topological_ordering < ?",
- (room_id, topological_ordering,)
+ "SELECT event_id, should_delete FROM events_to_purge"
)
event_rows = txn.fetchall()
+ logger.info(
+ "[purge] found %i events before cutoff, of which %i can be deleted",
+ len(event_rows), sum(1 for e in event_rows if e[1]),
+ )
- for event_id, state_key in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+ logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged
txn.execute(
- "SELECT DISTINCT e.event_id FROM events as e"
- " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
- " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
- " WHERE e.room_id = ? AND e.topological_ordering < ?"
- " AND e2.topological_ordering >= ?",
- (room_id, topological_ordering, topological_ordering)
+ "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+ " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
+ " WHERE e2.topological_ordering >= ?",
+ (topological_ordering, )
)
new_backwards_extrems = txn.fetchall()
+ logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
(room_id,)
@@ -1741,30 +1949,36 @@ class EventsStore(SQLBaseStore):
]
)
+ logger.info("[purge] finding redundant state groups")
+
# Get all state groups that are only referenced by events that are
# to be deleted.
txn.execute(
"SELECT state_group FROM event_to_state_groups"
" INNER JOIN events USING (event_id)"
" WHERE state_group IN ("
- " SELECT DISTINCT state_group FROM events"
+ " SELECT DISTINCT state_group FROM events_to_purge"
" INNER JOIN event_to_state_groups USING (event_id)"
- " WHERE room_id = ? AND topological_ordering < ?"
" )"
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
- (room_id, topological_ordering, topological_ordering)
+ (topological_ordering, )
)
state_rows = txn.fetchall()
- state_groups_to_delete = [sg for sg, in state_rows]
+ logger.info("[purge] found %i redundant state groups", len(state_rows))
+
+ # make a set of the redundant state groups, so that we can look them up
+ # efficiently
+ state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups
- new_state_edges = []
- chunks = [
- state_groups_to_delete[i:i + 100]
- for i in xrange(0, len(state_groups_to_delete), 100)
- ]
- for chunk in chunks:
+ logger.info("[purge] finding state groups which depend on redundant"
+ " state groups")
+ remaining_state_groups = []
+ for i in xrange(0, len(state_rows), 100):
+ chunk = [sg for sg, in state_rows[i:i + 100]]
+ # look for state groups whose prev_state_group is one we are about
+ # to delete
rows = self._simple_select_many_txn(
txn,
table="state_group_edges",
@@ -1773,21 +1987,28 @@ class EventsStore(SQLBaseStore):
retcols=["state_group"],
keyvalues={},
)
- new_state_edges.extend(row["state_group"] for row in rows)
+ remaining_state_groups.extend(
+ row["state_group"] for row in rows
+
+ # exclude state groups we are about to delete: no point in
+ # updating them
+ if row["state_group"] not in state_groups_to_delete
+ )
- # Now we turn the state groups that reference to-be-deleted state groups
- # to non delta versions.
- for new_state_edge in new_state_edges:
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
- txn, [new_state_edge], types=None
+ txn, [sg], types=None
)
- curr_state = curr_state[new_state_edge]
+ curr_state = curr_state[sg]
self._simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={
- "state_group": new_state_edge,
+ "state_group": sg,
}
)
@@ -1795,7 +2016,7 @@ class EventsStore(SQLBaseStore):
txn,
table="state_group_edges",
keyvalues={
- "state_group": new_state_edge,
+ "state_group": sg,
}
)
@@ -1804,16 +2025,17 @@ class EventsStore(SQLBaseStore):
table="state_groups_state",
values=[
{
- "state_group": new_state_edge,
+ "state_group": sg,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in curr_state.items()
+ for key, state_id in curr_state.iteritems()
],
)
+ logger.info("[purge] removing redundant state groups")
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
state_rows
@@ -1822,22 +2044,18 @@ class EventsStore(SQLBaseStore):
"DELETE FROM state_groups WHERE id = ?",
state_rows
)
- # Delete all non-state
- txn.executemany(
- "DELETE FROM event_to_state_groups WHERE event_id = ?",
- [(event_id,) for event_id, _ in event_rows]
- )
+ logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
- "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (topological_ordering, room_id,)
+ "DELETE FROM event_to_state_groups "
+ "WHERE event_id IN (SELECT event_id from events_to_purge)"
)
+ for event_id, _ in event_rows:
+ txn.call_after(self._get_state_group_for_event.invalidate, (
+ event_id,
+ ))
# Delete all remote non-state events
- to_delete = [
- (event_id,) for event_id, state_key in event_rows
- if state_key is None and not self.hs.is_mine_id(event_id)
- ]
for table in (
"events",
"event_json",
@@ -1847,29 +2065,102 @@ class EventsStore(SQLBaseStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
"rejections",
):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- to_delete
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ )
+
+ # event_push_actions lacks an index on event_id, and has one on
+ # (room_id, event_id) instead.
+ for table in (
+ "event_push_actions",
+ ):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ (room_id, )
)
- txn.executemany(
- "DELETE FROM events WHERE event_id = ?",
- to_delete
- )
# Mark all state and own events as outliers
- txn.executemany(
+ logger.info("[purge] marking remaining events as outliers")
+ txn.execute(
"UPDATE events SET outlier = ?"
- " WHERE event_id = ?",
- [
- (True, event_id,) for event_id, state_key in event_rows
- if state_key is not None or self.hs.is_mine_id(event_id)
- ]
+ " WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")",
+ (True,),
+ )
+
+ # synapse tries to take out an exclusive lock on room_depth whenever it
+ # persists events (because upsert), and once we run this update, we
+ # will block that for the rest of our transaction.
+ #
+ # So, let's stick it at the end so that we don't block event
+ # persistence.
+ logger.info("[purge] updating room_depth")
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # finally, drop the temp table. this will commit the txn in sqlite,
+ # so make sure to keep this actually last.
+ txn.execute(
+ "DROP TABLE events_to_purge"
+ )
+
+ logger.info("[purge] done")
+
+ @defer.inlineCallbacks
+ def is_event_after(self, event_id1, event_id2):
+ """Returns True if event_id1 is after event_id2 in the stream
+ """
+ to_1, so_1 = yield self._get_event_ordering(event_id1)
+ to_2, so_2 = yield self._get_event_ordering(event_id2)
+ defer.returnValue((to_1, so_1) > (to_2, so_2))
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def _get_event_ordering(self, event_id):
+ res = yield self._simple_select_one(
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True
+ )
+
+ if not res:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+
+ def get_max_current_state_delta_stream_id(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..ba834854e1
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,416 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background,
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import simplejson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+ def get_received_ts(self, event_id):
+ """Get received_ts (when it was persisted) for the event.
+
+ Raises an exception for unknown events.
+
+ Args:
+ event_id (str)
+
+ Returns:
+ Deferred[int|None]: Timestamp in milliseconds, or None for events
+ that were persisted before received_ts was implemented.
+ """
+ return self._simple_select_one_onecol(
+ table="events",
+ keyvalues={
+ "event_id": event_id,
+ },
+ retcol="received_ts",
+ desc="get_received_ts",
+ )
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ events = yield self._get_events(
+ [event_id],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not events and not allow_none:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ defer.returnValue(events[0] if events else None)
+
+ @defer.inlineCallbacks
+ def get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self._get_events(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ defer.returnValue({e.event_id: e for e in events})
+
+ @defer.inlineCallbacks
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ defer.returnValue([])
+
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
+ event_entry_map = self._get_events_from_cache(
+ event_ids,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ allow_rejected=allow_rejected,
+ )
+
+ event_entry_map.update(missing_events)
+
+ events = []
+ for event_id in event_id_list:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if allow_rejected or not entry.event.rejected_reason:
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ defer.returnValue(events)
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (list(str)): list of event_ids to fetch
+ allow_rejected (bool): Whether to teturn events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None,
+ update_metrics=update_metrics,
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ event_list = []
+ i = 0
+ while True:
+ try:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+
+ rows = self._new_transaction(
+ conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+ )
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ # We only want to resolve deferreds from the main thread
+ def fire(lst, res):
+ for ids, d in lst:
+ if not d.called:
+ try:
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
+ except Exception:
+ logger.exception("Failed to callback")
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(e)
+
+ if event_list:
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+ """
+ if not events:
+ defer.returnValue({})
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, events_d)
+ )
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
+
+ logger.debug("Loading %d events", len(events))
+ with PreserveLoggingContext():
+ rows = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = yield make_deferred_yieldable(defer.gatherResults(
+ [
+ run_in_background(
+ self._get_event_from_row,
+ row["internal_metadata"], row["json"], row["redacts"],
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True
+ ))
+
+ defer.returnValue({
+ e.event.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
+ rows = []
+ N = 200
+ for i in range(1 + len(events) / N):
+ evs = events[i * N:(i + 1) * N]
+ if not evs:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"] * len(evs)),)
+
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ @defer.inlineCallbacks
+ def _get_event_from_row(self, internal_metadata, js, redacted,
+ rejected_reason=None):
+ with Measure(self._clock, "_get_event_from_row"):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row_rejected_reason",
+ )
+
+ original_ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
+
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": redacted_event.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row_redactions",
+ )
+
+ redacted_event.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False,
+ allow_none=True,
+ )
+
+ if because:
+ # It's fine to do add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = because
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
+ )
+
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+ defer.returnValue(cache_entry)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index a2ccc66ea7..78b1e30945 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -19,6 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from canonicaljson import encode_canonical_json
import simplejson as json
@@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore):
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
- def_json = json.dumps(user_filter).encode("utf-8")
+ def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
+ "SELECT filter_id FROM user_filters "
+ "WHERE user_id = ? AND filter_json = ?"
+ )
+ txn.execute(sql, (user_localpart, def_json))
+ filter_id_response = txn.fetchone()
+ if filter_id_response is not None:
+ return filter_id_response[0]
+
+ sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
new file mode 100644
index 0000000000..da05ccb027
--- /dev/null
+++ b/synapse/storage/group_server.py
@@ -0,0 +1,1253 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+
+from ._base import SQLBaseStore
+
+import simplejson as json
+
+
+# The category ID for the "default" category. We don't store as null in the
+# database to avoid the fun of null != null
+_DEFAULT_CATEGORY_ID = ""
+_DEFAULT_ROLE_ID = ""
+
+
+class GroupServerStore(SQLBaseStore):
+ def set_group_join_policy(self, group_id, join_policy):
+ """Set the join policy of a group.
+
+ join_policy can be one of:
+ * "invite"
+ * "open"
+ """
+ return self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues={
+ "join_policy": join_policy,
+ },
+ desc="set_group_join_policy",
+ )
+
+ def get_group(self, group_id):
+ return self._simple_select_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=(
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public", "join_policy",
+ ),
+ allow_none=True,
+ desc="get_group",
+ )
+
+ def get_users_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_users",
+ keyvalues=keyvalues,
+ retcols=("user_id", "is_public", "is_admin",),
+ desc="get_users_in_group",
+ )
+
+ def get_invited_users_in_group(self, group_id):
+ # TODO: Pagination
+
+ return self._simple_select_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcol="user_id",
+ desc="get_invited_users_in_group",
+ )
+
+ def get_rooms_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_rooms",
+ keyvalues=keyvalues,
+ retcols=("room_id", "is_public",),
+ desc="get_rooms_in_group",
+ )
+
+ def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ """Get the rooms and categories that should be included in a summary request
+
+ Returns ([rooms], [categories])
+ """
+ def _get_rooms_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT room_id, is_public, category_id, room_order
+ FROM group_summary_rooms
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ rooms = [
+ {
+ "room_id": row[0],
+ "is_public": row[1],
+ "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT category_id, is_public, profile, cat_order
+ FROM group_summary_room_categories
+ INNER JOIN group_room_categories USING (group_id, category_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ categories = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return rooms, categories
+ return self.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
+
+ def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
+ return self.runInteraction(
+ "add_room_to_summary", self._add_room_to_summary_txn,
+ group_id, room_id, category_id, order, is_public,
+ )
+
+ def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
+ is_public):
+ """Add (or update) room's entry in summary.
+
+ Args:
+ group_id (str)
+ room_id (str)
+ category_id (str): If not None then adds the category to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the room at that position, e.g.
+ an order of 1 will put the room first. Otherwise, the room gets
+ added to the end.
+ """
+ room_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ retcol="room_id",
+ allow_none=True,
+ )
+ if not room_in_group:
+ raise SynapseError(400, "room not in group")
+
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+ else:
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ raise SynapseError(400, "Category doesn't exist")
+
+ # TODO: Check category is part of summary already
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_room_categories
+ (group_id, category_id, cat_order)
+ SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
+ FROM group_summary_room_categories
+ WHERE group_id = ? AND category_id = ?
+ """, (group_id, category_id, group_id, category_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ "category_id": category_id,
+ },
+ retcols=("room_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other room orders that come after the given order
+ sql = """
+ UPDATE group_summary_rooms SET room_order = room_order + 1
+ WHERE group_id = ? AND category_id = ? AND room_order >= ?
+ """
+ txn.execute(sql, (group_id, category_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
+ WHERE group_id = ? AND category_id = ?
+ """
+ txn.execute(sql, (group_id, category_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["room_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_rooms",
+ values={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ "room_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_room_from_summary(self, group_id, room_id, category_id):
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+
+ return self._simple_delete(
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ desc="remove_room_from_summary",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("category_id", "is_public", "profile"),
+ desc="get_group_categories",
+ )
+
+ defer.returnValue({
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, category_id):
+ category = yield self._simple_select_one(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_category",
+ )
+
+ category["profile"] = json.loads(category["profile"])
+
+ defer.returnValue(category)
+
+ def upsert_group_category(self, group_id, category_id, profile, is_public):
+ """Add/update room category for group
+ """
+ insertion_values = {}
+ update_values = {"category_id": category_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_category",
+ )
+
+ def remove_group_category(self, group_id, category_id):
+ return self._simple_delete(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ desc="remove_group_category",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("role_id", "is_public", "profile"),
+ desc="get_group_roles",
+ )
+
+ defer.returnValue({
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, role_id):
+ role = yield self._simple_select_one(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_role",
+ )
+
+ role["profile"] = json.loads(role["profile"])
+
+ defer.returnValue(role)
+
+ def upsert_group_role(self, group_id, role_id, profile, is_public):
+ """Add/remove user role
+ """
+ insertion_values = {}
+ update_values = {"role_id": role_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_role",
+ )
+
+ def remove_group_role(self, group_id, role_id):
+ return self._simple_delete(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ desc="remove_group_role",
+ )
+
+ def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
+ return self.runInteraction(
+ "add_user_to_summary", self._add_user_to_summary_txn,
+ group_id, user_id, role_id, order, is_public,
+ )
+
+ def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
+ is_public):
+ """Add (or update) user's entry in summary.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ role_id (str): If not None then adds the role to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the user at that position, e.g.
+ an order of 1 will put the user first. Otherwise, the user gets
+ added to the end.
+ """
+ user_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+ if not user_in_group:
+ raise SynapseError(400, "user not in group")
+
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+ else:
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ raise SynapseError(400, "Role doesn't exist")
+
+ # TODO: Check role is part of the summary already
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_roles
+ (group_id, role_id, role_order)
+ SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
+ FROM group_summary_roles
+ WHERE group_id = ? AND role_id = ?
+ """, (group_id, role_id, group_id, role_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ "role_id": role_id,
+ },
+ retcols=("user_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other users orders that come after the given order
+ sql = """
+ UPDATE group_summary_users SET user_order = user_order + 1
+ WHERE group_id = ? AND role_id = ? AND user_order >= ?
+ """
+ txn.execute(sql, (group_id, role_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
+ WHERE group_id = ? AND role_id = ?
+ """
+ txn.execute(sql, (group_id, role_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["user_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_users",
+ values={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ "user_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_user_from_summary(self, group_id, user_id, role_id):
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+
+ return self._simple_delete(
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ desc="remove_user_from_summary",
+ )
+
+ def get_users_for_summary_by_role(self, group_id, include_private=False):
+ """Get the users and roles that should be included in a summary request
+
+ Returns ([users], [roles])
+ """
+ def _get_users_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT user_id, is_public, role_id, user_order
+ FROM group_summary_users
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ users = [
+ {
+ "user_id": row[0],
+ "is_public": row[1],
+ "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT role_id, is_public, profile, role_order
+ FROM group_summary_roles
+ INNER JOIN group_roles USING (group_id, role_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ roles = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return users, roles
+ return self.runInteraction(
+ "get_users_for_summary_by_role", _get_users_for_summary_txn
+ )
+
+ def is_user_in_group(self, user_id, group_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ desc="is_user_in_group",
+ ).addCallback(lambda r: bool(r))
+
+ def is_user_admin_in_group(self, group_id, user_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="is_admin",
+ allow_none=True,
+ desc="is_user_admin_in_group",
+ )
+
+ def add_group_invite(self, group_id, user_id):
+ """Record that the group server has invited a user
+ """
+ return self._simple_insert(
+ table="group_invites",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ desc="add_group_invite",
+ )
+
+ def is_user_invited_to_local_group(self, group_id, user_id):
+ """Has the group server invited a user?
+ """
+ return self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ desc="is_user_invited_to_local_group",
+ allow_none=True,
+ )
+
+ def get_users_membership_info_in_group(self, group_id, user_id):
+ """Get a dict describing the membership of a user in a group.
+
+ Example if joined:
+
+ {
+ "membership": "join",
+ "is_public": True,
+ "is_privileged": False,
+ }
+
+ Returns an empty dict if the user is not join/invite/etc
+ """
+ def _get_users_membership_in_group_txn(txn):
+ row = self._simple_select_one_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("is_admin", "is_public"),
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "join",
+ "is_public": row["is_public"],
+ "is_privileged": row["is_admin"],
+ }
+
+ row = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "invite",
+ }
+
+ return {}
+
+ return self.runInteraction(
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+ )
+
+ def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
+ local_attestation=None, remote_attestation=None):
+ """Add a user to the group server.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ is_admin (bool)
+ is_public (bool)
+ local_attestation (dict): The attestation the GS created to give
+ to the remote server. Optional if the user and group are on the
+ same server
+ remote_attestation (dict): The attestation given to GS by remote
+ server. Optional if the user and group are on the same server
+ """
+ def _add_user_to_group_txn(txn):
+ self._simple_insert_txn(
+ txn,
+ table="group_users",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "is_public": is_public,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ },
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ },
+ )
+
+ return self.runInteraction(
+ "add_user_to_group", _add_user_to_group_txn
+ )
+
+ def remove_user_from_group(self, group_id, user_id):
+ def _remove_user_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+ def add_room_to_group(self, group_id, room_id, is_public):
+ return self._simple_insert(
+ table="group_rooms",
+ values={
+ "group_id": group_id,
+ "room_id": room_id,
+ "is_public": is_public,
+ },
+ desc="add_room_to_group",
+ )
+
+ def update_room_in_group_visibility(self, group_id, room_id, is_public):
+ return self._simple_update(
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ updatevalues={
+ "is_public": is_public,
+ },
+ desc="update_room_in_group_visibility",
+ )
+
+ def remove_room_from_group(self, group_id, room_id):
+ def _remove_room_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+ return self.runInteraction(
+ "remove_room_from_group", _remove_room_from_group_txn,
+ )
+
+ def get_publicised_groups_for_user(self, user_id):
+ """Get all groups a user is publicising
+ """
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ "is_publicised": True,
+ },
+ retcol="group_id",
+ desc="get_publicised_groups_for_user",
+ )
+
+ def update_group_publicity(self, group_id, user_id, publicise):
+ """Update whether the user is publicising their membership of the group
+ """
+ return self._simple_update_one(
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "is_publicised": publicise,
+ },
+ desc="update_group_publicity"
+ )
+
+ @defer.inlineCallbacks
+ def register_user_group_membership(self, group_id, user_id, membership,
+ is_admin=False, content={},
+ local_attestation=None,
+ remote_attestation=None,
+ is_publicised=False,
+ ):
+ """Registers that a local user is a member of a (local or remote) group.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ membership (str)
+ is_admin (bool)
+ content (dict): Content of the membership, e.g. includes the inviter
+ if the user has been invited.
+ local_attestation (dict): If remote group then store the fact that we
+ have given out an attestation, else None.
+ remote_attestation (dict): If remote group then store the remote
+ attestation from the group, else None.
+ """
+ def _register_user_group_membership_txn(txn, next_id):
+ # TODO: Upsert?
+ self._simple_delete_txn(
+ txn,
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ table="local_group_membership",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "membership": membership,
+ "is_publicised": is_publicised,
+ "content": json.dumps(content),
+ },
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="local_group_updates",
+ values={
+ "stream_id": next_id,
+ "group_id": group_id,
+ "user_id": user_id,
+ "type": "membership",
+ "content": json.dumps({"membership": membership, "content": content}),
+ }
+ )
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+ # TODO: Insert profile to ensure it comes down stream if its a join.
+
+ if membership == "join":
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ }
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ }
+ )
+ else:
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ return next_id
+
+ with self._group_updates_id_gen.get_next() as next_id:
+ res = yield self.runInteraction(
+ "register_user_group_membership",
+ _register_user_group_membership_txn, next_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, name, avatar_url, short_description,
+ long_description,):
+ yield self._simple_insert(
+ table="groups",
+ values={
+ "group_id": group_id,
+ "name": name,
+ "avatar_url": avatar_url,
+ "short_description": short_description,
+ "long_description": long_description,
+ "is_public": True,
+ },
+ desc="create_group",
+ )
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, profile,):
+ yield self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues=profile,
+ desc="update_group_profile",
+ )
+
+ def get_attestations_need_renewals(self, valid_until_ms):
+ """Get all attestations that need to be renewed until givent time
+ """
+ def _get_attestations_need_renewals_txn(txn):
+ sql = """
+ SELECT group_id, user_id FROM group_attestations_renewals
+ WHERE valid_until_ms <= ?
+ """
+ txn.execute(sql, (valid_until_ms,))
+ return self.cursor_to_dict(txn)
+ return self.runInteraction(
+ "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+ )
+
+ def update_attestation_renewal(self, group_id, user_id, attestation):
+ """Update an attestation that we have renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ },
+ desc="update_attestation_renewal",
+ )
+
+ def update_remote_attestion(self, group_id, user_id, attestation):
+ """Update an attestation that a remote has renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ "attestation_json": json.dumps(attestation)
+ },
+ desc="update_remote_attestion",
+ )
+
+ def remove_attestation_renewal(self, group_id, user_id):
+ """Remove an attestation that we thought we should renew, but actually
+ shouldn't. Ideally this would never get called as we would never
+ incorrectly try and do attestations for local users on local groups.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ """
+ return self._simple_delete(
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ desc="remove_attestation_renewal",
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_attestation(self, group_id, user_id):
+ """Get the attestation that proves the remote agrees that the user is
+ in the group.
+ """
+ row = yield self._simple_select_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("valid_until_ms", "attestation_json"),
+ desc="get_remote_attestation",
+ allow_none=True,
+ )
+
+ now = int(self._clock.time_msec())
+ if row and now < row["valid_until_ms"]:
+ defer.returnValue(json.loads(row["attestation_json"]))
+
+ defer.returnValue(None)
+
+ def get_joined_groups(self, user_id):
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ },
+ retcol="group_id",
+ desc="get_joined_groups",
+ )
+
+ def get_all_groups_for_user(self, user_id, now_token):
+ def _get_all_groups_for_user_txn(txn):
+ sql = """
+ SELECT group_id, type, membership, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND membership != 'leave'
+ AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, now_token,))
+ return [
+ {
+ "group_id": row[0],
+ "type": row[1],
+ "membership": row[2],
+ "content": json.loads(row[3]),
+ }
+ for row in txn
+ ]
+ return self.runInteraction(
+ "get_all_groups_for_user", _get_all_groups_for_user_txn,
+ )
+
+ def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_entity_changed(
+ user_id, from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_groups_changes_for_user_txn(txn):
+ sql = """
+ SELECT group_id, membership, type, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, from_token, to_token,))
+ return [{
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ } for group_id, membership, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+ )
+
+ def get_all_groups_changes(self, from_token, to_token, limit):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+ from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_all_groups_changes_txn(txn):
+ sql = """
+ SELECT stream_id, group_id, user_id, type, content
+ FROM local_group_updates
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit,))
+ return [(
+ stream_id,
+ group_id,
+ user_id,
+ gtype,
+ json.loads(content_json),
+ ) for stream_id, group_id, user_id, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_all_groups_changes", _get_all_groups_changes_txn,
+ )
+
+ def get_group_stream_token(self):
+ return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..87aeaf71d6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
key_ids
Args:
server_name (str): The name of the server.
- key_ids (list of str): List of key_ids to try and look up.
+ key_ids (iterable[str]): key_ids to try and look up.
Returns:
- (list of VerifyKey): The verification keys.
+ Deferred: resolves to dict[str, VerifyKey]: map from
+ key_id to verification key.
"""
keys = {}
for key_id in key_ids:
@@ -112,30 +113,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key
defer.returnValue(keys)
- @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
- key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up
- ts_now_ms (int): The time now in milliseconds
- verification_key (VerifyKey): The NACL verify key.
+ time_now_ms (int): The time now in milliseconds
+ verify_key (nacl.signing.VerifyKey): The NACL verify key.
"""
- yield self._simple_upsert(
- table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
- },
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": buffer(verify_key.encode()),
- },
- desc="store_server_verify_key",
- )
+ key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ def _txn(txn):
+ self._simple_upsert_txn(
+ txn,
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ },
+ values={
+ "from_server": from_server,
+ "ts_added_ms": time_now_ms,
+ "verify_key": buffer(verify_key.encode()),
+ },
+ )
+ txn.call_after(
+ self._get_server_verify_key.invalidate,
+ (server_name, key_id)
+ )
+
+ return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 4c0f82353d..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
-from ._base import SQLBaseStore
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def get_default_thumbnails(self, top_level_type, sub_type):
- return []
+ def __init__(self, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ update_name='local_media_repository_url_idx',
+ index_name='local_media_repository_url_idx',
+ table='local_media_repository',
+ columns=['created_ts'],
+ where_clause='url_cache IS NOT NULL',
+ )
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
@@ -30,13 +37,16 @@ class MediaRepositoryStore(SQLBaseStore):
return self._simple_select_one(
"local_media_repository",
{"media_id": media_id},
- ("media_type", "media_length", "upload_name", "created_ts"),
+ (
+ "media_type", "media_length", "upload_name", "created_ts",
+ "quarantined_by", "url_cache",
+ ),
allow_none=True,
desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
- media_length, user_id):
+ media_length, user_id, url_cache=None):
return self._simple_insert(
"local_media_repository",
{
@@ -46,6 +56,7 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
+ "url_cache": url_cache,
},
desc="store_local_media",
)
@@ -58,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
@@ -70,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
@@ -82,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None
return dict(zip((
- 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+ 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row))
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
- def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+ def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts):
return self._simple_insert(
"local_media_repository_url_cache",
@@ -97,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url,
"response_code": response_code,
"etag": etag,
- "expires": expires,
+ "expires_ts": expires_ts,
"og": og,
"media_id": media_id,
"download_ts": download_ts,
@@ -138,7 +149,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_origin": origin, "media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
- "filesystem_id",
+ "filesystem_id", "quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
@@ -162,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -170,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -234,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore):
},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+ def get_expired_url_cache(self, now_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository_url_cache"
+ " WHERE expires_ts < ?"
+ " ORDER BY expires_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_expired_url_cache_txn(txn):
+ txn.execute(sql, (now_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+ def delete_url_cache(self, media_ids):
+ if len(media_ids) == 0:
+ return
+
+ sql = (
+ "DELETE FROM local_media_repository_url_cache"
+ " WHERE media_id = ?"
+ )
+
+ def _delete_url_cache_txn(txn):
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+ def get_url_cache_media_before(self, before_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository"
+ " WHERE created_ts < ? AND url_cache IS NOT NULL"
+ " ORDER BY created_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_url_cache_media_before_txn(txn):
+ txn.execute(sql, (before_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction(
+ "get_url_cache_media_before", _get_url_cache_media_before_txn,
+ )
+
+ def delete_url_cache_media(self, media_ids):
+ if len(media_ids) == 0:
+ return
+
+ def _delete_url_cache_media_txn(txn):
+ sql = (
+ "DELETE FROM local_media_repository"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ sql = (
+ "DELETE FROM local_media_repository_thumbnails"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction(
+ "delete_url_cache_media", _delete_url_cache_media_txn,
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b357f22be7..04411a665f 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 40
+SCHEMA_VERSION = 48
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -44,6 +45,13 @@ def prepare_database(db_conn, database_engine, config):
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
+
+ Args:
+ db_conn:
+ database_engine:
+ config (synapse.config.homeserver.HomeServerConfig|None):
+ application config, or None if we are connecting to an existing
+ database which we expect to be configured already
"""
try:
cur = db_conn.cursor()
@@ -64,9 +72,13 @@ def prepare_database(db_conn, database_engine, config):
else:
_setup_new_database(cur, database_engine)
+ # check if any of our configured dynamic modules want a database
+ if config is not None:
+ _apply_module_schemas(cur, database_engine, config)
+
cur.close()
db_conn.commit()
- except:
+ except Exception:
db_conn.rollback()
raise
@@ -283,6 +295,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
+def _apply_module_schemas(txn, database_engine, config):
+ """Apply the module schemas for the dynamic modules, if any
+
+ Args:
+ cur: database cursor
+ database_engine: synapse database engine class
+ config (synapse.config.homeserver.HomeServerConfig):
+ application config
+ """
+ for (mod, _config) in config.password_providers:
+ if not hasattr(mod, 'get_db_schema_files'):
+ continue
+ modname = ".".join((mod.__module__, mod.__name__))
+ _apply_module_schema_files(
+ txn, database_engine, modname, mod.get_db_schema_files(),
+ )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+ """Apply the module schemas for a single module
+
+ Args:
+ cur: database cursor
+ database_engine: synapse database engine class
+ modname (str): fully qualified name of the module
+ names_and_streams (Iterable[(str, file)]): the names and streams of
+ schemas to be applied
+ """
+ cur.execute(
+ database_engine.convert_param_style(
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+ ),
+ (modname,)
+ )
+ applied_deltas = set(d for d, in cur)
+ for (name, stream) in names_and_streams:
+ if name in applied_deltas:
+ continue
+
+ root_name, ext = os.path.splitext(name)
+ if ext != '.sql':
+ raise PrepareDatabaseException(
+ "only .sql files are currently supported for module schemas",
+ )
+
+ logger.info("applying schema %s for %s", name, modname)
+ for statement in get_statements(stream):
+ cur.execute(statement)
+
+ # Mark as done.
+ cur.execute(
+ database_engine.convert_param_style(
+ "INSERT INTO applied_module_schemas (module_name, file)"
+ " VALUES (?,?)",
+ ),
+ (modname, name)
+ )
+
+
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -356,7 +427,7 @@ def _get_or_create_schema_state(txn, database_engine):
),
(current_version,)
)
- applied_deltas = [d for d, in txn.fetchall()]
+ applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded
return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4d1590d2b4..9e9d3c2591 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
- self._invalidate_cache_and_stream(
- txn, self._get_presence_for_user, (state.user_id,)
+ txn.call_after(
+ self._get_presence_for_user.invalidate, (state.user_id,)
)
# Actually insert new rows
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..8612bd5ecc 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,15 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from synapse.storage.roommember import ProfileInfo
+from synapse.api.errors import StoreError
+
from ._base import SQLBaseStore
-class ProfileStore(SQLBaseStore):
- def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
+class ProfileWorkerStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def get_profileinfo(self, user_localpart):
+ try:
+ profile = yield self._simple_select_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcols=("displayname", "avatar_url"),
+ desc="get_profileinfo",
+ )
+ except StoreError as e:
+ if e.code == 404:
+ # no match
+ defer.returnValue(ProfileInfo(None, None))
+ return
+ else:
+ raise
+
+ defer.returnValue(
+ ProfileInfo(
+ avatar_url=profile['avatar_url'],
+ display_name=profile['displayname'],
+ )
)
def get_profile_displayname(self, user_localpart):
@@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_displayname",
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
- desc="set_profile_displayname",
- )
-
def get_profile_avatar_url(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",
@@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_from_remote_profile_cache(self, user_id):
+ return self._simple_select_one(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("displayname", "avatar_url",),
+ allow_none=True,
+ desc="get_from_remote_profile_cache",
+ )
+
+
+class ProfileStore(ProfileWorkerStore):
+ def create_profile(self, user_localpart):
+ return self._simple_insert(
+ table="profiles",
+ values={"user_id": user_localpart},
+ desc="create_profile",
+ )
+
+ def set_profile_displayname(self, user_localpart, new_displayname):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"displayname": new_displayname},
+ desc="set_profile_displayname",
+ )
+
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self._simple_update_one(
table="profiles",
@@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore):
updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)
+
+ def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ """Ensure we are caching the remote user's profiles.
+
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ return self._simple_upsert(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
+ )
+
+ def update_remote_profile_cache(self, user_id, displayname, avatar_url):
+ return self._simple_update(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="update_remote_profile_cache",
+ )
+
+ @defer.inlineCallbacks
+ def maybe_delete_remote_profile_cache(self, user_id):
+ """Check if we still care about the remote user's profile, and if we
+ don't then remove their profile from the cache
+ """
+ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+ if not subscribed:
+ yield self._simple_delete(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ desc="delete_remote_profile_cache",
+ )
+
+ def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ """Get all users who haven't been checked since `last_checked`
+ """
+ def _get_remote_profile_cache_entries_that_expire_txn(txn):
+ sql = """
+ SELECT user_id, displayname, avatar_url
+ FROM remote_profile_cache
+ WHERE last_check < ?
+ """
+
+ txn.execute(sql, (last_checked,))
+
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_remote_profile_cache_entries_that_expire",
+ _get_remote_profile_cache_entries_that_expire_txn,
+ )
+
+ @defer.inlineCallbacks
+ def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
+
+ res = yield self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index cbec255966..04a0b59a39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,10 +15,17 @@
# limitations under the License.
from ._base import SQLBaseStore
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes
from twisted.internet import defer
+import abc
import logging
import simplejson as json
@@ -47,8 +55,44 @@ def _load_rules(rawrules, enabled_map):
return rules
-class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks()
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self._get_cache_dict(
+ db_conn, "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache", push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -72,7 +116,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@@ -88,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
@@ -184,6 +244,18 @@ class PushRuleStore(SQLBaseStore):
if uid in local_users_in_room:
user_ids.add(uid)
+ forgotten = yield self.who_forgot_in_room(
+ event.room_id, on_invalidate=cache_context.invalidate,
+ )
+
+ for row in forgotten:
+ user_id = row["user_id"]
+ event_id = row["event_id"]
+
+ mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
+ if event_id == mem_id:
+ user_ids.discard(user_id)
+
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
@@ -215,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results)
+
+class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
@@ -513,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8cc9f0353b..307660b99a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
logger = logging.getLogger(__name__)
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows)
- def get_pushers_stream_token(self):
- return self._pushers_id_gen.get_current_token()
-
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
@@ -135,6 +133,48 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
+ def get_all_updated_pushers_rows(self, last_id, current_id, limit):
+ """Get all the pushers that have changed between the given tokens.
+
+ Returns:
+ Deferred(list(tuple)): each tuple consists of:
+ stream_id (str)
+ user_id (str)
+ app_id (str)
+ pushkey (str)
+ was_deleted (bool): whether the pusher was added/updated (False)
+ or deleted (True)
+ """
+
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_pushers_rows_txn(txn):
+ sql = (
+ "SELECT id, user_name, app_id, pushkey"
+ " FROM pushers"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ results = [list(row) + [False] for row in txn]
+
+ sql = (
+ "SELECT stream_id, user_id, app_id, pushkey"
+ " FROM deleted_pushers"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+
+ results.extend(list(row) + [True] for row in txn)
+ results.sort() # Sort so that they're ordered by stream id
+
+ return results
+ return self.runInteraction(
+ "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
+ )
+
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
@@ -156,56 +196,74 @@ class PusherStore(SQLBaseStore):
defer.returnValue(result)
+
+class PusherStore(PusherWorkerStore):
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id:
- def f(txn):
- newly_inserted = self._simple_upsert_txn(
- txn,
- "pushers",
- {
- "app_id": app_id,
- "pushkey": pushkey,
- "user_name": user_id,
- },
- {
- "access_token": access_token,
- "kind": kind,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "ts": pushkey_ts,
- "lang": lang,
- "data": encode_canonical_json(data),
- "last_stream_ordering": last_stream_ordering,
- "profile_tag": profile_tag,
- "id": stream_id,
- },
- )
- if newly_inserted:
- # get_if_user_has_pusher only cares if the user has
- # at least *one* pusher.
- txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+ # no need to lock because `pushers` has a unique key on
+ # (app_id, pushkey, user_name) so _simple_upsert will retry
+ newly_inserted = yield self._simple_upsert(
+ table="pushers",
+ keyvalues={
+ "app_id": app_id,
+ "pushkey": pushkey,
+ "user_name": user_id,
+ },
+ values={
+ "access_token": access_token,
+ "kind": kind,
+ "app_display_name": app_display_name,
+ "device_display_name": device_display_name,
+ "ts": pushkey_ts,
+ "lang": lang,
+ "data": encode_canonical_json(data),
+ "last_stream_ordering": last_stream_ordering,
+ "profile_tag": profile_tag,
+ "id": stream_id,
+ },
+ desc="add_pusher",
+ lock=False,
+ )
- yield self.runInteraction("add_pusher", f)
+ if newly_inserted:
+ self.runInteraction(
+ "add_pusher",
+ self._invalidate_cache_and_stream,
+ self.get_if_user_has_pusher, (user_id,)
+ )
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
- txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_if_user_has_pusher, (user_id,)
+ )
self._simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
)
- self._simple_upsert_txn(
+
+ # it's possible for us to end up with duplicate rows for
+ # (app_id, pushkey, user_id) at different stream_ids, but that
+ # doesn't really matter.
+ self._simple_insert_txn(
txn,
- "deleted_pushers",
- {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
- {"stream_id": stream_id},
+ table="deleted_pushers",
+ values={
+ "stream_id": stream_id,
+ "app_id": app_id,
+ "pushkey": pushkey,
+ "user_id": user_id,
+ },
)
with self._pushers_id_gen.get_next() as stream_id:
@@ -268,9 +326,12 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
+ # no need to lock because `pusher_throttle` has a primary key on
+ # (pusher, room_id) so _simple_upsert will retry
yield self._simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
- desc="set_throttle_params"
+ desc="set_throttle_params",
+ lock=False,
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f72d15f5ed..63997ed449 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,46 +15,50 @@
# limitations under the License.
from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
+import abc
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
-class ReceiptsStore(SQLBaseStore):
- def __init__(self, hs):
- super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
- if receipt_type != "m.read":
- return
-
- # Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
-
- if res and res.called and user_id in res.result:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
-
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -265,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
}
defer.returnValue(results)
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ args = [last_id, current_id]
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns an ObservableDeferred
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
+
+ if res:
+ if isinstance(res, defer.Deferred) and res.called:
+ res = res.result
+ if user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -313,10 +371,9 @@ class ReceiptsStore(SQLBaseStore):
)
txn.execute(sql, (room_id, receipt_type, user_id))
- results = txn.fetchall()
- if results and topological_ordering:
- for to, so, _ in results:
+ if topological_ordering:
+ for to, so, _ in txn:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
@@ -351,6 +408,7 @@ class ReceiptsStore(SQLBaseStore):
room_id=room_id,
user_id=user_id,
topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
)
return True
@@ -452,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
-
- return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_receipts", get_all_updated_receipts_txn
- )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060c3..a50717db2d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,13 +19,75 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from six.moves import range
-class RegistrationStore(background_updates.BackgroundUpdateStore):
- def __init__(self, hs):
- super(RegistrationStore, self).__init__(hs)
+class RegistrationWorkerStore(SQLBaseStore):
+ @cached()
+ def get_user_by_id(self, user_id):
+ return self._simple_select_one(
+ table="users",
+ keyvalues={
+ "name": user_id,
+ },
+ retcols=["name", "password_hash", "is_guest"],
+ allow_none=True,
+ desc="get_user_by_id",
+ )
+
+ @cached()
+ def get_user_by_access_token(self, token):
+ """Get a user from the given access token.
+
+ Args:
+ token (str): The access token of a user.
+ Returns:
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ """
+ return self.runInteraction(
+ "get_user_by_access_token",
+ self._query_for_auth,
+ token
+ )
+
+ @defer.inlineCallbacks
+ def is_server_admin(self, user):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user.to_string()},
+ retcol="admin",
+ allow_none=True,
+ desc="is_server_admin",
+ )
+
+ defer.returnValue(res if res else False)
+
+ def _query_for_auth(self, txn, token):
+ sql = (
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
+ " FROM users"
+ " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+ " WHERE token = ?"
+ )
+
+ txn.execute(sql, (token,))
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return rows[0]
+
+ return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+ background_updates.BackgroundUpdateStore):
+
+ def __init__(self, db_conn, hs):
+ super(RegistrationStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
@@ -36,12 +98,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id"],
)
- self.register_background_index_update(
- "refresh_tokens_device_index",
- index_name="refresh_tokens_device_id",
- table="refresh_tokens",
- columns=["user_id", "device_id"],
- )
+ # we no longer use refresh tokens, but it's possible that some people
+ # might have a background update queued to build this index. Just
+ # clear the background update.
+ self.register_noop_background_update("refresh_tokens_device_index")
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
@@ -177,9 +237,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
if create_profile_with_localpart:
+ # set a default displayname serverside to avoid ugly race
+ # between auto-joins and clients trying to set displaynames
txn.execute(
- "INSERT INTO profiles(user_id) VALUES (?)",
- (create_profile_with_localpart,)
+ "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
+ (create_profile_with_localpart, create_profile_with_localpart)
)
self._invalidate_cache_and_stream(
@@ -187,18 +249,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- def get_user_by_id(self, user_id):
- return self._simple_select_one(
- table="users",
- keyvalues={
- "name": user_id,
- },
- retcols=["name", "password_hash", "is_guest"],
- allow_none=True,
- desc="get_user_by_id",
- )
-
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
@@ -209,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
- return dict(txn.fetchall())
+ return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
@@ -236,12 +286,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"user_set_password_hash", user_set_password_hash_txn
)
- @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_id=None,
- device_id=None,
- delete_refresh_tokens=False):
+ device_id=None):
"""
- Invalidate access/refresh tokens belonging to a user
+ Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
@@ -250,10 +298,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
- delete_refresh_tokens (bool): True to delete refresh tokens as
- well as access tokens.
Returns:
- defer.Deferred:
+ defer.Deferred[list[str, int, str|None, int]]: a list of
+ (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
keyvalues = {
@@ -262,13 +309,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
- if delete_refresh_tokens:
- self._simple_delete_txn(
- txn,
- table="refresh_tokens",
- keyvalues=keyvalues,
- )
-
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
@@ -277,14 +317,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id)
txn.execute(
- "SELECT token FROM access_tokens WHERE %s" % where_clause,
+ "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values
)
- rows = self.cursor_to_dict(txn)
+ tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
- for row in rows:
+ for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream(
- txn, self.get_user_by_access_token, (row["token"],)
+ txn, self.get_user_by_access_token, (token,)
)
txn.execute(
@@ -292,7 +332,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values
)
- yield self.runInteraction(
+ return tokens_and_devices
+
+ return self.runInteraction(
"user_delete_access_tokens", f,
)
@@ -312,34 +354,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("delete_access_token", f)
- @cached()
- def get_user_by_access_token(self, token):
- """Get a user from the given access token.
-
- Args:
- token (str): The access token of a user.
- Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
- """
- return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
- )
-
- @defer.inlineCallbacks
- def is_server_admin(self, user):
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user.to_string()},
- retcol="admin",
- allow_none=True,
- desc="is_server_admin",
- )
-
- defer.returnValue(res if res else False)
-
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
@@ -352,22 +366,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(res if res else False)
- def _query_for_auth(self, txn, token):
- sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
- " FROM users"
- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
- " WHERE token = ?"
- )
-
- txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
- if rows:
- return rows[0]
-
- return None
-
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
@@ -438,6 +436,19 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(ret)
@defer.inlineCallbacks
+ def count_nonbridged_users(self):
+ def _count_users(txn):
+ txn.execute("""
+ SELECT COALESCE(COUNT(*), 0) FROM users
+ WHERE appservice_id IS NULL
+ """)
+ count, = txn.fetchone()
+ return count
+
+ ret = yield self.runInteraction("count_users", _count_users)
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
@@ -451,18 +462,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
- rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):")
found = set()
- for r in rows:
- user_id = r["name"]
+ for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
- for i in xrange(len(found) + 1):
+ for i in range(len(found) + 1):
if i not in found:
return i
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fdf5..ea6a189185 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,14 +16,14 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
-
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.search import SearchStore
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
import collections
import logging
-import ujson as json
+import simplejson as json
+import re
logger = logging.getLogger(__name__)
@@ -33,8 +33,144 @@ OpsLevel = collections.namedtuple(
("ban_level", "kick_level", "redact_level",)
)
+RatelimitOverride = collections.namedtuple(
+ "RatelimitOverride",
+ ("messages_per_second", "burst_count",)
+)
+
-class RoomStore(SQLBaseStore):
+class RoomWorkerStore(SQLBaseStore):
+ def get_public_room_ids(self):
+ return self._simple_select_onecol(
+ table="rooms",
+ keyvalues={
+ "is_public": True,
+ },
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ @cached(num_args=2, max_entries=100)
+ def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+ """Get pulbic rooms for a particular list, or across all lists.
+
+ Args:
+ stream_id (int)
+ network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+ means the main list, None means all lsits.
+ """
+ return self.runInteraction(
+ "get_public_room_ids_at_stream_id",
+ self.get_public_room_ids_at_stream_id_txn,
+ stream_id, network_tuple=network_tuple
+ )
+
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+ network_tuple):
+ return {
+ rm
+ for rm, vis in self.get_published_at_stream_id_txn(
+ txn, stream_id, network_tuple=network_tuple
+ ).items()
+ if vis
+ }
+
+ def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+ if network_tuple:
+ # We want to get from a particular list. No aggregation required.
+
+ sql = ("""
+ SELECT room_id, visibility FROM public_room_list_stream
+ INNER JOIN (
+ SELECT room_id, max(stream_id) AS stream_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ? %s
+ GROUP BY room_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ if network_tuple.appservice_id is not None:
+ txn.execute(
+ sql % ("AND appservice_id = ? AND network_id = ?",),
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ )
+ else:
+ txn.execute(
+ sql % ("AND appservice_id IS NULL",),
+ (stream_id,)
+ )
+ return dict(txn)
+ else:
+ # We want to get from all lists, so we need to aggregate the results
+
+ logger.info("Executing full list")
+
+ sql = ("""
+ SELECT room_id, visibility
+ FROM public_room_list_stream
+ INNER JOIN (
+ SELECT
+ room_id, max(stream_id) AS stream_id, appservice_id,
+ network_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ?
+ GROUP BY room_id, appservice_id, network_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ txn.execute(
+ sql,
+ (stream_id,)
+ )
+
+ results = {}
+ # A room is visible if its visible on any list.
+ for room_id, visibility in txn:
+ results[room_id] = bool(visibility) or results.get(room_id, False)
+
+ return results
+
+ def get_public_room_changes(self, prev_stream_id, new_stream_id,
+ network_tuple):
+ def get_public_room_changes_txn(txn):
+ then_rooms = self.get_public_room_ids_at_stream_id_txn(
+ txn, prev_stream_id, network_tuple
+ )
+
+ now_rooms_dict = self.get_published_at_stream_id_txn(
+ txn, new_stream_id, network_tuple
+ )
+
+ now_rooms_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if vis
+ )
+ now_rooms_not_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if not vis
+ )
+
+ newly_visible = now_rooms_visible - then_rooms
+ newly_unpublished = now_rooms_not_visible & then_rooms
+
+ return newly_visible, newly_unpublished
+
+ return self.runInteraction(
+ "get_public_room_changes", get_public_room_changes_txn
+ )
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -221,16 +357,6 @@ class RoomStore(SQLBaseStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={
- "is_public": True,
- },
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
def get_room_count(self):
"""Retrieve a list of all rooms
"""
@@ -257,8 +383,8 @@ class RoomStore(SQLBaseStore):
},
)
- self._store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"],
)
def _store_room_name_txn(self, txn, event):
@@ -273,14 +399,14 @@ class RoomStore(SQLBaseStore):
}
)
- self._store_event_search_txn(
- txn, event, "content.name", event.content["name"]
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"],
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
- self._store_event_search_txn(
- txn, event, "content.body", event.content["body"]
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"],
)
def _store_history_visibility_txn(self, txn, event):
@@ -302,31 +428,6 @@ class RoomStore(SQLBaseStore):
event.content[key]
))
- def _store_event_search_txn(self, txn, event, key, value):
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
- txn.execute(
- sql,
- (
- event.event_id, event.room_id, key, value,
- event.internal_metadata.stream_ordering,
- event.origin_server_ts,
- )
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- txn.execute(sql, (event.event_id, event.room_id, key, value,))
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
@@ -347,129 +448,180 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = ("""
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """)
+
+ txn.execute(sql, (prev_id, current_id, limit,))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id, network_tuple=network_tuple
+ "get_all_new_public_rooms", get_all_new_public_rooms
)
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
- network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
+ Args:
+ user_id (str)
- sql = ("""
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """)
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self._simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
- )
- else:
- txn.execute(
- sql % ("AND appservice_id IS NULL",),
- (stream_id,)
- )
- return dict(txn.fetchall())
+ if row:
+ defer.returnValue(RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ ))
else:
- # We want to get from all lists, so we need to aggregate the results
+ defer.returnValue(None)
- logger.info("Executing full list")
-
- sql = ("""
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """)
-
- txn.execute(
- sql,
- (stream_id,)
- )
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn.fetchall():
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
+ @defer.inlineCallbacks
+ def block_room(self, room_id, user_id):
+ yield self._simple_insert(
+ table="blocked_rooms",
+ values={
+ "room_id": room_id,
+ "user_id": user_id,
+ },
+ desc="block_room",
+ )
+ yield self.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked, (room_id,),
+ )
- def get_public_room_changes(self, prev_stream_id, new_stream_id,
- network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
+ Args:
+ room_id (str)
- now_rooms_visible = set(
- rm for rm, vis in now_rooms_dict.items() if vis
- )
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
+ def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ """For a room loops through all events with media and quarantines
+ the associated media
+ """
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
+
+ # Now update all the tables to set the quarantined_by flag
+
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
+ )
)
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
- return newly_visible, newly_unpublished
+ return total_media_quarantined
return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
)
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = ("""
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """)
-
- txn.execute(sql, (prev_id, current_id, limit,))
- return txn.fetchall()
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
- if prev_id == current_id:
- return defer.succeed([])
+ Args:
+ txn (cursor)
+ room_id (str)
- return self.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, json FROM events
+ JOIN event_json USING (room_id, event_id)
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ event_json = json.loads(content_json)
+ content = event_json["content"]
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hs.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 545d3d3a99..6a861943a2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,14 +18,17 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.util.async import Linearizer
+from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.stringutils import to_ascii
from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -34,112 +38,47 @@ RoomsForUser = namedtuple(
("room_id", "sender", "membership", "event_id", "stream_ordering")
)
+GetRoomsForUserWithStreamOrdering = namedtuple(
+ "_GetRoomsForUserWithStreamOrdering",
+ ("room_id", "stream_ordering",)
+)
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-
-
-class RoomMemberStore(SQLBaseStore):
- def __init__(self, hs):
- super(RoomMemberStore, self).__init__(hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ]
- )
+# We store this using a namedtuple so that we save about 3x space over using a
+# dict.
+ProfileInfo = namedtuple(
+ "ProfileInfo", ("avatar_url", "display_name")
+)
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key, event.internal_metadata.stream_ordering
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- }
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
- txn.execute(sql, (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ))
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
+class RoomMemberWorkerStore(EventsWorkerStore):
+ @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+ def get_hosts_in_room(self, room_id, cache_context):
+ """Returns the set of all hosts currently in the room
+ """
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate,
)
+ hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+ defer.returnValue(hosts)
- def f(txn, stream_ordering):
- txn.execute(sql, (
- stream_ordering,
- True,
- room_id,
- user_id,
- ))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
- @cached(max_entries=500000, iterable=True)
+ @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
-
- rows = self._get_members_rows_txn(
- txn,
- room_id=room_id,
- membership=Membership.JOIN,
+ sql = (
+ "SELECT m.user_id FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
+ " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
- return [r["user_id"] for r in rows]
+ txn.execute(sql, (room_id, Membership.JOIN,))
+ return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@@ -246,57 +185,382 @@ class RoomMemberStore(SQLBaseStore):
return results
- def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
- where_clause = "c.room_id = ?"
- where_values = [room_id]
-
- if membership:
- where_clause += " AND m.membership = ?"
- where_values.append(membership)
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ def get_rooms_for_user_with_stream_ordering(self, user_id):
+ """Returns a set of room_ids the user is currently joined to
- if user_id:
- where_clause += " AND m.user_id = ?"
- where_values.append(user_id)
+ Args:
+ user_id (str)
- sql = (
- "SELECT m.* FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND %(where)s"
- ) % {
- "where": where_clause,
- }
-
- txn.execute(sql, where_values)
- rows = self.cursor_to_dict(txn)
-
- return rows
-
- @cached(max_entries=500000, iterable=True)
- def get_rooms_for_user(self, user_id):
- return self.get_rooms_for_user_where_membership_is(
+ Returns:
+ Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+ the rooms the user is in currently, along with the stream ordering
+ of the most recent join for that user and room.
+ """
+ rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
+ defer.returnValue(frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
+ ))
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user_id, on_invalidate=None):
+ """Returns a set of room_ids the user is currently joined to
+ """
+ rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ user_id, on_invalidate=on_invalidate,
+ )
+ defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
- rooms = yield self.get_rooms_for_user(
+ room_ids = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
- for room in rooms:
+ for room_id in room_ids:
user_ids = yield self.get_users_in_room(
- room.room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
defer.returnValue(user_who_share_room)
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ event.room_id, state_group, context.current_state_ids,
+ event=event,
+ context=context,
+ )
+
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry,
+ )
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
+ max_entries=100000)
+ def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+ cache_context, event=None, context=None):
+ # We don't use `state_group`, it's there so that we can cache based
+ # on it. However, it's important that it's never None, since two current_states
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ users_in_room = {}
+ member_event_ids = [
+ e_id
+ for key, e_id in current_state_ids.iteritems()
+ if key[0] == EventTypes.Member
+ ]
+
+ if context is not None:
+ # If we have a context with a delta from a previous state group,
+ # check if we also have the result from the previous group in cache.
+ # If we do then we can reuse that result and simply update it with
+ # any membership changes in `delta_ids`
+ if context.prev_group and context.delta_ids:
+ prev_res = self._get_joined_users_from_context.cache.get(
+ (room_id, context.prev_group), None
+ )
+ if prev_res and isinstance(prev_res, dict):
+ users_in_room = dict(prev_res)
+ member_event_ids = [
+ e_id
+ for key, e_id in context.delta_ids.iteritems()
+ if key[0] == EventTypes.Member
+ ]
+ for etype, state_key in context.delta_ids:
+ users_in_room.pop(state_key, None)
+
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
+
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_cache(
+ member_event_ids,
+ allow_rejected=False,
+ update_metrics=False,
+ )
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+ display_name=to_ascii(
+ ev_entry.event.content.get("displayname", None)
+ ),
+ avatar_url=to_ascii(
+ ev_entry.event.content.get("avatar_url", None)
+ ),
+ )
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ rows = yield self._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=missing_member_event_ids,
+ retcols=('user_id', 'display_name', 'avatar_url',),
+ keyvalues={
+ "membership": Membership.JOIN,
+ },
+ batch_size=500,
+ desc="_get_joined_users_from_context",
+ )
+
+ users_in_room.update({
+ to_ascii(row["user_id"]): ProfileInfo(
+ avatar_url=to_ascii(row["avatar_url"]),
+ display_name=to_ascii(row["display_name"]),
+ )
+ for row in rows
+ })
+
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+ display_name=to_ascii(event.content.get("displayname", None)),
+ avatar_url=to_ascii(event.content.get("avatar_url", None)),
+ )
+
+ defer.returnValue(users_in_room)
+
+ @cachedInlineCallbacks(max_entries=10000)
+ def is_host_joined(self, room_id, host):
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT state_key FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE membership = 'join'
+ AND type = 'm.room.member'
+ AND c.room_id = ?
+ AND state_key LIKE ?
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ defer.returnValue(False)
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ defer.returnValue(True)
+
+ @cachedInlineCallbacks()
+ def was_host_joined(self, room_id, host):
+ """Check whether the server is or ever was in the room.
+
+ Args:
+ room_id (str)
+ host (str)
+
+ Returns:
+ Deferred: Resolves to True if the host is/was in the room, otherwise
+ False.
+ """
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT user_id FROM room_memberships
+ WHERE room_id = ?
+ AND user_id LIKE ?
+ AND membership = 'join'
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ defer.returnValue(False)
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ defer.returnValue(True)
+
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry,
+ )
+
+ @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ cache = self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
+
+ defer.returnValue(joined_hosts)
+
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+ @cached()
+ def who_forgot_in_room(self, room_id):
+ return self._simple_select_list(
+ table="room_memberships",
+ retcols=("user_id", "event_id"),
+ keyvalues={
+ "room_id": room_id,
+ "forgotten": 1,
+ },
+ desc="who_forgot"
+ )
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ]
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -368,124 +632,6 @@ class RoomMemberStore(SQLBaseStore):
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
-
- def get_joined_users_from_context(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_users_from_context(
- event.room_id, state_group, context.current_state_ids, event=event,
- )
-
- def get_joined_users_from_state(self, room_id, state_group, state_ids):
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_users_from_context(
- room_id, state_group, state_ids,
- )
-
- @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
- max_entries=100000)
- def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
- cache_context, event=None):
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- member_event_ids = [
- e_id
- for key, e_id in current_state_ids.iteritems()
- if key[0] == EventTypes.Member
- ]
-
- rows = yield self._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids,
- retcols=['user_id', 'display_name', 'avatar_url'],
- keyvalues={
- "membership": Membership.JOIN,
- },
- batch_size=500,
- desc="_get_joined_users_from_context",
- )
-
- users_in_room = {
- row["user_id"]: {
- "display_name": row["display_name"],
- "avatar_url": row["avatar_url"],
- }
- for row in rows
- }
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[event.state_key] = {
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
-
- defer.returnValue(users_in_room)
-
- def is_host_joined(self, room_id, host, state_group, state_ids):
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._is_host_joined(
- room_id, host, state_group, state_ids
- )
-
- @cachedInlineCallbacks(num_args=3)
- def _is_host_joined(self, room_id, host, state_group, current_state_ids):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- for (etype, state_key), event_id in current_state_ids.items():
- if etype == EventTypes.Member:
- try:
- if get_domain_from_id(state_key) != host:
- continue
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
-
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- defer.returnValue(True)
-
- defer.returnValue(False)
-
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -499,8 +645,9 @@ class RoomMemberStore(SQLBaseStore):
def add_membership_profile_txn(txn):
sql = ("""
- SELECT stream_ordering, event_id, events.room_id, content
+ SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
+ INNER JOIN event_json USING (event_id)
INNER JOIN room_memberships USING (event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND type = 'm.room.member'
@@ -521,8 +668,9 @@ class RoomMemberStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- content = json.loads(row["content"])
- except:
+ event_json = json.loads(row["json"])
+ content = event_json['content']
+ except Exception:
continue
display_name = content.get("displayname", None)
@@ -560,3 +708,71 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry,
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
+import simplejson as json
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..e7351c3ae6 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -17,7 +17,7 @@ import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..6df57b5206 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -16,7 +16,7 @@ import logging
from synapse.storage.prepare_database import get_statements
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 5b7d8d1ab5..85bd1a2006 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -14,6 +14,8 @@
import logging
from synapse.config.appservice import load_appservices
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice.
try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
- except:
+ except Exception:
# Maybe we already added the column? Hope so...
pass
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+ user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..fe6b7d196d 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
import logging
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..1e002f9db2 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -15,7 +15,7 @@
from synapse.storage.prepare_database import get_statements
import logging
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 784f3b348f..20ad8bd5a6 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref;
-- and is used incredibly rarely.
DROP INDEX IF EXISTS events_order_topo_stream_room;
+-- an equivalent index to this actually gets re-created in delta 41, because it
+-- turned out that deleting it wasn't a great plan :/. In any case, let's
+-- delete it here, and delta 41 will create a new one with an added UNIQUE
+-- constraint
DROP INDEX IF EXISTS event_search_ev_idx;
"""
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
* limitations under the License.
*/
- INSERT into background_updates (update_name, progress_json)
- VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+-- VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql
new file mode 100644
index 0000000000..3918f0b794
--- /dev/null
+++ b/synapse/storage/schema/delta/40/event_push_summary.sql
@@ -0,0 +1,37 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Aggregate of old notification counts that have been deleted out of the
+-- main event_push_actions table. This count does not include those that were
+-- highlights, as they remain in the event_push_actions table.
+CREATE TABLE event_push_summary (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notif_count BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id);
+
+
+-- The stream ordering up to which we have aggregated the event_push_actions
+-- table into event_push_summary
+CREATE TABLE event_push_summary_stream_ordering (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_ordering BIGINT NOT NULL,
+ CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql
new file mode 100644
index 0000000000..054a223f14
--- /dev/null
+++ b/synapse/storage/schema/delta/40/pushers.sql
@@ -0,0 +1,39 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS pushers2 (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag TEXT NOT NULL,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ lang TEXT,
+ data TEXT,
+ last_stream_ordering INTEGER,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey, user_name)
+);
+
+INSERT INTO pushers2 SELECT * FROM PUSHERS;
+
+DROP TABLE PUSHERS;
+
+ALTER TABLE pushers2 RENAME TO pushers;
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
index 34db0cf12b..b7bee8b692 100644
--- a/synapse/storage/schema/delta/23/refresh_tokens.sql
+++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
@@ -1,4 +1,4 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,9 +13,5 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS refresh_tokens(
- id INTEGER PRIMARY KEY,
- token TEXT NOT NULL,
- user_id TEXT NOT NULL,
- UNIQUE (token)
-);
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('device_lists_stream_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql
new file mode 100644
index 0000000000..62f0b9892b
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_outbound_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id);
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
new file mode 100644
index 0000000000..5d9cfecf36
--- /dev/null
+++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_search_event_id_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql
new file mode 100644
index 0000000000..a194bf0238
--- /dev/null
+++ b/synapse/storage/schema/delta/41/ratelimit.sql
@@ -0,0 +1,22 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE ratelimit_override (
+ user_id TEXT NOT NULL,
+ messages_per_second BIGINT,
+ burst_count BIGINT
+);
+
+CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql
new file mode 100644
index 0000000000..d28851aff8
--- /dev/null
+++ b/synapse/storage/schema/delta/42/current_state_delta.sql
@@ -0,0 +1,26 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE current_state_delta_stream (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT, -- Is null if the key was removed
+ prev_event_id TEXT -- Is null if the key was added
+);
+
+CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql
new file mode 100644
index 0000000000..9ab8c14fa3
--- /dev/null
+++ b/synapse/storage/schema/delta/42/device_list_last_id.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Table of last stream_id that we sent to destination for user_id. This is
+-- used to fill out the `prev_id` fields of outbound device list updates.
+CREATE TABLE device_lists_outbound_last_success (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_lists_outbound_last_success
+ SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_pokes
+ WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
+ GROUP BY destination, user_id;
+
+CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
+ destination, user_id, stream_id
+);
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql
index bb225dafbf..b8821ac759 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
+++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,4 +14,4 @@
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
- ('refresh_tokens_device_index', '{}');
+ ('event_auth_state_only', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
new file mode 100644
index 0000000000..ea6a18196d
--- /dev/null
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+
+logger = logging.getLogger(__name__)
+
+
+BOTH_TABLES = """
+CREATE TABLE user_directory_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_directory (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
+
+CREATE TABLE users_in_pubic_room (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL -- A room_id that we know is public
+);
+
+CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
+CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
+"""
+
+
+POSTGRES_TABLE = """
+CREATE TABLE user_directory_search (
+ user_id TEXT NOT NULL,
+ vector tsvector
+);
+
+CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
+CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
+"""
+
+
+SQLITE_TABLE = """
+CREATE VIRTUAL TABLE user_directory_search
+ USING fts4 ( user_id, value );
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(BOTH_TABLES.splitlines()):
+ cur.execute(statement)
+
+ if isinstance(database_engine, PostgresEngine):
+ for statement in get_statements(POSTGRES_TABLE.splitlines()):
+ cur.execute(statement)
+ elif isinstance(database_engine, Sqlite3Engine):
+ for statement in get_statements(SQLITE_TABLE.splitlines()):
+ cur.execute(statement)
+ else:
+ raise Exception("Unrecognized database engine")
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql
new file mode 100644
index 0000000000..0e3cd143ff
--- /dev/null
+++ b/synapse/storage/schema/delta/43/blocked_rooms.sql
@@ -0,0 +1,21 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE blocked_rooms (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL -- Admin who blocked the room
+);
+
+CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql
new file mode 100644
index 0000000000..630907ec4f
--- /dev/null
+++ b/synapse/storage/schema/delta/43/quarantine_media.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
+ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/43/url_cache.sql
index 290bd6da86..45ebe020da 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device.sql
+++ b/synapse/storage/schema/delta/43/url_cache.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
* limitations under the License.
*/
-ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
+ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
new file mode 100644
index 0000000000..ee7062abe4
--- /dev/null
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Table keeping track of who shares a room with who. We only keep track
+-- of this for local users, so `user_id` is local users only (but we do keep track
+-- of which remote users share a room)
+CREATE TABLE users_who_share_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room
+);
+
+
+CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
+CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
+CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
+
+
+-- Make sure that we populate the table initially
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..b12f9b2ebf
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,41 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
+-- removed and replaced with 46/local_media_repository_url_idx.sql.
+--
+-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+ url TEXT,
+ response_code INTEGER,
+ etag TEXT,
+ expires_ts BIGINT,
+ og TEXT,
+ media_id TEXT,
+ download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+ SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql
new file mode 100644
index 0000000000..b2333848a0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/group_server.sql
@@ -0,0 +1,167 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups (
+ group_id TEXT NOT NULL,
+ name TEXT, -- the display name of the room
+ avatar_url TEXT,
+ short_description TEXT,
+ long_description TEXT
+);
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
+
+
+-- list of users the group server thinks are joined
+CREATE TABLE group_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
+);
+
+
+CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
+CREATE INDEX groups_users_u_idx ON group_users(user_id);
+
+-- list of users the group server thinks are invited
+CREATE TABLE group_invites (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL
+);
+
+CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
+CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
+
+
+CREATE TABLE group_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
+);
+
+CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
+CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
+
+
+-- Rooms to include in the summary
+CREATE TABLE group_summary_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ room_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
+ UNIQUE (group_id, category_id, room_id, room_order),
+ CHECK (room_order > 0)
+);
+
+CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
+
+
+-- Categories to include in the summary
+CREATE TABLE group_summary_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ cat_order BIGINT NOT NULL,
+ UNIQUE (group_id, category_id, cat_order),
+ CHECK (cat_order > 0)
+);
+
+-- The categories in the group
+CREATE TABLE group_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
+ UNIQUE (group_id, category_id)
+);
+
+-- The users to include in the group summary
+CREATE TABLE group_summary_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ user_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
+);
+
+CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
+
+-- The roles to include in the group summary
+CREATE TABLE group_summary_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ role_order BIGINT NOT NULL,
+ UNIQUE (group_id, role_id, role_order),
+ CHECK (role_order > 0)
+);
+
+
+-- The roles in a groups
+CREATE TABLE group_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
+ UNIQUE (group_id, role_id)
+);
+
+
+-- List of attestations we've given out and need to renew
+CREATE TABLE group_attestations_renewals (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL
+);
+
+CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
+CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
+CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
+
+
+-- List of attestations we've received from remotes and are interested in.
+CREATE TABLE group_attestations_remote (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL,
+ attestation_json TEXT NOT NULL
+);
+
+CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
+CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
+CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+-- The group membership for the HS's users
+CREATE TABLE local_group_membership (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ membership TEXT NOT NULL,
+ is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
+ content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+ stream_id BIGINT NOT NULL,
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ content TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql
new file mode 100644
index 0000000000..e5ddc84df0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/profile_cache.sql
@@ -0,0 +1,28 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A subset of remote users whose profiles we have cached.
+-- Whether a user is in this table or not is defined by the storage function
+-- `is_subscribed_remote_profile_for_user`
+CREATE TABLE remote_profile_cache (
+ user_id TEXT NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ last_check BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
+CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
new file mode 100644
index 0000000000..68c48a89a9
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* we no longer use (or create) the refresh_tokens table */
+DROP TABLE IF EXISTS refresh_tokens;
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
new file mode 100644
index 0000000000..bb307889c1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop the unique constraint on deleted_pushers so that we can just insert
+-- into it rather than upserting.
+
+CREATE TABLE deleted_pushers2 (
+ stream_id BIGINT NOT NULL,
+ app_id TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ user_id TEXT NOT NULL
+);
+
+INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
+ SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
+
+DROP TABLE deleted_pushers;
+ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
+
+-- create the index after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
+
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql
new file mode 100644
index 0000000000..097679bc9a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/group_server.sql
@@ -0,0 +1,32 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups_new (
+ group_id TEXT NOT NULL,
+ name TEXT, -- the display name of the room
+ avatar_url TEXT,
+ short_description TEXT,
+ long_description TEXT,
+ is_public BOOL NOT NULL -- whether non-members can access group APIs
+);
+
+-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
+INSERT INTO groups_new
+ SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
+
+DROP TABLE groups;
+ALTER TABLE groups_new RENAME TO groups;
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
new file mode 100644
index 0000000000..bbfc7f5d1a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will recreate the
+-- local_media_repository_url_idx index.
+--
+-- We do this as a bg update not because it is a particularly onerous
+-- operation, but because we'd like it to be a partial index if possible, and
+-- the background_index_update code will understand whether we are on
+-- postgres or sqlite and behave accordingly.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_url_idx', '{}');
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
new file mode 100644
index 0000000000..cb0d5a2576
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- change the user_directory table to also cover global local user profiles
+-- rather than just profiles within specific rooms.
+
+CREATE TABLE user_directory2 (
+ user_id TEXT NOT NULL,
+ room_id TEXT,
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
+ SELECT user_id, room_id, display_name, avatar_url from user_directory;
+
+DROP TABLE user_directory;
+ALTER TABLE user_directory2 RENAME TO user_directory;
+
+-- create indexes after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql
new file mode 100644
index 0000000000..d9505f8da1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_typos.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this is just embarassing :|
+ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
+
+-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
+-- so is hopefully not going to block anyone else for that long...
+CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
+CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
+DROP INDEX users_in_pubic_room_room_idx;
+DROP INDEX users_in_pubic_room_user_idx;
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
new file mode 100644
index 0000000000..31d7a817eb
--- /dev/null
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ notif SMALLINT NOT NULL,
+ highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # if we already have some state groups, we want to start making new
+ # ones with a higher id.
+ cur.execute("SELECT max(id) FROM state_groups")
+ row = cur.fetchone()
+
+ if row[0] is None:
+ start_val = 1
+ else:
+ start_val = row[0] + 1
+
+ cur.execute(
+ "CREATE SEQUENCE state_group_id_seq START WITH %s",
+ (start_val, ),
+ )
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
new file mode 100644
index 0000000000..9248b0b24a
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_index', '{}');
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+ # remove duplicates from group_users & group_invites tables
+ cur.execute("""
+ DELETE FROM group_users WHERE %s NOT IN (
+ SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+ cur.execute("""
+ DELETE FROM group_invites WHERE %s NOT IN (
+ SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+
+ for statement in get_statements(FIX_INDEXES.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql
new file mode 100644
index 0000000000..ce26eaf0c9
--- /dev/null
+++ b/synapse/storage/schema/delta/48/groups_joinable.sql
@@ -0,0 +1,22 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This isn't a real ENUM because sqlite doesn't support it
+ * and we use a default of NULL for inserted rows and interpret
+ * NULL at the python store level as necessary so that existing
+ * rows are given the correct default policy.
+ */
+ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
file TEXT NOT NULL,
UNIQUE(version, file)
);
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+ module_name TEXT NOT NULL,
+ file TEXT NOT NULL,
+ UNIQUE(module_name, file)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 8f2b3c4435..6ba3e59889 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,28 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import namedtuple
+import logging
+import re
+import simplejson as json
+
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import logging
-import re
-import ujson as json
-
logger = logging.getLogger(__name__)
+SearchEntry = namedtuple('SearchEntry', [
+ 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+ 'origin_server_ts',
+])
+
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+ EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, hs):
- super(SearchStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(SearchStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@@ -42,23 +49,35 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order
)
- self.register_background_update_handler(
+
+ # we used to have a background update to turn the GIN index into a
+ # GIST one; we no longer do that (obviously) because we actually want
+ # a GIN index. However, it's possible that some people might still have
+ # the background update queued, so we register a handler to clear the
+ # background update.
+ self.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
- self._background_reindex_gist_search
+ )
+
+ self.register_background_update_handler(
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+ self._background_reindex_gin_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
+ # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+ "SELECT stream_ordering, event_id, room_id, type, json, "
+ " origin_server_ts FROM events"
+ " JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -67,6 +86,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+ # we could stream straight from the results into
+ # store_search_entries_txn with a generator function, but that
+ # would mean having two cursors open on the database at once.
+ # Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
if not rows:
return 0
@@ -79,9 +102,12 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
+ stream_ordering = row["stream_ordering"]
+ origin_server_ts = row["origin_server_ts"]
try:
- content = json.loads(row["content"])
- except:
+ event_json = json.loads(row["json"])
+ content = event_json["content"]
+ except Exception:
continue
if etype == "m.room.message":
@@ -93,6 +119,8 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name":
key = "content.name"
value = content["name"]
+ else:
+ raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -103,25 +131,16 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
- event_search_rows.append((event_id, room_id, key, value))
+ event_search_rows.append(SearchEntry(
+ key=key,
+ value=value,
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ origin_server_ts=origin_server_ts,
+ ))
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, vector)"
- " VALUES (?,?,?,to_tsvector('english', ?))"
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
- for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
- clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -145,25 +164,48 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result)
@defer.inlineCallbacks
- def _background_reindex_gist_search(self, progress, batch_size):
+ def _background_reindex_gin_search(self, progress, batch_size):
+ """This handles old synapses which used GIST indexes, if any;
+ converting them back to be GIN as per the actual schema.
+ """
+
def create_index(conn):
conn.rollback()
- conn.set_session(autocommit=True)
- c = conn.cursor()
- c.execute(
- "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
- " ON event_search USING GIST (vector)"
- )
+ # we have to set autocommit, because postgres refuses to
+ # CREATE INDEX CONCURRENTLY without it.
+ conn.set_session(autocommit=True)
- c.execute("DROP INDEX event_search_fts_idx")
+ try:
+ c = conn.cursor()
- conn.set_session(autocommit=False)
+ # if we skipped the conversion to GIST, we may already/still
+ # have an event_search_fts_idx; unfortunately postgres 9.4
+ # doesn't support CREATE INDEX IF EXISTS so we just catch the
+ # exception and ignore it.
+ import psycopg2
+ try:
+ c.execute(
+ "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+ " ON event_search USING GIN (vector)"
+ )
+ except psycopg2.ProgrammingError as e:
+ logger.warn(
+ "Ignoring error %r when trying to switch from GIST to GIN",
+ e
+ )
+
+ # we should now be able to delete the GIST index.
+ c.execute(
+ "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+ )
+ finally:
+ conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index)
- yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+ yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
defer.returnValue(1)
@defer.inlineCallbacks
@@ -242,6 +284,85 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows)
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),),
+ )
+
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ entry.stream_ordering, entry.origin_server_ts,
+ ) for entry in entries)
+
+ # inserts to a GIN index are normally batched up into a pending
+ # list, and then all committed together once the list gets to a
+ # certain size. The trouble with that is that postgres (pre-9.5)
+ # uses work_mem to determine the length of the list, and work_mem
+ # is typically very large.
+ #
+ # We therefore reduce work_mem while we do the insert.
+ #
+ # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+ # so doesn't suffer the same problem, but changing work_mem will
+ # be harmless)
+ #
+ # Note that we don't need to worry about restoring it on
+ # exception, because exceptions will cause the transaction to be
+ # rolled back, including the effects of the SET command.
+ #
+ # Also: we use SET rather than SET LOCAL because there's lots of
+ # other stuff going on in this transaction, which want to have the
+ # normal work_mem setting.
+
+ txn.execute("SET work_mem='256kB'")
+ txn.executemany(sql, args)
+ txn.execute("RESET work_mem")
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ ) for entry in entries)
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -407,7 +528,7 @@ class SearchStore(BackgroundUpdateStore):
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
- except:
+ except Exception:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927d7..9e6eaaa532 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
-class SignatureStore(SQLBaseStore):
- """Persistence for event signatures and hashes"""
-
+class SignatureWorkerStore(SQLBaseStore):
@cached()
def get_event_reference_hash(self, event_id):
- return self._get_event_reference_hashes_txn(event_id)
+ # This is a dummy function to allow get_event_reference_hashes
+ # to use its cache
+ raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
@@ -72,7 +72,11 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
- return {k: v for k, v in txn.fetchall()}
+ return {k: v for k, v in txn}
+
+
+class SignatureStore(SignatureWorkerStore):
+ """Persistence for event signatures and hashes"""
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 84482d8285..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,14 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches import intern_string
-from synapse.storage.engines import PostgresEngine
+from collections import namedtuple
+import logging
from twisted.internet import defer
-import logging
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.stringutils import to_ascii
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -28,45 +32,97 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class StateStore(SQLBaseStore):
- """ Keeps track of the state at a given event.
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+ __slots__ = []
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
+class StateGroupWorkerStore(SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- def __init__(self, hs):
- super(StateStore, self).__init__(hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
+ def __init__(self, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
+
+ @cached(max_entries=100000, iterable=True)
+ def get_current_state_ids(self, room_id):
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ deferred: dict of (type, state_key) -> event_id
+ """
+ def _get_current_state_ids_txn(txn):
+ txn.execute(
+ """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """,
+ (room_id,)
+ )
+
+ return {
+ (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+ }
+
+ return self.runInteraction(
+ "get_current_state_ids",
+ _get_current_state_ids_txn,
)
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
+
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+ def _get_state_group_delta_txn(txn):
+ prev_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self._simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcols=("type", "state_key", "event_id",)
+ )
+
+ return _GetStateGroupDelta(prev_group, {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in delta_ids
+ })
+ return self.runInteraction(
+ "get_state_group_delta",
+ _get_state_group_delta_txn,
)
@defer.inlineCallbacks
@@ -78,12 +134,26 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@defer.inlineCallbacks
+ def get_state_ids_for_group(self, state_group):
+ """Get the state IDs for the given state group
+
+ Args:
+ state_group (int)
+
+ Returns:
+ Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = yield self._get_state_for_groups((state_group,))
+
+ defer.returnValue(group_to_state[state_group])
+
+ @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
@@ -96,156 +166,21 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in group_to_ids.values()
- for ev_id in group_ids.values()
+ ev_id for group_ids in group_to_ids.itervalues()
+ for ev_id in group_ids.itervalues()
],
get_prev_content=False
)
defer.returnValue({
group: [
- state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ state_event_map[v] for v in event_id_map.itervalues()
+ if v in state_event_map
]
- for group, event_id_map in group_to_ids.items()
+ for group, event_id_map in group_to_ids.iteritems()
})
- def _have_persisted_state_group_txn(self, txn, state_group):
- txn.execute(
- "SELECT count(*) FROM state_groups WHERE id = ?",
- (state_group,)
- )
- row = txn.fetchone()
- return row and row[0]
-
- def _store_mult_state_groups_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- if context.current_state_ids is None:
- continue
-
- state_groups[event.event_id] = context.state_group
-
- if self._have_persisted_state_group_txn(txn, context.state_group):
- continue
-
- self._simple_insert_txn(
- txn,
- table="state_groups",
- values={
- "id": context.state_group,
- "room_id": event.room_id,
- "event_id": event.event_id,
- },
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if context.prev_group:
- potential_hops = self._count_state_group_hops_txn(
- txn, context.prev_group
- )
- if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": context.state_group,
- "prev_state_group": context.prev_group,
- },
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": context.state_group,
- "room_id": event.room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in context.delta_ids.items()
- ],
- )
- else:
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": context.state_group,
- "room_id": event.room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in context.current_state_ids.items()
- ],
- )
-
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {
- "state_group": state_group_id,
- "event_id": event_id,
- }
- for event_id, state_group_id in state_groups.items()
- ],
- )
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = ("""
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """)
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
- @cached(num_args=2, max_entries=100000, iterable=True)
- def _get_state_group_from_group(self, group, types):
- raise NotImplementedError()
-
- @cachedList(cached_method_name="_get_state_group_from_group",
- list_name="groups", num_args=2, inlineCallbacks=True)
+ @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""
@@ -305,6 +240,9 @@ class StateStore(SQLBaseStore):
(
"AND type = ? AND state_key = ?",
(etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
)
for etype, state_key in types
]
@@ -319,15 +257,24 @@ class StateStore(SQLBaseStore):
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
- rows = self.cursor_to_dict(txn)
- for row in rows:
- key = (row["type"], row["state_key"])
- results[group][key] = row["event_id"]
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
else:
+ where_args = []
+ where_clauses = []
+ wildcard_types = False
if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
+ for typ in types:
+ if typ[1] is None:
+ where_clauses.append("(type = ?)")
+ where_args.extend(typ[0])
+ wildcard_types = True
+ else:
+ where_clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend([typ[0], typ[1]])
+ where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@@ -344,23 +291,30 @@ class StateStore(SQLBaseStore):
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
- args.extend(i for typ in types for i in typ)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? %s" % (where_clause,),
args
)
- rows = txn.fetchall()
- results[group].update({
- (typ, state_key): event_id
- for typ, state_key, event_id in rows
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
- })
+ )
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if types is not None and len(results[group]) == len(types):
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ types is not None and
+ not wildcard_types and
+ len(results[group]) == len(types)
+ ):
break
next_group = self._simple_select_one_onecol_txn(
@@ -393,21 +347,21 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
- [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in group_to_state[group].items()
+ for k, v in group_to_state[group].iteritems()
if v in state_event_map
}
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -430,12 +384,12 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -474,8 +428,8 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=10000)
- def _get_state_group_for_event(self, room_id, event_id):
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
@@ -517,20 +471,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
+
for typ, state_key in types:
+ key = (typ, state_key)
if state_key is None:
type_to_key[typ] = None
- missing_types.add((typ, state_key))
+ missing_types.add(key)
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict_ids:
- missing_types.add((typ, state_key))
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types.add(key)
sentinel = object()
@@ -544,10 +500,10 @@ class StateStore(SQLBaseStore):
return True
return False
- got_all = not (missing_types or types is None)
+ got_all = is_all or not missing_types
return {
- k: v for k, v in state_dict_ids.items()
+ k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1])
}, missing_types, got_all
@@ -561,7 +517,7 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all
@@ -578,7 +534,7 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict_ids
@@ -606,46 +562,247 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
- for group, group_state_dict in group_to_state_dict.items():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {
- (intern_string(etype), intern_string(state_key)): None
- for (etype, state_key) in types
- }
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
-
- state_dict.update({
- (intern_string(k[0]), intern_string(k[1])): v
- for k, v in group_state_dict.items()
- })
+ for group, group_state_dict in group_to_state_dict.iteritems():
+ state_dict = results[group]
+
+ state_dict.update(
+ ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
+ for k, v in group_state_dict.iteritems()
+ )
self._state_group_cache.update(
cache_seq_num,
key=group,
value=state_dict,
full=(types is None),
+ known_absent=types,
)
- # Remove all the entries with None values. The None values were just
- # used for bookkeeping in the cache.
- for group, state_dict in results.items():
- results[group] = {
- key: event_id
- for key, event_id in state_dict.items()
- if event_id
- }
-
defer.returnValue(results)
- def get_next_state_group(self):
- return self._state_groups_id_gen.get_next()
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "id": state_group,
+ "room_id": room_id,
+ "event_id": event_id,
+ },
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(
+ txn, prev_group
+ )
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self._simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in delta_ids.iteritems()
+ ],
+ )
+ else:
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in current_state_ids.iteritems()
+ ],
+ )
+
+ # Prefill the state group cache with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_state_ids),
+ full=True,
+ )
+
+ return state_group
+
+ return self.runInteraction("store_state_group", _store_state_group_txn)
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = ("""
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """)
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+ def __init__(self, db_conn, hs):
+ super(StateStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME,
+ self._background_index_state,
+ )
+ self.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+
+ def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {
+ "state_group": state_group_id,
+ "event_id": event_id,
+ }
+ for event_id, state_group_id in state_groups.iteritems()
+ ],
+ )
+
+ for event_id, state_group_id in state_groups.iteritems():
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
@@ -727,7 +884,7 @@ class StateStore(SQLBaseStore):
# of keys
delta_state = {
- key: value for key, value in curr_state.items()
+ key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value
}
@@ -767,7 +924,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_state.items()
+ for key, state_id in delta_state.iteritems()
],
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 200d124632..f0784ba137 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,15 +35,20 @@ what sort order was used:
from twisted.internet import defer
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+
from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+import abc
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -143,81 +148,41 @@ def filter_to_clause(event_filter):
return " AND ".join(clauses), args
-class StreamStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
- # NB this lives here instead of appservice.py so we can reuse the
- # 'private' StreamToken class in this file.
- if limit:
- limit = max(limit, MAX_STREAM_SIZE)
- else:
- limit = MAX_STREAM_SIZE
-
- # From and to keys should be integers from ordering.
- from_id = RoomStreamToken.parse_stream_token(from_key)
- to_id = RoomStreamToken.parse_stream_token(to_key)
-
- if from_key == to_key:
- defer.returnValue(([], to_key))
- return
-
- # select all the events between from/to with a sensible limit
- sql = (
- "SELECT e.event_id, e.room_id, e.type, s.state_key, "
- "e.stream_ordering FROM events AS e "
- "LEFT JOIN state_events as s ON "
- "e.event_id = s.event_id "
- "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
- "ORDER BY stream_ordering ASC LIMIT %(limit)d "
- ) % {
- "limit": limit
- }
-
- def f(txn):
- # pull out all the events between the tokens
- txn.execute(sql, (from_id.stream, to_id.stream,))
- rows = self.cursor_to_dict(txn)
-
- # Logic:
- # - We want ALL events which match the AS room_id regex
- # - We want ALL events which match the rooms represented by the AS
- # room_alias regex
- # - We want ALL events for rooms that AS users have joined.
- # This is currently supported via get_app_service_rooms (which is
- # used for the Notifier listener rooms). We can't reasonably make a
- # SQL query for these room IDs, so we'll pull all the events between
- # from/to and filter in python.
- rooms_for_as = self._get_app_service_rooms_txn(txn, service)
- room_ids_for_as = [r.room_id for r in rooms_for_as]
-
- def app_service_interested(row):
- if row["room_id"] in room_ids_for_as:
- return True
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+ which can be called in the initializer.
+ """
- if row["type"] == EventTypes.Member:
- if service.is_interested_in_user(row.get("state_key")):
- return True
- return False
+ __metaclass__ = abc.ABCMeta
- return [r for r in rows if app_service_interested(r)]
+ def __init__(self, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(db_conn, hs)
- rows = yield self.runInteraction("get_appservice_room_stream", f)
-
- ret = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
+ events_max = self.get_room_max_stream_ordering()
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._stream_order_on_start = self.get_room_max_stream_ordering()
- if rows:
- key = "s%d" % max(r["stream_ordering"] for r in rows)
- else:
- # Assume we didn't get anything because there was nothing to
- # get.
- key = to_key
+ @abc.abstractmethod
+ def get_room_max_stream_ordering(self):
+ raise NotImplementedError()
- defer.returnValue((ret, key))
+ @abc.abstractmethod
+ def get_room_min_stream_ordering(self):
+ raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -233,13 +198,14 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
- res = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.get_room_events_stream_for_room)(
+ for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
+ res = yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(
+ self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
- ]))
+ ], consumeErrors=True))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -381,88 +347,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
- args = [False, room_id]
- if direction == 'b':
- order = "DESC"
- bounds = upper_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, lower_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
- else:
- order = "ASC"
- bounds = lower_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, upper_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
-
- filter_clause, filter_args = filter_to_clause(event_filter)
-
- if filter_clause:
- bounds += " AND " + filter_clause
- args.extend(filter_args)
-
- if int(limit) > 0:
- args.append(int(limit))
- limit_str = " LIMIT ?"
- else:
- limit_str = ""
-
- sql = (
- "SELECT * FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s %(limit)s"
- ) % {
- "bounds": bounds,
- "order": order,
- "limit": limit_str
- }
-
- def f(txn):
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
-
- if rows:
- topo = rows[-1]["topological_ordering"]
- toke = rows[-1]["stream_ordering"]
- if direction == 'b':
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- toke -= 1
- next_token = str(RoomStreamToken(topo, toke))
- else:
- # TODO (erikj): We should work out what to do here instead.
- next_token = to_key if to_key else from_key
-
- return rows, next_token,
-
- rows, token = yield self.runInteraction("paginate_room_events", f)
-
- events = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
- )
-
- self._set_before_and_after(events, rows)
-
- defer.returnValue((events, token))
-
- @defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token, from_token
@@ -534,6 +418,33 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn
)
+ def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+ """Gets details of the first event in a room at or after a stream ordering
+
+ Args:
+ room_id (str):
+ stream_ordering (int):
+
+ Returns:
+ Deferred[(int, int, str)]:
+ (stream ordering, topological ordering, event_id)
+ """
+ def _f(txn):
+ sql = (
+ "SELECT stream_ordering, topological_ordering, event_id"
+ " FROM events"
+ " WHERE room_id = ? AND stream_ordering >= ?"
+ " AND NOT outlier"
+ " ORDER BY stream_ordering"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (room_id, stream_ordering, ))
+ return txn.fetchone()
+
+ return self.runInteraction(
+ "get_room_event_after_stream_ordering", _f,
+ )
+
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
"""Returns the current token for rooms stream.
@@ -542,7 +453,7 @@ class StreamStore(SQLBaseStore):
`room_id` causes it to return the current room specific topological
token.
"""
- token = yield self._stream_id_gen.get_current_token()
+ token = yield self.get_room_max_stream_ordering()
if room_id is None:
defer.returnValue("s%d" % (token,))
else:
@@ -552,12 +463,6 @@ class StreamStore(SQLBaseStore):
)
defer.returnValue("t%d-%d" % (topo, token))
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
-
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
@@ -829,3 +734,96 @@ class StreamStore(SQLBaseStore):
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+
+ def has_room_changed_since(self, room_id, stream_id):
+ return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+
+class StreamStore(StreamWorkerStore):
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
+
+ @defer.inlineCallbacks
+ def paginate_room_events(self, room_id, from_key, to_key=None,
+ direction='b', limit=-1, event_filter=None):
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ args = [False, room_id]
+ if direction == 'b':
+ order = "DESC"
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
+ if to_key:
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+ else:
+ order = "ASC"
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
+ if to_key:
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
+
+ if int(limit) > 0:
+ args.append(int(limit))
+ limit_str = " LIMIT ?"
+ else:
+ limit_str = ""
+
+ sql = (
+ "SELECT * FROM events"
+ " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+ " ORDER BY topological_ordering %(order)s,"
+ " stream_ordering %(order)s %(limit)s"
+ ) % {
+ "bounds": bounds,
+ "order": order,
+ "limit": limit_str
+ }
+
+ def f(txn):
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+
+ if rows:
+ topo = rows[-1]["topological_ordering"]
+ toke = rows[-1]["stream_ordering"]
+ if direction == 'b':
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ toke -= 1
+ next_token = str(RoomStreamToken(topo, toke))
+ else:
+ # TODO (erikj): We should work out what to do here instead.
+ next_token = to_key if to_key else from_key
+
+ return rows, next_token,
+
+ rows, token = yield self.runInteraction("paginate_room_events", f)
+
+ events = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(events, rows)
+
+ defer.returnValue((events, token))
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa59b..6671d3cfca 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,25 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-import ujson as json
+import simplejson as json
import logging
-logger = logging.getLogger(__name__)
+from six.moves import range
+logger = logging.getLogger(__name__)
-class TagsStore(SQLBaseStore):
- def get_max_account_data_stream_id(self):
- """Get the current max stream id for the private user data stream
-
- Returns:
- A deferred int.
- """
- return self._account_data_id_gen.get_current_token()
+class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@@ -95,7 +91,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
- for tag, content in txn.fetchall():
+ for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
@@ -104,7 +100,7 @@ class TagsStore(SQLBaseStore):
batch_size = 50
results = []
- for i in xrange(0, len(tag_ids), batch_size):
+ for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
@@ -132,7 +128,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
- room_ids = [row[0] for row in txn.fetchall()]
+ room_ids = [row[0] for row in txn]
return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
@@ -170,6 +166,8 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
+
+class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 809fdd311f..f825264ea9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json
from collections import namedtuple
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, hs):
- super(TransactionStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(TransactionStore, self).__init__(db_conn, hs)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
new file mode 100644
index 0000000000..d6e289ffbe
--- /dev/null
+++ b/synapse/storage/user_directory.py
@@ -0,0 +1,764 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id, get_localpart_from_id
+
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryStore(SQLBaseStore):
+ @cachedInlineCallbacks(cache_context=True)
+ def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ """Check if the room is either world_readable or publically joinable
+ """
+ current_state_ids = yield self.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_id:
+ join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ if join_rule_ev:
+ if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
+ defer.returnValue(True)
+
+ hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_id:
+ hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ if hist_vis_ev:
+ if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def add_users_to_public_room(self, room_id, user_ids):
+ """Add user to the list of users in public rooms
+
+ Args:
+ room_id (str): A room_id that all users are in that is world_readable
+ or publically joinable
+ user_ids (list(str)): Users to add
+ """
+ yield self._simple_insert_many(
+ table="users_in_public_rooms",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+ for user_id in user_ids
+ ],
+ desc="add_users_to_public_room"
+ )
+ for user_id in user_ids:
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def add_profiles_to_user_dir(self, room_id, users_with_profile):
+ """Add profiles to the user directory
+
+ Args:
+ room_id (str): A room_id that all users are joined to
+ users_with_profile (dict): Users to add to directory in the form of
+ mapping of user_id -> ProfileInfo
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ args = (
+ (
+ user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ profile.display_name,
+ )
+ for user_id, profile in users_with_profile.iteritems()
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT INTO user_directory_search(user_id, value)
+ VALUES (?,?)
+ """
+ args = (
+ (
+ user_id,
+ "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+ )
+ for user_id, p in users_with_profile.iteritems()
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ def _add_profiles_to_user_dir_txn(txn):
+ txn.executemany(sql, args)
+ self._simple_insert_many_txn(
+ txn,
+ table="user_directory",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ "display_name": profile.display_name,
+ "avatar_url": profile.avatar_url,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ ]
+ )
+ for user_id in users_with_profile:
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+
+ return self.runInteraction(
+ "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_user_dir(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_user_dir",
+ )
+ self.get_user_in_directory.invalidate((user_id,))
+
+ def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
+ def _update_profile_in_user_dir_txn(txn):
+ new_entry = self._simple_upsert_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ insertion_values={"room_id": room_id},
+ values={"display_name": display_name, "avatar_url": avatar_url},
+ lock=False, # We're only inserter
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the localpart most highly, then display name and finally
+ # server name
+ if new_entry:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id, get_localpart_from_id(user_id),
+ get_domain_from_id(user_id), display_name,
+ )
+ )
+ else:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ display_name, user_id,
+ )
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ value = "%s %s" % (user_id, display_name,) if display_name else user_id
+ self._simple_upsert_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ values={"value": value},
+ lock=False, # We're only inserter
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.runInteraction(
+ "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_public_user_list(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_public_user_list",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_user_in_public_room.invalidate, (user_id,)
+ )
+ return self.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn,
+ )
+
+ @defer.inlineCallbacks
+ def remove_from_user_in_public_room(self, user_id):
+ yield self._simple_delete(
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id},
+ desc="remove_from_user_in_public_room",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def get_users_in_public_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ return self._simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_public_due_to_room",
+ )
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ user_ids_dir = yield self._simple_select_onecol(
+ table="user_directory",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_pub = yield self._simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share = yield self._simple_select_onecol(
+ table="users_who_share_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_dir)
+ user_ids.update(user_ids_pub)
+ user_ids.update(user_ids_share)
+
+ defer.returnValue(user_ids)
+
+ @defer.inlineCallbacks
+ def get_all_rooms(self):
+ """Get all room_ids we've ever known about, in ascending order of "size"
+ """
+ sql = """
+ SELECT room_id FROM current_state_events
+ GROUP BY room_id
+ ORDER BY count(*) ASC
+ """
+ rows = yield self._execute("get_all_rooms", None, sql)
+ defer.returnValue([room_id for room_id, in rows])
+
+ @defer.inlineCallbacks
+ def get_all_local_users(self):
+ """Get all local users
+ """
+ sql = """
+ SELECT name FROM users
+ """
+ rows = yield self._execute("get_all_local_users", None, sql)
+ defer.returnValue([name for name, in rows])
+
+ def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
+ """Insert entries into the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _add_users_who_share_room_txn(txn):
+ self._simple_insert_many_txn(
+ txn,
+ table="users_who_share_rooms",
+ values=[
+ {
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ "room_id": room_id,
+ "share_private": share_private,
+ }
+ for user_id, other_user_id in user_id_tuples
+ ],
+ )
+ for user_id, other_user_id in user_id_tuples:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "add_users_who_share_room", _add_users_who_share_room_txn
+ )
+
+ def update_users_who_share_room(self, room_id, share_private, user_id_sets):
+ """Updates entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _update_users_who_share_room_txn(txn):
+ sql = """
+ UPDATE users_who_share_rooms
+ SET room_id = ?, share_private = ?
+ WHERE user_id = ? AND other_user_id = ?
+ """
+ txn.executemany(
+ sql,
+ (
+ (room_id, share_private, uid, oid)
+ for uid, oid in user_id_sets
+ )
+ )
+ for user_id, other_user_id in user_id_sets:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "update_users_who_share_room", _update_users_who_share_room_txn
+ )
+
+ def remove_user_who_share_room(self, user_id, other_user_id):
+ """Deletes entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _remove_user_who_share_room_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ )
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+
+ return self.runInteraction(
+ "remove_user_who_share_room", _remove_user_who_share_room_txn
+ )
+
+ @cached(max_entries=500000)
+ def get_if_users_share_a_room(self, user_id, other_user_id):
+ """Gets if users share a room.
+
+ Args:
+ user_id (str): Must be a local user_id
+ other_user_id (str)
+
+ Returns:
+ bool|None: None if they don't share a room, otherwise whether they
+ share a private room or not.
+ """
+ return self._simple_select_one_onecol(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ retcol="share_private",
+ allow_none=True,
+ desc="get_if_users_share_a_room",
+ )
+
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ def get_users_who_share_room_from_dir(self, user_id):
+ """Returns the set of users who share a room with `user_id`
+
+ Args:
+ user_id(str): Must be a local user
+
+ Returns:
+ dict: user_id -> share_private mapping
+ """
+ rows = yield self._simple_select_list(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("other_user_id", "share_private",),
+ desc="get_users_who_share_room_with_user",
+ )
+
+ defer.returnValue({
+ row["other_user_id"]: row["share_private"]
+ for row in rows
+ })
+
+ def get_users_in_share_dir_with_room_id(self, user_id, room_id):
+ """Get all user tuples that are in the users_who_share_rooms due to the
+ given room_id.
+
+ Returns:
+ [(user_id, other_user_id)]: where one of the two will match the given
+ user_id.
+ """
+ sql = """
+ SELECT user_id, other_user_id FROM users_who_share_rooms
+ WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
+ """
+ return self._execute(
+ "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_in_common_for_users(self, user_id, other_user_id):
+ """Given two user_ids find out the list of rooms they share.
+ """
+ sql = """
+ SELECT room_id FROM (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) AS f1 INNER JOIN (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) f2 USING (room_id)
+ """
+
+ rows = yield self._execute(
+ "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+ )
+
+ defer.returnValue([room_id for room_id, in rows])
+
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_public_rooms")
+ txn.execute("DELETE FROM users_who_share_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+ txn.call_after(self.get_user_in_public_room.invalidate_all)
+ txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
+ txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+ return self.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self._simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id", "display_name", "avatar_url",),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ @cached()
+ def get_user_in_public_room(self, user_id):
+ return self._simple_select_one(
+ table="users_in_public_rooms",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id",),
+ allow_none=True,
+ desc="get_user_in_public_room",
+ )
+
+ def get_user_directory_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="get_user_directory_stream_pos",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ @defer.inlineCallbacks
+ def search_user_dir(self, user_id, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+
+ if self.hs.config.user_directory_search_all_users:
+ # make s.user_id null to keep the ordering algorithm happy
+ join_clause = """
+ CROSS JOIN (SELECT NULL as user_id) AS s
+ """
+ join_args = ()
+ where_clause = "1=1"
+ else:
+ join_clause = """
+ LEFT JOIN users_in_public_rooms AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ """
+ join_args = (user_id,)
+ where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
+ if isinstance(self.database_engine, PostgresEngine):
+ full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+
+ # We order by rank and then if they have profile info
+ # The ranking algorithm is hand tweaked for "best" results. Broadly
+ # the idea is we give a higher weight to exact matches.
+ # The array of numbers are the weights for the various part of the
+ # search: (domain, _, display name, localpart)
+ sql = """
+ SELECT d.user_id AS user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ %s
+ WHERE
+ %s
+ AND vector @@ to_tsquery('english', ?)
+ ORDER BY
+ (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+ * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (
+ 3 * ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ + ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ )
+ DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """ % (join_clause, where_clause)
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ search_query = _parse_query_sqlite(search_term)
+
+ sql = """
+ SELECT d.user_id AS user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ %s
+ WHERE
+ %s
+ AND value MATCH ?
+ ORDER BY
+ rank(matchinfo(user_directory_search)) DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """ % (join_clause, where_clause)
+ args = join_args + (search_query, limit + 1)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ results = yield self._execute(
+ "search_user_dir", self.cursor_to_dict, sql, *args
+ )
+
+ limited = len(results) > limit
+
+ defer.returnValue({
+ "limited": limited,
+ "results": results,
+ })
+
+
+def _parse_query_sqlite(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+
+ We specifically add both a prefix and non prefix matching term so that
+ exact matches get ranked higher.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+
+
+def _parse_query_postgres(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+ exact = " & ".join("%s" % (result,) for result in results)
+ prefix = " & ".join("%s:*" % (result,) for result in results)
+
+ return both, exact, prefix
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 46cf93ff87..95031dc9ec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1):
+ """
+
+ Args:
+ db_conn (object):
+ table (str):
+ column (str):
+ step (int):
+
+ Returns:
+ int
+ """
cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
+
+ Returns:
+ int
"""
with self._lock:
if self._unfinished_ids:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 4f089bfb94..ca78e551cb 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -80,13 +80,13 @@ class PaginationConfig(object):
from_tok = None # For backwards compat.
elif from_tok:
from_tok = StreamToken.from_string(from_tok)
- except:
+ except Exception:
raise SynapseError(400, "'from' paramater is invalid")
try:
if to_tok:
to_tok = StreamToken.from_string(to_tok)
- except:
+ except Exception:
raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None)
@@ -98,7 +98,7 @@ class PaginationConfig(object):
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
- except:
+ except Exception:
logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.")
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..f03ad99118 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -45,6 +45,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -65,6 +66,7 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
@@ -73,6 +75,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -93,5 +96,6 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..cc7c182a78 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -12,26 +12,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import string
from synapse.api.errors import SynapseError
from collections import namedtuple
-Requester = namedtuple("Requester", [
+class Requester(namedtuple("Requester", [
"user", "access_token_id", "is_guest", "device_id", "app_service",
-])
-"""
-Represents the user making a request
+])):
+ """
+ Represents the user making a request
-Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
- request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
-"""
+ Attributes:
+ user (UserID): id of the user making the request
+ access_token_id (int|None): *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest (bool): True if the user making this request is a guest user
+ device_id (str|None): device_id which was set at authentication time
+ app_service (ApplicationService|None): the AS requesting on behalf of the user
+ """
+
+ def serialize(self):
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Returns:
+ dict
+ """
+ return {
+ "user_id": self.user.to_string(),
+ "access_token_id": self.access_token_id,
+ "is_guest": self.is_guest,
+ "device_id": self.device_id,
+ "app_server_id": self.app_service.id if self.app_service else None,
+ }
+
+ @staticmethod
+ def deserialize(store, input):
+ """Converts a dict that was produced by `serialize` back into a
+ Requester.
+
+ Args:
+ store (DataStore): Used to convert AS ID to AS object
+ input (dict): A dict produced by `serialize`
+
+ Returns:
+ Requester
+ """
+ appservice = None
+ if input["app_server_id"]:
+ appservice = store.get_app_service_by_id(input["app_server_id"])
+
+ return Requester(
+ user=UserID.from_string(input["user_id"]),
+ access_token_id=input["access_token_id"],
+ is_guest=input["is_guest"],
+ device_id=input["device_id"],
+ app_service=appservice,
+ )
def create_requester(user_id, access_token_id=None, is_guest=False,
@@ -56,10 +96,17 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
def get_domain_from_id(string):
- try:
- return string.split(":", 1)[1]
- except IndexError:
+ idx = string.find(":")
+ if idx == -1:
+ raise SynapseError(400, "Invalid ID: %r" % (string,))
+ return string[idx + 1:]
+
+
+def get_localpart_from_id(string):
+ idx = string.find(":")
+ if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
+ return string[1:idx]
class DomainSpecificString(
@@ -119,14 +166,10 @@ class DomainSpecificString(
try:
cls.from_string(s)
return True
- except:
+ except Exception:
return False
- __str__ = to_string
-
- @classmethod
- def create(cls, localpart, domain,):
- return cls(localpart=localpart, domain=domain)
+ __repr__ = to_string
class UserID(DomainSpecificString):
@@ -149,6 +192,43 @@ class EventID(DomainSpecificString):
SIGIL = "$"
+class GroupID(DomainSpecificString):
+ """Structure representing a group ID."""
+ SIGIL = "+"
+
+ @classmethod
+ def from_string(cls, s):
+ group_id = super(GroupID, cls).from_string(s)
+ if not group_id.localpart:
+ raise SynapseError(
+ 400,
+ "Group ID cannot be empty",
+ )
+
+ if contains_invalid_mxid_characters(group_id.localpart):
+ raise SynapseError(
+ 400,
+ "Group ID can only contain characters a-z, 0-9, or '=_-./'",
+ )
+
+ return group_id
+
+
+mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
+
+
+def contains_invalid_mxid_characters(localpart):
+ """Check for characters not allowed in an mxid or groupid localpart
+
+ Args:
+ localpart (basestring): the localpart to be checked
+
+ Returns:
+ bool: True if there are any naughty characters
+ """
+ return any(c not in mxid_localpart_allowed_characters for c in localpart)
+
+
class StreamToken(
namedtuple("Token", (
"room_key",
@@ -159,6 +239,7 @@ class StreamToken(
"push_rules_key",
"to_device_key",
"device_list_key",
+ "groups_key",
))
):
_SEPARATOR = "_"
@@ -171,7 +252,7 @@ class StreamToken(
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys)
- except:
+ except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
@@ -197,6 +278,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
+ or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):
@@ -216,9 +298,7 @@ class StreamToken(
return self
def copy_and_replace(self, key, new_value):
- d = self._asdict()
- d[key] = new_value
- return StreamToken(**d)
+ return self._replace(**{key: new_value})
StreamToken.START = StreamToken(
@@ -258,7 +338,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
- except:
+ except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -267,7 +347,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
- except:
+ except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..814a7bf71b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -24,11 +23,6 @@ import logging
logger = logging.getLogger(__name__)
-class DeferredTimedOutError(SynapseError):
- def __init__(self):
- super(SynapseError).__init__(504, "Timed out")
-
-
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
@@ -59,9 +53,9 @@ class Clock(object):
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
"""
- l = task.LoopingCall(f)
- l.start(msec / 1000.0, now=False)
- return l
+ call = task.LoopingCall(f)
+ call.start(msec / 1000.0, now=False)
+ return call
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
@@ -82,54 +76,6 @@ class Clock(object):
def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel()
- except:
+ except Exception:
if not ignore_errs:
raise
-
- def time_bound_deferred(self, given_deferred, time_out):
- if given_deferred.called:
- return given_deferred
-
- ret_deferred = defer.Deferred()
-
- def timed_out_fn():
- try:
- ret_deferred.errback(DeferredTimedOutError())
- except:
- pass
-
- try:
- given_deferred.cancel()
- except:
- pass
-
- timer = None
-
- def cancel(res):
- try:
- self.cancel_call_later(timer)
- except:
- pass
- return res
-
- ret_deferred.addBoth(cancel)
-
- def sucess(res):
- try:
- ret_deferred.callback(res)
- except:
- pass
-
- return res
-
- def err(res):
- try:
- ret_deferred.errback(res)
- except:
- pass
-
- given_deferred.addCallbacks(callback=sucess, errback=err)
-
- timer = self.call_later(time_out, timed_out_fn)
-
- return ret_deferred
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,16 +15,20 @@
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
from .logcontext import (
- PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
-from synapse.util import unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError
from contextlib import contextmanager
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -53,6 +57,11 @@ class ObservableDeferred(object):
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
+
+ NB that it does not attempt to do anything with logcontexts; in general
+ you should probably make_deferred_yieldable the deferreds
+ returned by `observe`, and ensure that the original deferred runs its
+ callbacks in the sentinel logcontext.
"""
__slots__ = ["_deferred", "_observers", "_result"]
@@ -68,7 +77,7 @@ class ObservableDeferred(object):
try:
# TODO: Handle errors here.
self._observers.pop().callback(r)
- except:
+ except Exception:
pass
return r
@@ -78,7 +87,7 @@ class ObservableDeferred(object):
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
- except:
+ except Exception:
pass
if consumeErrors:
@@ -89,6 +98,11 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
def observe(self):
+ """Observe the underlying deferred.
+
+ Can return either a deferred if the underlying deferred is still pending
+ (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ """
if not self._result:
d = defer.Deferred()
@@ -101,7 +115,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return defer.succeed(res) if success else defer.fail(res)
+ return res if success else defer.fail(res)
def observers(self):
return self._observers
@@ -146,13 +160,13 @@ def concurrently_execute(func, args, limit):
def _concurrently_execute_inner():
try:
while True:
- yield func(it.next())
+ yield func(next(it))
except StopIteration:
pass
- return preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(_concurrently_execute_inner)()
- for _ in xrange(limit)
+ return logcontext.make_deferred_yieldable(defer.gatherResults([
+ run_in_background(_concurrently_execute_inner)
+ for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
@@ -195,10 +209,29 @@ class Linearizer(object):
try:
with PreserveLoggingContext():
yield current_defer
- except:
+ except Exception:
logger.exception("Unexpected exception in Linearizer")
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ logger.info("Acquired linearizer lock %r for key %r", self.name,
+ key)
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # (There's no particular need for it to happen before we return
+ # the context manager, but it needs to happen while we hold the
+ # lock, and the context manager's exit code must be synchronous,
+ # so actually this is the only sensible place.
+ yield run_on_reactor()
+
+ else:
+ logger.info("Acquired uncontended linearizer lock %r for key %r",
+ self.name, key)
@contextmanager
def _ctx_manager():
@@ -206,7 +239,8 @@ class Linearizer(object):
yield
finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
@@ -248,8 +282,13 @@ class Limiter(object):
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
+
+ logger.info("Waiting to acquire limiter lock for key %r", key)
with PreserveLoggingContext():
yield new_defer
+ logger.info("Acquired limiter lock for key %r", key)
+ else:
+ logger.info("Acquired uncontended limiter lock for key %r", key)
entry[0] += 1
@@ -258,16 +297,21 @@ class Limiter(object):
try:
yield
finally:
+ logger.info("Releasing limiter lock for key %r", key)
+
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry[0] -= 1
- try:
- entry[1].pop(0).callback(None)
- except IndexError:
- # If nothing else is executing for this key then remove it
- # from the map
- if entry[0] == 0:
- self.key_to_defer.pop(key, None)
+
+ if entry[1]:
+ next_def = entry[1].pop(0)
+
+ with PreserveLoggingContext():
+ next_def.callback(None)
+ elif entry[0] == 0:
+ # We were the last thing for this key: remove it from the
+ # map.
+ del self.key_to_defer[key]
defer.returnValue(_ctx_manager())
@@ -311,7 +355,7 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
- yield curr_writer
+ yield make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@@ -340,7 +384,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
- yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
+ yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
@@ -352,3 +396,68 @@ class ReadWriteLock(object):
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+ """
+ This error is raised by default when a L{Deferred} times out.
+ """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+ """
+ Add a timeout to a deferred by scheduling it to be cancelled after
+ timeout seconds.
+
+ This is essentially a backport of deferred.addTimeout, which was introduced
+ in twisted 16.5.
+
+ If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+ unless a cancelable function was passed to its initialization or unless
+ a different on_timeout_cancel callable is provided.
+
+ Args:
+ deferred (defer.Deferred): deferred to be timed out
+ timeout (Number): seconds to time out after
+
+ on_timeout_cancel (callable): A callable which is called immediately
+ after the deferred times out, and not if this deferred is
+ otherwise cancelled before the timeout.
+
+ It takes an arbitrary value, which is the value of the deferred at
+ that exact point in time (probably a CancelledError Failure), and
+ the timeout.
+
+ The default callable (if none is provided) will translate a
+ CancelledError Failure into a DeferredTimeoutError.
+ """
+ timed_out = [False]
+
+ def time_it_out():
+ timed_out[0] = True
+ deferred.cancel()
+
+ delayed_call = reactor.callLater(timeout, time_it_out)
+
+ def convert_cancelled(value):
+ if timed_out[0]:
+ to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+ return to_call(value, timeout)
+ return value
+
+ deferred.addBoth(convert_cancelled)
+
+ def cancel_timeout(result):
+ # stop the pending call to cancel the deferred if it's been fired
+ if delayed_call.active():
+ delayed_call.cancel()
+ return result
+
+ deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise DeferredTimeoutError(timeout, "Deferred")
+ return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8a7774a88e..4adae96681 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,12 +14,9 @@
# limitations under the License.
import synapse.metrics
-from lrucache import LruCache
import os
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-DEBUG_CACHES = False
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
@@ -40,10 +37,6 @@ def register_cache(name, cache):
)
-_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
-_stirng_cache_metrics = register_cache("string_cache", _string_cache)
-
-
KNOWN_KEYS = {
key: key for key in
(
@@ -67,14 +60,16 @@ KNOWN_KEYS = {
def intern_string(string):
- """Takes a (potentially) unicode string and interns using custom cache
+ """Takes a (potentially) unicode string and interns it if it's ascii
"""
- new_str = _string_cache.setdefault(string, string)
- if new_str is string:
- _stirng_cache_metrics.inc_hits()
- else:
- _stirng_cache_metrics.inc_misses()
- return new_str
+ if string is None:
+ return None
+
+ try:
+ string = string.encode("ascii")
+ return intern(string)
+ except UnicodeEncodeError:
+ return string
def intern_dict(dictionary):
@@ -87,13 +82,9 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
- intern_str_keys = ("event_id", "room_id")
- intern_unicode_keys = ("sender", "user_id", "type", "state_key")
-
- if key in intern_str_keys:
- return intern(value.encode('ascii'))
+ intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
- if key in intern_unicode_keys:
+ if key in intern_keys:
return intern_string(value)
return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..68285a7594 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,19 +16,17 @@
import logging
from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
+from synapse.util.stringutils import to_ascii
-from . import DEBUG_CACHES, register_cache
+from . import register_cache
from twisted.internet import defer
from collections import namedtuple
-import os
import functools
import inspect
import threading
@@ -39,17 +38,13 @@ logger = logging.getLogger(__name__)
_CacheSentinel = object()
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class CacheEntry(object):
__slots__ = [
- "deferred", "sequence", "callbacks", "invalidated"
+ "deferred", "callbacks", "invalidated"
]
- def __init__(self, deferred, sequence, callbacks):
+ def __init__(self, deferred, callbacks):
self.deferred = deferred
- self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
@@ -67,7 +62,6 @@ class Cache(object):
"max_entries",
"name",
"keylen",
- "sequence",
"thread",
"metrics",
"_pending_deferred_cache",
@@ -79,15 +73,18 @@ class Cache(object):
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
- size_callback=(lambda d: len(d.result)) if iterable else None,
+ size_callback=(lambda d: len(d)) if iterable else None,
+ evicted_callback=self._on_evicted,
)
self.name = name
self.keylen = keylen
- self.sequence = 0
self.thread = None
self.metrics = register_cache(name, self.cache)
+ def _on_evicted(self, evicted_count):
+ self.metrics.inc_evictions(evicted_count)
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -98,21 +95,34 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel, callback=None):
+ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+ """Looks the key up in the caches.
+
+ Args:
+ key(tuple)
+ default: What is returned if key is not in the caches. If not
+ specified then function throws KeyError instead
+ callback(fn): Gets called when the entry in the cache is invalidated
+ update_metrics (bool): whether to update the cache hit rate metrics
+
+ Returns:
+ Either a Deferred or the raw result
+ """
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- if val.sequence == self.sequence:
- val.callbacks.update(callbacks)
+ val.callbacks.update(callbacks)
+ if update_metrics:
self.metrics.inc_hits()
- return val.deferred
+ return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
- self.metrics.inc_misses()
+ if update_metrics:
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@@ -124,12 +134,9 @@ class Cache(object):
self.check_thread()
entry = CacheEntry(
deferred=value,
- sequence=self.sequence,
callbacks=callbacks,
)
- entry.callbacks.update(callbacks)
-
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
@@ -137,13 +144,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
- if self.sequence == entry.sequence:
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- self.cache.set(key, entry.deferred, entry.callbacks)
- else:
- entry.invalidate()
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, result, entry.callbacks)
else:
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
entry.invalidate()
return result
@@ -155,29 +174,29 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
+ self.cache.pop(key, None)
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
if entry:
entry.invalidate()
- self.cache.pop(key, None)
-
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
- self.sequence += 1
self.cache.del_multi(key)
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
@@ -185,11 +204,73 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
- self.sequence += 1
self.cache.clear()
+ for entry in self._pending_deferred_cache.itervalues():
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
+
+
+class _CacheDescriptorBase(object):
+ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+ arg_spec = inspect.getargspec(orig)
+ all_args = arg_spec.args
-class CacheDescriptor(object):
+ if "cache_context" in all_args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ if num_args is None:
+ num_args = len(all_args) - 1
+ if cache_context:
+ num_args -= 1
+
+ if len(all_args) < num_args + 1:
+ raise Exception(
+ "Not enough explicit positional arguments to key off for %r: "
+ "got %i args, but wanted %i. (@cached cannot key off *args or "
+ "**kwargs)"
+ % (orig.__name__, len(all_args), num_args)
+ )
+
+ self.num_args = num_args
+
+ # list of the names of the args used as the cache key
+ self.arg_names = all_args[1:num_args + 1]
+
+ # self.arg_defaults is a map of arg name to its default value for each
+ # argument that has a default value
+ if arg_spec.defaults:
+ self.arg_defaults = dict(zip(
+ all_args[-len(arg_spec.defaults):],
+ arg_spec.defaults
+ ))
+ else:
+ self.arg_defaults = {}
+
+ if "cache_context" in self.arg_names:
+ raise Exception(
+ "cache_context arg cannot be included among the cache keys"
+ )
+
+ self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +298,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
+ Args:
+ num_args (int): number of positional arguments (excluding ``self`` and
+ ``cache_context``) to use as cache keys. Defaults to all named
+ args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False):
- max_entries = int(max_entries * CACHE_SIZE_FACTOR)
- self.orig = orig
+ super(CacheDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
+ max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.max_entries = max_entries
- self.num_args = num_args
self.tree = tree
-
self.iterable = iterable
- all_args = inspect.getargspec(orig)
- self.arg_names = all_args.args[1:num_args + 1]
-
- if "cache_context" in all_args.args:
- if not cache_context:
- raise ValueError(
- "Cannot have a 'cache_context' arg without setting"
- " cache_context=True"
- )
- try:
- self.arg_names.remove("cache_context")
- except ValueError:
- pass
- elif cache_context:
- raise ValueError(
- "Cannot have cache_context=True without having an arg"
- " named `cache_context`"
- )
-
- self.add_cache_context = cache_context
-
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwargs)"
- % (orig.__name__,)
- )
-
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
@@ -272,18 +325,47 @@ class CacheDescriptor(object):
iterable=self.iterable,
)
+ def get_cache_key_gen(args, kwargs):
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ We loop through each arg name, looking up if its in the `kwargs`,
+ otherwise using the next argument in `args`. If there are no more
+ args then we try looking the arg name up in the defaults
+ """
+ pos = 0
+ for nm in self.arg_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield self.arg_defaults[nm]
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+ if self.num_args == 1:
+ nm = self.arg_names[0]
+
+ def get_cache_key(args, kwargs):
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return self.arg_defaults[nm]
+ else:
+ def get_cache_key(args, kwargs):
+ return tuple(get_cache_key_gen(args, kwargs))
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
- # Add temp cache_context so inspect.getcallargs doesn't explode
- if self.add_cache_context:
- kwargs["cache_context"] = None
-
- arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+ cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@@ -293,26 +375,14 @@ class CacheDescriptor(object):
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
- observer = cached_result_d.observe()
- if DEBUG_CACHES:
- @defer.inlineCallbacks
- def check_result(cached_result):
- actual_result = yield self.function_to_call(obj, *args, **kwargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, cache_key,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
- observer.addCallback(check_result)
-
- return preserve_context_over_deferred(observer)
+ if isinstance(cached_result_d, ObservableDeferred):
+ observer = cached_result_d.observe()
+ else:
+ observer = cached_result_d
+
except KeyError:
ret = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
obj, *args, **kwargs
)
@@ -322,64 +392,72 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, ret, callback=invalidate_callback)
+ # If our cache_key is a string, try to convert to ascii to save
+ # a bit of space in large caches
+ if isinstance(cache_key, basestring):
+ cache_key = to_ascii(cache_key)
- return preserve_context_over_deferred(ret.observe())
+ result_d = ObservableDeferred(ret, consumeErrors=True)
+ cache.set(cache_key, result_d, callback=invalidate_callback)
+ observer = result_d.observe()
+
+ if isinstance(observer, defer.Deferred):
+ return logcontext.make_deferred_yieldable(observer)
+ else:
+ return observer
+
+ if self.num_args == 1:
+ wrapped.invalidate = lambda key: cache.invalidate(key[0])
+ wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+ else:
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
- wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
- wrapped.invalidate_many = cache.invalidate_many
- wrapped.prefill = cache.prefill
wrapped.cache = cache
+ wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped fucntion.
+ the list of missing keys to the wrapped function.
+
+ Once wrapped, the function returns either a Deferred which resolves to
+ the list of results, or (if all results were cached), just the list of
+ results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False):
"""
Args:
orig (function)
- method_name (str); The name of the chached method.
+ cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
- num_args (int)
+ num_args (int): number of positional arguments (excluding ``self``,
+ but including list_name) to use as cache keys. Defaults to all
+ named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
- self.orig = orig
+ super(CacheListDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
- self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
-
self.cached_method_name = cached_method_name
self.sentinel = object()
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
- % (orig.__name__,)
- )
-
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
@@ -387,8 +465,9 @@ class CacheListDescriptor(object):
)
def __get__(self, obj, objtype=None):
-
- cache = getattr(obj, self.cached_method_name).cache
+ cached_method = getattr(obj, self.cached_method_name)
+ cache = cached_method.cache
+ num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
@@ -405,13 +484,26 @@ class CacheListDescriptor(object):
results = {}
cached_defers = {}
missing = []
- for arg in list_args:
+
+ # If the cache takes a single arg then that is used as the key,
+ # otherwise a tuple is used.
+ if num_args == 1:
+ def cache_get(arg):
+ return cache.get(arg, callback=invalidate_callback)
+ else:
key = list(keyargs)
- key[self.list_pos] = arg
+ def cache_get(arg):
+ key[self.list_pos] = arg
+ return cache.get(tuple(key), callback=invalidate_callback)
+
+ for arg in list_args:
try:
- res = cache.get(tuple(key), callback=invalidate_callback)
- if not res.has_succeeded():
+ res = cache_get(arg)
+
+ if not isinstance(res, ObservableDeferred):
+ results[arg] = res
+ elif not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
@@ -425,8 +517,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
**args_to_call
)
@@ -435,23 +526,33 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- with PreserveLoggingContext():
- observer = ret_d.observe()
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
+ if num_args == 1:
+ cache.set(
+ arg, observer,
+ callback=invalidate_callback
+ )
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, arg)
+ else:
+ key = list(keyargs)
+ key[self.list_pos] = arg
+ cache.set(
+ tuple(key), observer,
+ callback=invalidate_callback
+ )
+
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@@ -463,7 +564,7 @@ class CacheListDescriptor(object):
results.update(res)
return results
- return preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
@@ -487,7 +588,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
@@ -499,8 +600,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
- iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+ cache_context=False, iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -512,7 +613,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
)
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +626,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
- num_args (int): Number of arguments to use as the key in the cache.
+ num_args (int): Number of arguments to use as the key in the cache
+ (including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..1709e8b429 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
logger = logging.getLogger(__name__)
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+ """Returned when getting an entry from the cache
+
+ Attributes:
+ full (bool): Whether the cache has the full or dict or just some keys.
+ If not full then not all requested keys will necessarily be present
+ in `value`
+ known_absent (set): Keys that were looked up in the dict and were not
+ there.
+ value (dict): The full or partial dict value
+ """
def __len__(self):
return len(self.value)
@@ -58,21 +68,31 @@ class DictionaryCache(object):
)
def get(self, key, dict_keys=None):
+ """Fetch an entry out of the cache
+
+ Args:
+ key
+ dict_key(list): If given a set of keys then return only those keys
+ that exist in the cache.
+
+ Returns:
+ DictionaryEntry
+ """
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
self.metrics.inc_hits()
if dict_keys is None:
- return DictionaryEntry(entry.full, dict(entry.value))
+ return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
else:
- return DictionaryEntry(entry.full, {
+ return DictionaryEntry(entry.full, entry.known_absent, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
self.metrics.inc_misses()
- return DictionaryEntry(False, {})
+ return DictionaryEntry(False, set(), {})
def invalidate(self, key):
self.check_thread()
@@ -87,19 +107,38 @@ class DictionaryCache(object):
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, full=False):
+ def update(self, sequence, key, value, full=False, known_absent=None):
+ """Updates the entry in the cache
+
+ Args:
+ sequence
+ key
+ value (dict): The value to update the cache with.
+ full (bool): Whether the given value is the full dict, or just a
+ partial subset there of. If not full then any existing entries
+ for the key will be updated.
+ known_absent (set): Set of keys that we know don't exist in the full
+ dict.
+ """
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
+ if known_absent is None:
+ known_absent = set()
if full:
- self._insert(key, value)
+ self._insert(key, value, known_absent)
else:
- self._update_or_insert(key, value)
+ self._update_or_insert(key, value, known_absent)
+
+ def _update_or_insert(self, key, value, known_absent):
+ # We pop and reinsert as we need to tell the cache the size may have
+ # changed
- def _update_or_insert(self, key, value):
- entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+ entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
+ entry.known_absent.update(known_absent)
+ self.cache[key] = entry
- def _insert(self, key, value):
- self.cache[key] = DictionaryEntry(True, value)
+ def _insert(self, key, value, known_absent):
+ self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2987c38a2d..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- self._size_estimate -= len(value.value)
+ removed_len = len(value.value)
+ self.metrics.inc_evictions(removed_len)
+ self._size_estimate -= removed_len
+ else:
+ self.metrics.inc_evictions()
def __getitem__(self, key):
try:
@@ -94,12 +98,22 @@ class ExpiringCache(object):
return entry.value
+ def __contains__(self, key):
+ return key in self._cache
+
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
+ def setdefault(self, key, value):
+ try:
+ return self[key]
+ except KeyError:
+ self[key] = value
+ return value
+
def _prune_cache(self):
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+ def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+ evicted_callback=None):
+ """
+ Args:
+ max_size (int):
+
+ keylen (int):
+
+ cache_type (type):
+ type of underlying cache to be used. Typically one of dict
+ or TreeCache.
+
+ size_callback (func(V) -> int | None):
+
+ evicted_callback (func(int)|None):
+ if not None, called on eviction with the size of the evicted
+ entry
+ """
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
- delete_node(todelete)
+ evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
+ if evicted_callback:
+ evicted_callback(evicted_len)
def synchronized(f):
@wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ deleted_len = 1
if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
+ deleted_len = size_callback(node.value)
+ cached_cache_len[0] -= deleted_len
for cb in node.callbacks:
cb()
node.callbacks.clear()
+ return deleted_len
@synchronized
def cache_get(key, default=None, callbacks=[]):
@@ -132,14 +154,21 @@ class LruCache(object):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
- if value != node.value:
+ # We sometimes store large objects, e.g. dicts, which cause
+ # the inequality check to take a long time. So let's only do
+ # the check if we have some callbacks to call.
+ if node.callbacks and value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
- if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
- cached_cache_len[0] += size_callback(value)
+ # We don't bother to protect this by value != node.value as
+ # generally size_callback will be cheap compared with equality
+ # checks. (For example, taking the size of two dicts is quicker
+ # than comparing them for equality.)
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+ cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..7f79333e96 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from twisted.internet import defer
from synapse.util.async import ObservableDeferred
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
class ResponseCache(object):
@@ -24,20 +31,68 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self, hs, timeout_ms=0):
+ def __init__(self, hs, name, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
+ self._name = name
+ self._metrics = cache_metrics.register_cache(
+ "response_cache",
+ size_callback=lambda: self.size(),
+ cache_name=name,
+ )
+
+ def size(self):
+ return len(self.pending_result_cache)
+
def get(self, key):
+ """Look up the given key.
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if the request has completed, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ If there is no entry for the key, returns None. It is worth noting that
+ this means there is no way to distinguish a completed result of None
+ from an absent cache entry.
+
+ Args:
+ key (hashable):
+
+ Returns:
+ twisted.internet.defer.Deferred|None|E: None if there is no entry
+ for this key; otherwise either a deferred result or the result
+ itself.
+ """
result = self.pending_result_cache.get(key)
if result is not None:
+ self._metrics.inc_hits()
return result.observe()
else:
+ self._metrics.inc_misses()
return None
def set(self, key, deferred):
+ """Set the entry for the given key to the given deferred.
+
+ *deferred* should run its callbacks in the sentinel logcontext (ie,
+ you should wrap normal synapse deferreds with
+ logcontext.run_in_background).
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if *deferred* was already complete, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ Args:
+ key (hashable):
+ deferred (twisted.internet.defer.Deferred[T):
+
+ Returns:
+ twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+ result.
+ """
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -53,3 +108,52 @@ class ResponseCache(object):
result.addBoth(remove)
return result.observe()
+
+ def wrap(self, key, callback, *args, **kwargs):
+ """Wrap together a *get* and *set* call, taking care of logcontexts
+
+ First looks up the key in the cache, and if it is present makes it
+ follow the synapse logcontext rules and returns it.
+
+ Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+ follow the synapse logcontext rules, and adds the result to the cache.
+
+ Example usage:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ # etc
+ defer.returnValue(result)
+
+ result = yield response_cache.wrap(
+ key,
+ handle_request,
+ request,
+ )
+
+ Args:
+ key (hashable): key to get/set in the cache
+
+ callback (callable): function to call if the key is not found in
+ the cache
+
+ *args: positional parameters to pass to the callback, if it is used
+
+ **kwargs: named paramters to pass to the callback, if it is used
+
+ Returns:
+ twisted.internet.defer.Deferred: yieldable result
+ """
+ result = self.get(key)
+ if not result:
+ logger.info("[%s]: no cached result for [%s], calculating new one",
+ self._name, key)
+ d = run_in_background(callback, *args, **kwargs)
+ result = self.set(key, d)
+ elif not isinstance(result, defer.Deferred) or result.called:
+ logger.info("[%s]: using completed cached result for [%s]",
+ self._name, key)
+ else:
+ logger.info("[%s]: using incomplete cached result for [%s]",
+ self._name, key)
+ return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b72bb0ff02..941d873ab8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,20 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import register_cache
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
from blist import sorteddict
import logging
-import os
logger = logging.getLogger(__name__)
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
@@ -50,7 +46,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
- assert type(stream_pos) is int
+ assert type(stream_pos) is int or type(stream_pos) is long
if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
@@ -89,6 +85,21 @@ class StreamChangeCache(object):
return result
+ def has_any_entity_changed(self, stream_pos):
+ """Returns if any entity has changed
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ self.metrics.inc_hits()
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ return i < len(keys)
+ else:
+ self.metrics.inc_misses()
+ return True
+
def get_all_entities_changed(self, stream_pos):
"""Returns all entites that have had new things since the given
position. If the position is too old it will return None.
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e68f94ce77..734331caaa 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,32 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_fn
-)
+from twisted.internet import defer
from synapse.util import unwrapFirstError
-
-import logging
-
+from synapse.util.logcontext import PreserveLoggingContext
logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id):
- return preserve_context_over_fn(
- distributor.fire,
- "user_left_room", user=user, room_id=room_id
- )
+ with PreserveLoggingContext():
+ distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id):
- return preserve_context_over_fn(
- distributor.fire,
- "user_joined_room", user=user, room_id=room_id
- )
+ with PreserveLoggingContext():
+ distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object):
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..3380970e4e
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,141 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+from six.moves import queue
+
+
+class BackgroundFileConsumer(object):
+ """A consumer that writes to a file like object. Supports both push
+ and pull producers
+
+ Args:
+ file_obj (file): The file like object to write to. Closed when
+ finished.
+ """
+
+ # For PushProducers pause if we have this many unwritten slices
+ _PAUSE_ON_QUEUE_SIZE = 5
+ # And resume once the size of the queue is less than this
+ _RESUME_ON_QUEUE_SIZE = 2
+
+ def __init__(self, file_obj):
+ self._file_obj = file_obj
+
+ # Producer we're registered with
+ self._producer = None
+
+ # True if PushProducer, false if PullProducer
+ self.streaming = False
+
+ # For PushProducers, indicates whether we've paused the producer and
+ # need to call resumeProducing before we get more data.
+ self._paused_producer = False
+
+ # Queue of slices of bytes to be written. When producer calls
+ # unregister a final None is sent.
+ self._bytes_queue = queue.Queue()
+
+ # Deferred that is resolved when finished writing
+ self._finished_deferred = None
+
+ # If the _writer thread throws an exception it gets stored here.
+ self._write_exception = None
+
+ def registerProducer(self, producer, streaming):
+ """Part of IConsumer interface
+
+ Args:
+ producer (IProducer)
+ streaming (bool): True if push based producer, False if pull
+ based.
+ """
+ if self._producer:
+ raise Exception("registerProducer called twice")
+
+ self._producer = producer
+ self.streaming = streaming
+ self._finished_deferred = run_in_background(
+ threads.deferToThread, self._writer
+ )
+ if not streaming:
+ self._producer.resumeProducing()
+
+ def unregisterProducer(self):
+ """Part of IProducer interface
+ """
+ self._producer = None
+ if not self._finished_deferred.called:
+ self._bytes_queue.put_nowait(None)
+
+ def write(self, bytes):
+ """Part of IProducer interface
+ """
+ if self._write_exception:
+ raise self._write_exception
+
+ if self._finished_deferred.called:
+ raise Exception("consumer has closed")
+
+ self._bytes_queue.put_nowait(bytes)
+
+ # If this is a PushProducer and the queue is getting behind
+ # then we pause the producer.
+ if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+ self._paused_producer = True
+ self._producer.pauseProducing()
+
+ def _writer(self):
+ """This is run in a background thread to write to the file.
+ """
+ try:
+ while self._producer or not self._bytes_queue.empty():
+ # If we've paused the producer check if we should resume the
+ # producer.
+ if self._producer and self._paused_producer:
+ if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+ reactor.callFromThread(self._resume_paused_producer)
+
+ bytes = self._bytes_queue.get()
+
+ # If we get a None (or empty list) then that's a signal used
+ # to indicate we should check if we should stop.
+ if bytes:
+ self._file_obj.write(bytes)
+
+ # If its a pull producer then we need to explicitly ask for
+ # more stuff.
+ if not self.streaming and self._producer:
+ reactor.callFromThread(self._producer.resumeProducing)
+ except Exception as e:
+ self._write_exception = e
+ raise
+ finally:
+ self._file_obj.close()
+
+ def wait(self):
+ """Returns a deferred that resolves when finished writing to file
+ """
+ return make_deferred_yieldable(self._finished_deferred)
+
+ def _resume_paused_producer(self):
+ """Gets called if we should resume producing after being paused
+ """
+ if self._paused_producer and self._producer:
+ self._paused_producer = False
+ self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 6322f0f55c..f497b51f4a 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,6 +14,7 @@
# limitations under the License.
from frozendict import frozendict
+import simplejson as json
def freeze(o):
@@ -49,3 +50,21 @@ def unfreeze(o):
pass
return o
+
+
+def _handle_frozendict(obj):
+ """Helper for EventEncoder. Makes frozendicts serializable by returning
+ the underlying dict
+ """
+ if type(obj) is frozendict:
+ # fishing the protected dict out of the object is a bit nasty,
+ # but we don't really want the overhead of copying the dict.
+ return obj._dict
+ raise TypeError('Object of type %s is not JSON serializable' %
+ obj.__class__.__name__)
+
+
+# A JSONEncoder which is capable of encoding frozendics without barfing
+frozendict_json_encoder = json.JSONEncoder(
+ default=_handle_frozendict,
+)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 45be47159a..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
import logging
@@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
+ # twisted requires all resources to be bytes
+ full_path = full_path.encode("utf-8")
+
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
+ for path_seg in full_path.split(b'/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
+ child_resource = NoResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
+ last_path_seg = full_path.split(b'/')[-1]
# if there is already a resource here, thieve its children and
# replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6c83eb213d..e086e12213 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
from twisted.internet import defer
import threading
@@ -32,7 +42,7 @@ try:
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
-except:
+except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage by returning None.
def get_thread_resource_usage():
@@ -42,13 +52,17 @@ except:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
+
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "previous_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag", "alive",
+ "previous_context", "name", "ru_stime", "ru_utime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "usage_start", "usage_end",
+ "main_thread", "alive",
+ "request", "tag",
]
thread_local = threading.local()
@@ -73,8 +87,12 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def add_database_scheduled(self, sched_ms):
+ pass
+
def __nonzero__(self):
return False
+ __bool__ = __nonzero__ # python3
sentinel = Sentinel()
@@ -84,9 +102,17 @@ class LoggingContext(object):
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
- self.db_txn_duration = 0.
+
+ # ms spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_ms = 0
+
+ # ms spent waiting for db txns to be scheduled
+ self.db_sched_duration_ms = 0
+
self.usage_start = None
+ self.usage_end = None
self.main_thread = threading.current_thread()
+ self.request = None
self.tag = ""
self.alive = True
@@ -95,7 +121,11 @@ class LoggingContext(object):
@classmethod
def current_context(cls):
- """Get the current logging context from thread local storage"""
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -145,11 +175,13 @@ class LoggingContext(object):
self.alive = False
def copy_to(self, record):
- """Copy fields from this context to the record"""
- for key, value in self.__dict__.items():
- setattr(record, key, value)
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
- record.ru_utime, record.ru_stime = self.get_resource_usage()
+ # 'request' is the only field we currently use in the logger, so that's
+ # all we need to copy
+ record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
@@ -184,7 +216,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
- self.db_txn_duration += duration_ms / 1000.
+ self.db_txn_duration_ms += duration_ms
+
+ def add_database_scheduled(self, sched_ms):
+ """Record a use of the database pool
+
+ Args:
+ sched_ms (int): number of milliseconds it took us to get a
+ connection
+ """
+ self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter):
@@ -251,80 +292,94 @@ class PreserveLoggingContext(object):
)
-class _PreservingContextDeferred(defer.Deferred):
- """A deferred that ensures that all callbacks and errbacks are called with
- the given logging context.
- """
- def __init__(self, context):
- self._log_context = context
- defer.Deferred.__init__(self)
-
- def addCallbacks(self, callback, errback=None,
- callbackArgs=None, callbackKeywords=None,
- errbackArgs=None, errbackKeywords=None):
- callback = self._wrap_callback(callback)
- errback = self._wrap_callback(errback)
- return defer.Deferred.addCallbacks(
- self, callback,
- errback=errback,
- callbackArgs=callbackArgs,
- callbackKeywords=callbackKeywords,
- errbackArgs=errbackArgs,
- errbackKeywords=errbackKeywords,
- )
+def preserve_fn(f):
+ """Function decorator which wraps the function with run_in_background"""
+ def g(*args, **kwargs):
+ return run_in_background(f, *args, **kwargs)
+ return g
- def _wrap_callback(self, f):
- def g(res, *args, **kwargs):
- with PreserveLoggingContext(self._log_context):
- res = f(res, *args, **kwargs)
- return res
- return g
+def run_in_background(f, *args, **kwargs):
+ """Calls a function, ensuring that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the function completes.
-def preserve_context_over_fn(fn, *args, **kwargs):
- """Takes a function and invokes it with the given arguments, but removes
- and restores the current logging context while doing so.
+ Useful for wrapping functions that return a deferred which you don't yield
+ on (for instance because you want to pass it to deferred.gatherResults()).
- If the result is a deferred, call preserve_context_over_deferred before
- returning it.
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
"""
- with PreserveLoggingContext():
- res = fn(*args, **kwargs)
-
- if isinstance(res, defer.Deferred):
- return preserve_context_over_deferred(res)
- else:
+ current = LoggingContext.current_context()
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if not isinstance(res, defer.Deferred):
return res
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
-def preserve_context_over_deferred(deferred, context=None):
- """Given a deferred wrap it such that any callbacks added later to it will
- be invoked with the current context.
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
+ return res
+
+
+def make_deferred_yieldable(deferred):
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to run_in_background.)
"""
- if context is None:
- context = LoggingContext.current_context()
- d = _PreservingContextDeferred(context)
- deferred.chainDeferred(d)
- return d
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
-def preserve_fn(f):
- """Ensures that function is called with correct context and that context is
- restored after return. Useful for wrapping functions that return a deferred
- which you don't yield on.
- """
- current = LoggingContext.current_context()
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
+ return deferred
- def g(*args, **kwargs):
- with PreserveLoggingContext(current):
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred):
- return preserve_context_over_deferred(
- res, context=LoggingContext.sentinel
- )
- else:
- return res
- return g
+
+def _set_context_cb(result, context):
+ """A callback function which just sets the logging context"""
+ LoggingContext.set_current_context(context)
+ return result
# modules to ignore in `logcontext_tracer`
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
new file mode 100644
index 0000000000..3e42868ea9
--- /dev/null
+++ b/synapse/util/logformatter.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from six import StringIO
+import logging
+import traceback
+
+
+class LogFormatter(logging.Formatter):
+ """Log formatter which gives more detail for exceptions
+
+ This is the same as the standard log formatter, except that when logging
+ exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+ sequence that led up to the point at which the exception was caught.
+ (Normally only stack frames between the point the exception was raised and
+ where it was caught are logged).
+ """
+ def __init__(self, *args, **kwargs):
+ super(LogFormatter, self).__init__(*args, **kwargs)
+
+ def formatException(self, ei):
+ sio = StringIO()
+ (typ, val, tb) = ei
+
+ # log the stack above the exception capture point if possible, but
+ # check that we actually have an f_back attribute to work around
+ # https://twistedmatrix.com/trac/ticket/9305
+
+ if tb and hasattr(tb.tb_frame, 'f_back'):
+ sio.write("Capture point (most recent call last):\n")
+ traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+ traceback.print_exception(typ, val, tb, None, sio)
+ s = sio.getvalue()
+ sio.close()
+ if s[-1:] == "\n":
+ s = s[:-1]
+ return s
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
-block_timer = metrics.register_distribution(
- "block_timer",
- labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+ "block_count",
+ labels=["block_name"],
+ alternative_names=(
+ # the following are all deprecated aliases for the same metric
+ metrics.name_prefix + x for x in (
+ "_block_timer:count",
+ "_block_ru_utime:count",
+ "_block_ru_stime:count",
+ "_block_db_txn_count:count",
+ "_block_db_txn_duration:count",
+ )
+ )
+)
+
+block_timer = metrics.register_counter(
+ "block_time_seconds",
+ labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_timer:total",
+ ),
)
-block_ru_utime = metrics.register_distribution(
- "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+ "block_ru_utime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_utime:total",
+ ),
)
-block_ru_stime = metrics.register_distribution(
- "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+ "block_ru_stime_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_ru_stime:total",
+ ),
)
-block_db_txn_count = metrics.register_distribution(
- "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+ "block_db_txn_count", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_count:total",
+ ),
)
-block_db_txn_duration = metrics.register_distribution(
- "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+ "block_db_txn_duration_seconds", labels=["block_name"],
+ alternative_names=(
+ metrics.name_prefix + "_block_db_txn_duration:total",
+ ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+ "block_db_sched_duration_seconds", labels=["block_name"],
)
@@ -64,7 +101,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+ "ru_stime",
+ "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "created_context",
]
def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+ self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
+
+ block_counter.inc(self.name)
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by(
- context.db_txn_duration - self.db_txn_duration, self.name
+ (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+ self.name
+ )
+ block_db_sched_duration.inc_by(
+ (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+ self.name
)
if self.created_context:
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
new file mode 100644
index 0000000000..4288312b8a
--- /dev/null
+++ b/synapse/util/module_loader.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+
+from synapse.config._base import ConfigError
+
+
+def load_module(provider):
+ """ Loads a module with its config
+ Take a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+
+ Returns
+ Tuple of (provider class, parsed config object)
+ """
+ # We need to import the module, and then pick the class out of
+ # that, so we split based on the last dot.
+ module, clz = provider['module'].rsplit(".", 1)
+ module = importlib.import_module(module)
+ provider_class = getattr(module, clz)
+
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
+
+ return provider_class, provider_config
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
new file mode 100644
index 0000000000..607161e7f0
--- /dev/null
+++ b/synapse/util/msisdn.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import phonenumbers
+from synapse.api.errors import SynapseError
+
+
+def phone_number_to_msisdn(country, number):
+ """
+ Takes an ISO-3166-1 2 letter country code and phone number and
+ returns an msisdn representing the canonical version of that
+ phone number.
+ Args:
+ country (str): ISO-3166-1 2 letter country code
+ number (str): Phone number in a national or international format
+
+ Returns:
+ (str) The canonical form of the phone number, as an msisdn
+ Raises:
+ SynapseError if the number could not be parsed.
+ """
+ try:
+ phoneNumber = phonenumbers.parse(number, country)
+ except phonenumbers.NumberParseException:
+ raise SynapseError(400, "Unable to parse phone number")
+ return phonenumbers.format_number(
+ phoneNumber, phonenumbers.PhoneNumberFormat.E164
+ )[1:]
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..18424f6c36 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
import collections
import contextlib
@@ -150,7 +150,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 153ef001ad..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import synapse.util.logcontext
from twisted.internet import defer
from synapse.api.errors import CodeMessageException
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
+ """Raised by the limiter (and federation client) to indicate that we are
+ are deliberately not attempting to contact a given server.
+
+ Args:
+ retry_last_ts (int): the unix ts in milliseconds of our last attempt
+ to contact the server. 0 indicates that the last attempt was
+ successful or that we've never actually attempted to connect.
+ retry_interval (int): the time in milliseconds to wait until the next
+ attempt.
+ destination (str): the domain in question
+ """
+
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)
@@ -35,7 +47,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+ **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -43,6 +56,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500)
+ Args:
+ destination (str): name of homeserver
+ clock (synapse.util.clock): timing source
+ store (synapse.storage.transactions.TransactionStore): datastore
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway. We will still update the next
+ retry_interval on success/failure.
+
Example usage:
try:
@@ -66,7 +87,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec())
- if retry_last_ts + retry_interval > now:
+ if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination(
retry_last_ts=retry_last_ts,
retry_interval=retry_interval,
@@ -124,7 +145,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False
- if exc_type is not None and issubclass(exc_type, CodeMessageException):
+ if exc_type is None:
+ valid_err_code = True
+ elif not issubclass(exc_type, Exception):
+ # avoid treating exceptions which don't derive from Exception as
+ # failures; this is mostly so as not to catch defer._DefGen.
+ valid_err_code = True
+ elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +169,13 @@ class RetryDestinationLimiter(object):
else:
valid_err_code = False
- if exc_type is None or valid_err_code:
+ if valid_err_code:
# We connected successfully.
if not self.retry_interval:
return
+ logger.debug("Connection to %s was successful; clearing backoff",
+ self.destination)
retry_last_ts = 0
self.retry_interval = 0
else:
@@ -160,6 +189,10 @@ class RetryDestinationLimiter(object):
else:
self.retry_interval = self.min_retry_interval
+ logger.debug(
+ "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+ self.destination, exc_type, exc_val, self.retry_interval
+ )
retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks
@@ -168,9 +201,10 @@ class RetryDestinationLimiter(object):
yield self.store.set_destination_retry_timings(
self.destination, retry_last_ts, self.retry_interval
)
- except:
+ except Exception:
logger.exception(
- "Failed to store set_destination_retry_timings",
+ "Failed to store destination_retry_timings",
)
- store_retry_timings()
+ # we deliberately do this in the background.
+ synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index a100f151d4..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
import random
import string
+from six.moves import range
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in xrange(length)
+ random.choice(_string_with_symbols) for _ in range(length)
)
@@ -40,3 +41,17 @@ def is_ascii(s):
return False
else:
return True
+
+
+def to_ascii(s):
+ """Converts a string to ascii if it is ascii, otherwise leave it alone.
+
+ If given None then will return None.
+ """
+ if s is None:
+ return None
+
+ try:
+ return s.encode("ascii")
+ except UnicodeEncodeError:
+ return s
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+ """Checks whether a given format of 3PID is allowed to be used on this HS
+
+ Args:
+ hs (synapse.server.HomeServer): server
+ medium (str): 3pid medium - e.g. email, msisdn
+ address (str): address within that medium (e.g. "wotan@matrix.org")
+ msisdns need to first have been canonicalised
+ Returns:
+ bool: whether the 3PID medium/address is allowed to be added to this HS
+ """
+
+ if hs.config.allowed_local_3pids:
+ for constraint in hs.config.allowed_local_3pids:
+ logger.debug(
+ "Checking 3PID %s (%s) against %s (%s)",
+ address, medium, constraint['pattern'], constraint['medium'],
+ )
+ if (
+ medium == constraint['medium'] and
+ re.match(constraint['pattern'], address)
+ ):
+ return True
+ else:
+ return True
+
+ return False
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7412fc57a4..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six.moves import range
+
class _Entry(object):
__slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
- _Entry(key) for key in xrange(last_key, then_key + 1)
+ _Entry(key) for key in range(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
@@ -91,7 +93,4 @@ class WheelTimer(object):
return ret
def __len__(self):
- l = 0
- for entry in self.entries:
- l += len(entry.queue)
- return l
+ return sum(len(entry.queue) for entry in self.entries)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 199b16d827..aaca2c584c 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import logging
@@ -43,7 +43,8 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
-def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
+def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
+ always_include_ids=frozenset()):
""" Returns dict of user_id -> list of events that user is allowed to
see.
@@ -54,9 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
* the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
+ always_include_ids (set(event_id)): set of event ids to specifically
+ include (unless sender is ignored)
"""
- forgotten = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(store.who_forgot_in_room)(
+ forgotten = yield make_deferred_yieldable(defer.gatherResults([
+ defer.maybeDeferred(
+ preserve_fn(store.who_forgot_in_room),
room_id,
)
for room_id in frozenset(e.room_id for e in events)
@@ -90,6 +94,9 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if not event.is_state() and event.sender in ignore_list:
return False
+ if event.event_id in always_include_ids:
+ return True
+
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
@@ -134,6 +141,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
+ # Always allow the user to see their own leave events, otherwise
+ # they won't see the room disappear if they reject the invite
+ if membership == "leave" and (
+ prev_membership == "join" or prev_membership == "invite"
+ ):
+ return True
+
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
@@ -181,26 +195,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
@defer.inlineCallbacks
-def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
- user_ids = set(u[0] for u in user_tuples)
- event_id_to_state = {}
- for event_id, context in event_id_to_context.items():
- state = yield store.get_events([
- e_id
- for key, e_id in context.current_state_ids.iteritems()
- if key == (EventTypes.RoomHistoryVisibility, "")
- or (key[0] == EventTypes.Member and key[1] in user_ids)
- ])
- event_id_to_state[event_id] = state
-
- res = yield filter_events_for_clients(
- store, user_tuples, events, event_id_to_state
- )
- defer.returnValue(res)
-
-
-@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False):
+def filter_events_for_client(store, user_id, events, is_peeking=False,
+ always_include_ids=frozenset()):
"""
Check which events a user is allowed to see
@@ -224,6 +220,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False):
types=types
)
res = yield filter_events_for_clients(
- store, [(user_id, is_peeking)], events, event_id_to_state
+ store, [(user_id, is_peeking)], events, event_id_to_state,
+ always_include_ids=always_include_ids,
)
defer.returnValue(res.get(user_id, []))
|