diff --git a/synapse/__init__.py b/synapse/__init__.py
index bf9e810da6..4f95778eea 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.99.4"
+__version__ = "0.99.5.1"
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6b347b1749..ee129c8689 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
+ Encryption = "m.room.encryption"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 485b3d0237..4085bd10b9 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -77,24 +77,20 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
)
- EVENTID_NOSLASH_TEST = RoomVersion(
- "eventid-noslash-test",
- RoomDisposition.UNSTABLE,
+ V4 = RoomVersion(
+ "4",
+ RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
)
-# the version we will give rooms which are created on this server
-DEFAULT_ROOM_VERSION = RoomVersions.V1
-
-
KNOWN_ROOM_VERSIONS = {
v.identifier: v for v in (
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
- RoomVersions.EVENTID_NOSLASH_TEST,
+ RoomVersions.V4,
)
} # type: dict[str, RoomVersion]
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 864f1eac48..a16e037f32 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -38,6 +38,7 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -81,6 +82,7 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedTransactionStore,
+ SlavedProfileStore,
SlavedClientIpStore,
BaseSlavedStore,
):
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 727fdc54d8..5c4fc8ff21 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -36,20 +37,41 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
+from .stats import StatsConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig,
- RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
- AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
- JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,
- SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
- ConsentConfig,
- ServerNoticesConfig, RoomDirectoryConfig,
- ):
+class HomeServerConfig(
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+):
pass
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 1309bce3ee..693288f938 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -123,6 +123,14 @@ class RegistrationConfig(Config):
# link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
# from the ``email`` section.
#
+ # Once this feature is enabled, Synapse will look for registered users without an
+ # expiration date at startup and will add one to every account it found using the
+ # current settings at that time.
+ # This means that, if a validity period is set, and Synapse is restarted (it will
+ # then derive an expiration date from the current validity period), and some time
+ # after that the validity period changes and Synapse is restarted, the users'
+ # expiration dates won't be updated unless their account is manually renewed.
+ #
#account_validity:
# enabled: True
# period: 6w
diff --git a/synapse/config/server.py b/synapse/config/server.py
index f34aa42afa..e9120d4d75 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -20,6 +20,7 @@ import os.path
from netaddr import IPSet
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@@ -35,6 +36,8 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
+DEFAULT_ROOM_VERSION = "1"
+
class ServerConfig(Config):
@@ -88,6 +91,22 @@ class ServerConfig(Config):
"restrict_public_rooms_to_local_users", False,
)
+ default_room_version = config.get(
+ "default_room_version", DEFAULT_ROOM_VERSION,
+ )
+
+ # Ensure room version is a str
+ default_room_version = str(default_room_version)
+
+ if default_room_version not in KNOWN_ROOM_VERSIONS:
+ raise ConfigError(
+ "Unknown default_room_version: %s, known room versions: %s" %
+ (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
+ )
+
+ # Get the actual room version object rather than just the identifier
+ self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]
+
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
@@ -310,6 +329,10 @@ class ServerConfig(Config):
unsecure_port = 8008
pid_file = os.path.join(data_dir_path, "homeserver.pid")
+
+ # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
+ # default config string
+ default_room_version = DEFAULT_ROOM_VERSION
return """\
## Server ##
@@ -384,6 +407,15 @@ class ServerConfig(Config):
#
#restrict_public_rooms_to_local_users: true
+ # The default room version for newly created rooms.
+ #
+ # Known room versions are listed here:
+ # https://matrix.org/docs/spec/#complete-list-of-room-versions
+ #
+ # For example, for room version 1, default_room_version should be set
+ # to "1".
+ #default_room_version: "%(default_room_version)s"
+
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
new file mode 100644
index 0000000000..80fc1b9dd0
--- /dev/null
+++ b/synapse/config/stats.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import sys
+
+from ._base import Config
+
+
+class StatsConfig(Config):
+ """Stats Configuration
+ Configuration for the behaviour of synapse's stats engine
+ """
+
+ def read_config(self, config):
+ self.stats_enabled = True
+ self.stats_bucket_size = 86400
+ self.stats_retention = sys.maxsize
+ stats_config = config.get("stats", None)
+ if stats_config:
+ self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
+ self.stats_bucket_size = (
+ self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000
+ )
+ self.stats_retention = (
+ self.parse_duration(
+ stats_config.get("retention", "%ds" % (sys.maxsize,))
+ )
+ / 1000
+ )
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Local statistics collection. Used in populating the room directory.
+ #
+ # 'bucket_size' controls how large each statistics timeslice is. It can
+ # be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ # 'retention' controls how long historical statistics will be kept for.
+ # It can be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ #
+ #stats:
+ # enabled: true
+ # bucket_size: 1d
+ # retention: 1y
+ """
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d8ba870cca..eaf41b983c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -20,7 +20,6 @@ from collections import namedtuple
from six import raise_from
from six.moves import urllib
-import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -43,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@@ -56,9 +56,9 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
-VerifyKeyRequest = namedtuple("VerifyRequest", (
- "server_name", "key_ids", "json_object", "deferred"
-))
+VerifyKeyRequest = namedtuple(
+ "VerifyRequest", ("server_name", "key_ids", "json_object", "deferred")
+)
"""
A request for a verify key to verify a JSON object.
@@ -80,12 +80,13 @@ class KeyLookupError(ValueError):
class Keyring(object):
def __init__(self, hs):
- self.store = hs.get_datastore()
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
- self.config = hs.get_config()
- self.perspective_servers = self.config.perspectives
- self.hs = hs
+
+ self._key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -96,9 +97,7 @@ class Keyring(object):
def verify_json_for_server(self, server_name, json_object):
return logcontext.make_deferred_yieldable(
- self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ self.verify_json_objects_for_server([(server_name, json_object)])[0]
)
def verify_json_objects_for_server(self, server_and_json):
@@ -130,18 +129,15 @@ class Keyring(object):
if not key_ids:
return defer.fail(
SynapseError(
- 400,
- "Not signed by %s" % (server_name,),
- Codes.UNAUTHORIZED,
+ 400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
)
)
- logger.debug("Verifying for %s with key_ids %s",
- server_name, key_ids)
+ logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
- server_name, key_ids, json_object, defer.Deferred(),
+ server_name, key_ids, json_object, defer.Deferred()
)
verify_requests.append(verify_request)
@@ -179,15 +175,13 @@ class Keyring(object):
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
- rq.server_name: defer.Deferred()
- for rq in verify_requests
+ rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(
- [rq.server_name for rq in verify_requests],
- server_to_deferred,
+ [rq.server_name for rq in verify_requests], server_to_deferred
)
# Actually start fetching keys.
@@ -216,9 +210,7 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- verify_request.deferred.addBoth(
- remove_deferreds, verify_request,
- )
+ verify_request.deferred.addBoth(remove_deferreds, verify_request)
except Exception:
logger.exception("Error starting key lookups")
@@ -248,7 +240,8 @@ class Keyring(object):
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
- [w[0] for w in wait_on], loop_count,
+ [w[0] for w in wait_on],
+ loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
@@ -279,13 +272,6 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
- # These are functions that produce keys given a list of key ids
- key_fetch_fns = (
- self.get_keys_from_store, # First try the local store
- self.get_keys_from_perspectives, # Then try via perspectives
- self.get_keys_from_server, # Then try directly
- )
-
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
@@ -296,8 +282,8 @@ class Keyring(object):
verify_request.key_ids
)
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
+ for f in self._key_fetchers:
+ results = yield f.get_keys(missing_keys.items())
# We now need to figure out which verify requests we have keys
# for and which we don't
@@ -315,11 +301,15 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
- key = result_keys.get(key_id)
- if key:
+ fetch_key_result = result_keys.get(key_id)
+ if fetch_key_result:
with PreserveLoggingContext():
verify_request.deferred.callback(
- (server_name, key_id, key)
+ (
+ server_name,
+ key_id,
+ fetch_key_result.verify_key,
+ )
)
break
else:
@@ -335,13 +325,14 @@ class Keyring(object):
with PreserveLoggingContext():
for verify_request in requests_missing_keys:
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ verify_request.deferred.errback(
+ SynapseError(
+ 401,
+ "No key for %s with id %s"
+ % (verify_request.server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+ )
def on_err(err):
with PreserveLoggingContext():
@@ -351,17 +342,30 @@ class Keyring(object):
run_in_background(do_iterations).addErrback(on_err)
- @defer.inlineCallbacks
- def get_keys_from_store(self, server_name_and_key_ids):
+
+class KeyFetcher(object):
+ def get_keys(self, server_name_and_key_ids):
"""
Args:
- server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
+ server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
- server_name -> key_id -> VerifyKey
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
"""
+ raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+ """KeyFetcher impl which fetches keys from our data store"""
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
@@ -373,8 +377,127 @@ class Keyring(object):
keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys)
+
+class BaseV2KeyFetcher(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.config = hs.get_config()
+
@defer.inlineCallbacks
- def get_keys_from_perspectives(self, server_name_and_key_ids):
+ def process_v2_response(
+ self, from_server, response_json, time_added_ms, requested_ids=[]
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key request
+
+ This is used to parse either the entirety of the response from
+ GET /_matrix/key/v2/server, or a single entry from the list returned by
+ POST /_matrix/key/v2/query.
+
+ Checks that each signature in the response that claims to come from the origin
+ server is valid. (Does not check that there actually is such a signature, for
+ some reason.)
+
+ Stores the json in server_keys_json so that it can be used for future responses
+ to /_matrix/key/v2/query.
+
+ Args:
+ from_server (str): the name of the server producing this result: either
+ the origin server for a /_matrix/key/v2/server request, or the notary
+ for a /_matrix/key/v2/query.
+
+ response_json (dict): the json-decoded Server Keys response object
+
+ time_added_ms (int): the timestamp to record in server_keys_json
+
+ requested_ids (iterable[str]): a list of the key IDs that were requested.
+ We will store the json for these key ids as well as any that are
+ actually in the response
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+ # start by extracting the keys from the response, since they may be required
+ # to validate the signature on the response.
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+ )
+
+ # TODO: improve this signature checking
+ server_name = response_json["server_name"]
+ for key_id in response_json["signatures"].get(server_name, {}):
+ if key_id not in verify_keys:
+ raise KeyLookupError(
+ "Key response must include verification keys for all signatures"
+ )
+
+ verify_signed_json(
+ response_json, server_name, verify_keys[key_id].verify_key
+ )
+
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+ )
+
+ # re-sign the json with our own key, so that it is ready if we are asked to
+ # give it out as a notary server
+ signed_key_json = sign_json(
+ response_json, self.config.server_name, self.config.signing_key[0]
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+ # for reasons I don't quite understand, we store this json for the key ids we
+ # requested, as well as those we got.
+ updated_key_ids = set(requested_ids)
+ updated_key_ids.update(verify_keys)
+
+ yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.store.store_server_keys_json,
+ server_name=server_name,
+ key_id=key_id,
+ from_server=from_server,
+ ts_now_ms=time_added_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+ def __init__(self, hs):
+ super(PerspectivesKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.perspective_servers = self.config.perspectives
+
+ @defer.inlineCallbacks
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
+
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
@@ -383,25 +506,26 @@ class Keyring(object):
)
defer.returnValue(result)
except KeyLookupError as e:
- logger.warning(
- "Key lookup failed from %r: %s", perspective_name, e,
- )
+ logger.warning("Key lookup failed from %r: %s", perspective_name, e)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e),
+ type(e).__name__,
+ str(e),
)
defer.returnValue({})
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(get_key, p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(get_key, p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
union_of_keys = {}
for result in results:
@@ -411,33 +535,21 @@ class Keyring(object):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
- def get_keys_from_server(self, server_name_and_key_ids):
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct,
- server_name,
- key_ids,
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
- merged = {}
- for result in results:
- merged.update(result)
-
- defer.returnValue({
- server_name: keys
- for server_name, keys in merged.items()
- if keys
- })
+ def get_server_verify_key_v2_indirect(
+ self, server_names_and_key_ids, perspective_name, perspective_keys
+ ):
+ """
+ Args:
+ server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
+ list of (server_name, iterable[key_id]) tuples to fetch keys for
+ perspective_name (str): name of the notary server to query for the keys
+ perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+ notary server
- @defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
- perspective_name,
- perspective_keys):
+ Returns:
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+ from server_name -> key_id -> FetchKeyResult
+ """
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
@@ -448,9 +560,7 @@ class Keyring(object):
data={
u"server_keys": {
server_name: {
- key_id: {
- u"minimum_valid_until_ts": 0
- } for key_id in key_ids
+ key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
@@ -458,21 +568,20 @@ class Keyring(object):
long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ raise_from(KeyLookupError("Remote server returned an error"), e)
keys = {}
+ added_keys = []
- responses = query_response["server_keys"]
+ time_now_ms = self.clock.time_msec()
- for response in responses:
- if (u"signatures" not in response
- or perspective_name not in response[u"signatures"]):
+ for response in query_response["server_keys"]:
+ if (
+ u"signatures" not in response
+ or perspective_name not in response[u"signatures"]
+ ):
raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
@@ -482,9 +591,7 @@ class Keyring(object):
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
- response,
- perspective_name,
- perspective_keys[key_id]
+ response, perspective_name, perspective_keys[key_id]
)
verified = True
@@ -494,7 +601,7 @@ class Keyring(object):
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
- list(perspective_keys)
+ list(perspective_keys),
)
raise KeyLookupError(
"Response not signed with a known key for perspective"
@@ -502,196 +609,101 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
- perspective_name, response
+ perspective_name, response, time_added_ms=time_now_ms
)
server_name = response["server_name"]
+ added_keys.extend(
+ (server_name, key_id, key) for key_id, key in processed_response.items()
+ )
keys.setdefault(server_name, {}).update(processed_response)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store_keys,
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=response_keys,
- )
- for server_name, response_keys in keys.items()
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError))
+ yield self.store.store_server_verify_keys(
+ perspective_name, time_now_ms, added_keys
+ )
defer.returnValue(keys)
+
+class ServerKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the origin servers"""
+
+ def __init__(self, hs):
+ super(ServerKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+
+ @defer.inlineCallbacks
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.get_server_verify_key_v2_direct, server_name, key_ids
+ )
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue(
+ {server_name: keys for server_name, keys in merged.items() if keys}
+ )
+
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
- keys = {} # type: dict[str, nacl.signing.VerifyKey]
+ keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
+ time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
- path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
+ path="/_matrix/key/v2/server/"
+ + urllib.parse.quote(requested_key_id),
ignore_backoff=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ raise_from(KeyLookupError("Remote server returned an error"), e)
- if (u"signatures" not in response
- or server_name not in response[u"signatures"]):
+ if (
+ u"signatures" not in response
+ or server_name not in response[u"signatures"]
+ ):
raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name:
- raise KeyLookupError("Expected a response for server %r not %r" % (
- server_name, response["server_name"]
- ))
+ raise KeyLookupError(
+ "Expected a response for server %r not %r"
+ % (server_name, response["server_name"])
+ )
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
response_json=response,
+ time_added_ms=time_now_ms,
+ )
+ yield self.store.store_server_verify_keys(
+ server_name,
+ time_now_ms,
+ ((server_name, key_id, key) for key_id, key in response_keys.items()),
)
-
keys.update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
- )
defer.returnValue({server_name: keys})
- @defer.inlineCallbacks
- def process_v2_response(
- self, from_server, response_json, requested_ids=[],
- ):
- """Parse a 'Server Keys' structure from the result of a /key request
-
- This is used to parse either the entirety of the response from
- GET /_matrix/key/v2/server, or a single entry from the list returned by
- POST /_matrix/key/v2/query.
-
- Checks that each signature in the response that claims to come from the origin
- server is valid. (Does not check that there actually is such a signature, for
- some reason.)
-
- Stores the json in server_keys_json so that it can be used for future responses
- to /_matrix/key/v2/query.
-
- Args:
- from_server (str): the name of the server producing this result: either
- the origin server for a /_matrix/key/v2/server request, or the notary
- for a /_matrix/key/v2/query.
-
- response_json (dict): the json-decoded Server Keys response object
-
- requested_ids (iterable[str]): a list of the key IDs that were requested.
- We will store the json for these key ids as well as any that are
- actually in the response
-
- Returns:
- Deferred[dict[str, nacl.signing.VerifyKey]]:
- map from key_id to key object
- """
- time_now_ms = self.clock.time_msec()
- response_keys = {}
- verify_keys = {}
- for key_id, key_data in response_json["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.time_added = time_now_ms
- verify_keys[key_id] = verify_key
-
- old_verify_keys = {}
- for key_id, key_data in response_json["old_verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.expired = key_data["expired_ts"]
- verify_key.time_added = time_now_ms
- old_verify_keys[key_id] = verify_key
-
- server_name = response_json["server_name"]
- for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in response_json["verify_keys"]:
- raise KeyLookupError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response_json,
- server_name,
- verify_keys[key_id]
- )
-
- signed_key_json = sign_json(
- response_json,
- self.config.server_name,
- self.config.signing_key[0],
- )
-
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
- ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
-
- updated_key_ids = set(requested_ids)
- updated_key_ids.update(verify_keys)
- updated_key_ids.update(old_verify_keys)
-
- response_keys.update(verify_keys)
- response_keys.update(old_verify_keys)
-
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_keys_json,
- server_name=server_name,
- key_id=key_id,
- from_server=from_server,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
- for key_id in updated_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
- defer.returnValue(response_keys)
-
- def store_keys(self, server_name, from_server, verify_keys):
- """Store a collection of verify keys for a given server
- Args:
- server_name(str): The name of the server the keys are for.
- from_server(str): The server the keys were downloaded from.
- verify_keys(dict): A mapping of key_id to VerifyKey.
- Returns:
- A deferred that completes when the keys are stored.
- """
- # TODO(markjh): Store whether the keys have expired.
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_verify_key,
- server_name, server_name, key.time_added, key
- )
- for key_id, key in verify_keys.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
@@ -713,17 +725,19 @@ def _handle_key_deferred(verify_request):
except KeyLookupError as e:
logger.warn(
"Failed to download keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
+ server_name,
+ type(e).__name__,
+ str(e),
)
raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
+ 502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
+ server_name,
+ type(e).__name__,
+ str(e),
)
raise SynapseError(
401,
@@ -733,22 +747,24 @@ def _handle_key_deferred(verify_request):
json_object = verify_request.json_object
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
+ logger.debug(
+ "Got key %s %s:%s for server %s, verifying"
+ % (key_id, verify_key.alg, verify_key.version, server_name)
+ )
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
- server_name, verify_key.alg, verify_key.version,
+ server_name,
+ verify_key.alg,
+ verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
- "Invalid signature for server %s with key %s:%s: %s" % (
- server_name, verify_key.alg, verify_key.version, str(e),
- ),
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0684778882..cf4fad7de0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1916,6 +1916,11 @@ class FederationHandler(BaseHandler):
event.room_id, latest_event_ids=extrem_ids,
)
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id, current_state_ids,
+ )
+
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
@@ -1932,7 +1937,7 @@ class FederationHandler(BaseHandler):
self.auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
logger.warn(
- "Failed current state auth resolution for %r because %s",
+ "Soft-failing %r because %s",
event, e,
)
event.internal_metadata.soft_failed = True
@@ -2008,15 +2013,44 @@ class FederationHandler(BaseHandler):
Args:
origin (str):
- event (synapse.events.FrozenEvent):
+ event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
- auth_events (dict[(str, str)->str]):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ What we expect the event's auth_events to be, based on the event's
+ position in the dag. I think? maybe??
+
+ Also NB that this function adds entries to it.
+ Returns:
+ defer.Deferred[None]
+ """
+ room_version = yield self.store.get_room_version(event.room_id)
+
+ yield self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ try:
+ self.auth.check(room_version, event, auth_events=auth_events)
+ except AuthError as e:
+ logger.warn("Failed auth resolution for %r because %s", event, e)
+ raise e
+
+ @defer.inlineCallbacks
+ def _update_auth_events_and_context_for_auth(
+ self, origin, event, context, auth_events
+ ):
+ """Helper for do_auth. See there for docs.
+
+ Args:
+ origin (str):
+ event (synapse.events.EventBase):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
defer.Deferred[None]
"""
- # Check if we have all the auth events.
- current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(event.auth_event_ids())
if event.is_state():
@@ -2024,11 +2058,21 @@ class FederationHandler(BaseHandler):
else:
event_key = None
- if event_auth_events - current_state:
+ # if the event's auth_events refers to events which are not in our
+ # calculated auth_events, we need to fetch those events from somewhere.
+ #
+ # we start by fetching them from the store, and then try calling /event_auth/.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if missing_auth:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
- event_auth_events - current_state
+ missing_auth
)
+ logger.debug("Got events %s from store", have_events)
+ missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@@ -2037,13 +2081,12 @@ class FederationHandler(BaseHandler):
for e in auth_events.values()
})
- seen_events = set(have_events.keys())
-
- missing_auth = event_auth_events - seen_events - current_state
-
if missing_auth:
- logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
+ logger.info(
+ "auth_events contains unknown events: %s",
+ missing_auth,
+ )
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
@@ -2084,145 +2127,168 @@ class FederationHandler(BaseHandler):
have_events = yield self.store.get_seen_events_with_rejections(
event.auth_event_ids()
)
- seen_events = set(have_events.keys())
except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
+ if event.internal_metadata.is_outlier():
+ logger.info("Skipping auth_event fetch for outlier")
+ return
+
# FIXME: Assumes we have and stored all the state for all the
# prev_events
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- room_version = yield self.store.get_room_version(event.room_id)
+ if not different_auth:
+ return
- if different_auth and not event.internal_metadata.is_outlier():
- # Do auth conflict res.
- logger.info("Different auth: %s", different_auth)
-
- different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
- ).addErrback(unwrapFirstError)
-
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event
+ room_version = yield self.store.get_room_version(event.room_id)
+
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
+ d,
+ allow_none=True,
+ allow_rejected=False,
)
+ for d in different_auth
+ if d in have_events and not have_events[d]
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+ if different_events:
+ local_view = dict(auth_events)
+ remote_view = dict(auth_events)
+ remote_view.update({
+ (d.type, d.state_key): d for d in different_events if d
+ })
- auth_events.update(new_state)
+ new_state = yield self.state_handler.resolve_events(
+ room_version,
+ [list(local_view.values()), list(remote_view.values())],
+ event
+ )
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ auth_events.update(new_state)
- if different_auth and not event.internal_metadata.is_outlier():
- logger.info("Different auth after resolution: %s", different_auth)
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- # Only do auth resolution if we have something new to say.
- # We can't rove an auth failure.
- do_resolution = False
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
- provable = [
- RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
- ]
+ if not different_auth:
+ # we're done
+ return
- for e_id in different_auth:
- if e_id in have_events:
- if have_events[e_id] in provable:
- do_resolution = True
- break
+ logger.info(
+ "auth_events still refers to events which are not in the calculated auth "
+ "chain after state resolution: %s",
+ different_auth,
+ )
- if do_resolution:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- # 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
- )
+ # Only do auth resolution if we have something new to say.
+ # We can't prove an auth failure.
+ do_resolution = False
- try:
- # 2. Get remote difference.
- result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
- )
+ for e_id in different_auth:
+ if e_id in have_events:
+ if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
+ do_resolution = True
+ break
- seen_remotes = yield self.store.have_seen_events(
- [e.event_id for e in result["auth_chain"]]
- )
+ if not do_resolution:
+ logger.info(
+ "Skipping auth resolution due to lack of provable rejection reasons"
+ )
+ return
- # 3. Process any remote auth chain events we haven't seen.
- for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes:
- continue
+ logger.info("Doing auth resolution")
- if ev.event_id == event.event_id:
- continue
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
- try:
- auth_ids = ev.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
- }
- ev.internal_metadata.outlier = True
+ # 1. Get what we think is the auth chain.
+ auth_ids = yield self.auth.compute_auth_events(
+ event, prev_state_ids
+ )
+ local_auth_chain = yield self.store.get_auth_chain(
+ auth_ids, include_given=True
+ )
- logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
- )
+ try:
+ # 2. Get remote difference.
+ result = yield self.federation_client.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
+ seen_remotes = yield self.store.have_seen_events(
+ [e.event_id for e in result["auth_chain"]]
+ )
- if ev.event_id in event_auth_events:
- auth_events[(ev.type, ev.state_key)] = ev
- except AuthError:
- pass
+ # 3. Process any remote auth chain events we haven't seen.
+ for ev in result["auth_chain"]:
+ if ev.event_id in seen_remotes:
+ continue
- except Exception:
- # FIXME:
- logger.exception("Failed to query auth chain")
+ if ev.event_id == event.event_id:
+ continue
- # 4. Look at rejects and their proofs.
- # TODO.
+ try:
+ auth_ids = ev.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ or event.type == EventTypes.Create
+ }
+ ev.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s different_auth: %s",
+ event.event_id, e.event_id
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
- try:
- self.auth.check(room_version, event, auth_events=auth_events)
- except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
+
+ except Exception:
+ # FIXME:
+ logger.exception("Failed to query auth chain")
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7e40cb6502..0b02469ceb 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -604,6 +604,20 @@ class EventCreationHandler(object):
self.validator.validate_new(event)
+ # If this event is an annotation then we check that that the sender
+ # can't annotate the same way twice (e.g. stops users from liking an
+ # event multiple times).
+ relation = event.content.get("m.relates_to", {})
+ if relation.get("rel_type") == RelationTypes.ANNOTATION:
+ relates_to = relation["event_id"]
+ aggregation_key = relation["key"]
+
+ already_exists = yield self.store.has_user_annotated_event(
+ relates_to, event.type, aggregation_key, event.sender,
+ )
+ if already_exists:
+ raise SynapseError(400, "Can't send same reaction twice")
+
logger.debug(
"Created event %s",
event.event_id,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e37ae96899..4a17911a87 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -70,6 +70,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self.config = hs.config
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@@ -475,7 +476,11 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
- room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
+ room_version = config.get(
+ "room_version",
+ self.config.default_room_version.identifier,
+ )
+
if not isinstance(room_version, string_types):
raise SynapseError(
400,
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
new file mode 100644
index 0000000000..0e92b405ba
--- /dev/null
+++ b/synapse/handlers/stats.py
@@ -0,0 +1,325 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.handlers.state_deltas import StateDeltasHandler
+from synapse.metrics import event_processing_positions
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class StatsHandler(StateDeltasHandler):
+ """Handles keeping the *_stats tables updated with a simple time-series of
+ information about the users, rooms and media on the server, such that admins
+ have some idea of who is consuming their resources.
+
+ Heavily derived from UserDirectoryHandler
+ """
+
+ def __init__(self, hs):
+ super(StatsHandler, self).__init__(hs)
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ # The current position in the current_state_delta stream
+ self.pos = None
+
+ # Guard to ensure we only process deltas one at a time
+ self._is_processing = False
+
+ if hs.config.stats_enabled:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # We kick this off so that we don't have to wait for a change before
+ # we start populating stats
+ self.clock.call_later(0, self.notify_new_event)
+
+ def notify_new_event(self):
+ """Called when there may be more deltas to process
+ """
+ if not self.hs.config.stats_enabled:
+ return
+
+ if self._is_processing:
+ return
+
+ @defer.inlineCallbacks
+ def process():
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
+ self._is_processing = True
+ run_as_background_process("stats.notify_new_event", process)
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # If self.pos is None then means we haven't fetched it from DB
+ if self.pos is None:
+ self.pos = yield self.store.get_stats_stream_pos()
+
+ # If still None then the initial background update hasn't happened yet
+ if self.pos is None:
+ defer.returnValue(None)
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "stats_delta"):
+ deltas = yield self.store.get_current_state_deltas(self.pos)
+ if not deltas:
+ return
+
+ logger.info("Handling %d state deltas", len(deltas))
+ yield self._handle_deltas(deltas)
+
+ self.pos = deltas[-1]["stream_id"]
+ yield self.store.update_stats_stream_pos(self.pos)
+
+ event_processing_positions.labels("stats").set(self.pos)
+
+ @defer.inlineCallbacks
+ def _handle_deltas(self, deltas):
+ """
+ Called with the state deltas to process
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ stream_id = delta["stream_id"]
+ prev_event_id = delta["prev_event_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ token = yield self.store.get_earliest_token_for_room_stats(room_id)
+
+ # If the earliest token to begin from is larger than our current
+ # stream ID, skip processing this delta.
+ if token is not None and token >= stream_id:
+ logger.debug(
+ "Ignoring: %s as earlier than this room's initial ingestion event",
+ event_id,
+ )
+ continue
+
+ if event_id is None and prev_event_id is None:
+ # Errr...
+ continue
+
+ event_content = {}
+
+ if event_id is not None:
+ event_content = (yield self.store.get_event(event_id)).content or {}
+
+ # quantise time to the nearest bucket
+ now = yield self.store.get_received_ts(event_id)
+ now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size
+
+ if typ == EventTypes.Member:
+ # we could use _get_key_change here but it's a bit inefficient
+ # given we're not testing for a specific result; might as well
+ # just grab the prev_membership and membership strings and
+ # compare them.
+ prev_event_content = {}
+ if prev_event_id is not None:
+ prev_event_content = (
+ yield self.store.get_event(prev_event_id)
+ ).content
+
+ membership = event_content.get("membership", Membership.LEAVE)
+ prev_membership = prev_event_content.get("membership", Membership.LEAVE)
+
+ if prev_membership == membership:
+ continue
+
+ if prev_membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", -1
+ )
+ elif prev_membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", -1
+ )
+ elif prev_membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", -1
+ )
+ elif prev_membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", -1
+ )
+ else:
+ err = "%s is not a valid prev_membership" % (repr(prev_membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ if membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", +1
+ )
+ elif membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", +1
+ )
+ elif membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", +1
+ )
+ elif membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", +1
+ )
+ else:
+ err = "%s is not a valid membership" % (repr(membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ user_id = state_key
+ if self.is_mine_id(user_id):
+ # update user_stats as it's one of our users
+ public = yield self._is_public_room(room_id)
+
+ if membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ -1,
+ )
+ elif membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ +1,
+ )
+
+ elif typ == EventTypes.Create:
+ # Newly created room. Add it with all blank portions.
+ yield self.store.update_room_state(
+ room_id,
+ {
+ "join_rules": None,
+ "history_visibility": None,
+ "encryption": None,
+ "name": None,
+ "topic": None,
+ "avatar": None,
+ "canonical_alias": None,
+ },
+ )
+
+ elif typ == EventTypes.JoinRules:
+ yield self.store.update_room_state(
+ room_id, {"join_rules": event_content.get("join_rule")}
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "join_rule", JoinRules.PUBLIC
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.RoomHistoryVisibility:
+ yield self.store.update_room_state(
+ room_id,
+ {"history_visibility": event_content.get("history_visibility")},
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "history_visibility", "world_readable"
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.Encryption:
+ yield self.store.update_room_state(
+ room_id, {"encryption": event_content.get("algorithm")}
+ )
+ elif typ == EventTypes.Name:
+ yield self.store.update_room_state(
+ room_id, {"name": event_content.get("name")}
+ )
+ elif typ == EventTypes.Topic:
+ yield self.store.update_room_state(
+ room_id, {"topic": event_content.get("topic")}
+ )
+ elif typ == EventTypes.RoomAvatar:
+ yield self.store.update_room_state(
+ room_id, {"avatar": event_content.get("url")}
+ )
+ elif typ == EventTypes.CanonicalAlias:
+ yield self.store.update_room_state(
+ room_id, {"canonical_alias": event_content.get("alias")}
+ )
+
+ @defer.inlineCallbacks
+ def update_public_room_stats(self, ts, room_id, is_public):
+ """
+ Increment/decrement a user's number of public rooms when a room they are
+ in changes to/from public visibility.
+
+ Args:
+ ts (int): Timestamp in seconds
+ room_id (str)
+ is_public (bool)
+ """
+ # For now, blindly iterate over all local users in the room so that
+ # we can handle the whole problem of copying buckets over as needed
+ user_ids = yield self.store.get_users_in_room(room_id)
+
+ for user_id in user_ids:
+ if self.hs.is_mine(UserID.from_string(user_id)):
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "public_rooms", +1 if is_public else -1
+ )
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "private_rooms", -1 if is_public else +1
+ )
+
+ @defer.inlineCallbacks
+ def _is_public_room(self, room_id):
+ join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules)
+ history_visibility = yield self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility
+ )
+
+ if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or (
+ (
+ history_visibility
+ and history_visibility.content.get("history_visibility")
+ == "world_readable"
+ )
+ ):
+ defer.returnValue(True)
+ else:
+ defer.returnValue(False)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index fdcfb90a7e..f64baa4d58 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -16,7 +16,12 @@
import logging
-from pkg_resources import DistributionNotFound, VersionConflict, get_distribution
+from pkg_resources import (
+ DistributionNotFound,
+ Requirement,
+ VersionConflict,
+ get_provider,
+)
logger = logging.getLogger(__name__)
@@ -69,14 +74,6 @@ REQUIREMENTS = [
"attrs>=17.4.0",
"netaddr>=0.7.18",
-
- # requests is a transitive dep of treq, and urlib3 is a transitive dep
- # of requests, as well as of sentry-sdk.
- #
- # As of requests 2.21, requests does not yet support urllib3 1.25.
- # (If we do not pin it here, pip will give us the latest urllib3
- # due to the dep via sentry-sdk.)
- "urllib3<1.25",
]
CONDITIONAL_REQUIREMENTS = {
@@ -91,7 +88,13 @@ CONDITIONAL_REQUIREMENTS = {
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
- "acme": ["txacme>=0.9.2"],
+ "acme": [
+ "txacme>=0.9.2",
+
+ # txacme depends on eliot. Eliot 1.8.0 is incompatible with
+ # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
+ 'eliot<1.8.0;python_version<"3.5.3"',
+ ],
"saml2": ["pysaml2>=4.5.0"],
"systemd": ["systemd-python>=231"],
@@ -125,10 +128,10 @@ class DependencyException(Exception):
@property
def dependencies(self):
for i in self.args[0]:
- yield '"' + i + '"'
+ yield "'" + i + "'"
-def check_requirements(for_feature=None, _get_distribution=get_distribution):
+def check_requirements(for_feature=None):
deps_needed = []
errors = []
@@ -139,7 +142,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in reqs:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -157,7 +160,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in OPTS:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -175,6 +178,23 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
raise DependencyException(deps_needed)
+def _check_requirement(dependency_string):
+ """Parses a dependency string, and checks if the specified requirement is installed
+
+ Raises:
+ VersionConflict if the requirement is installed, but with the the wrong version
+ DistributionNotFound if nothing is found to provide the requirement
+ """
+ req = Requirement.parse(dependency_string)
+
+ # first check if the markers specify that this requirement needs installing
+ if req.marker is not None and not req.marker.evaluate():
+ # not required for this environment
+ return
+
+ get_provider(req)
+
+
if __name__ == "__main__":
import sys
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index a868d06098..2b4892330c 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
@@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet):
"""
super(CapabilitiesRestServlet, self).__init__()
self.hs = hs
+ self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
- "default": DEFAULT_ROOM_VERSION.identifier,
+ "default": self.config.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index eb8782aa6e..21c3c807b9 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.keyring = hs.get_keyring()
+ self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -217,7 +217,7 @@ class RemoteKey(Resource):
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
- yield self.keyring.get_server_verify_key_v2_direct(
+ yield self.fetcher.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
diff --git a/synapse/server.py b/synapse/server.py
index 80d40b9272..9229a68a8d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -72,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.user_directory import UserDirectoryHandler
@@ -139,6 +140,7 @@ class HomeServer(object):
'acme_handler',
'auth_handler',
'device_handler',
+ 'stats_handler',
'e2e_keys_handler',
'e2e_room_keys_handler',
'event_handler',
@@ -191,6 +193,7 @@ class HomeServer(object):
REQUIRED_ON_MASTER_STARTUP = [
"user_directory_handler",
+ "stats_handler"
]
# This is overridden in derived application classes
@@ -474,6 +477,9 @@ class HomeServer(object):
def build_secrets(self):
return Secrets()
+ def build_stats_handler(self):
+ return StatsHandler(self)
+
def build_spam_checker(self):
return SpamChecker(self)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7522d3fd57..66675d08ae 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -55,6 +55,7 @@ from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
+from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
@@ -100,6 +101,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ StatsStore,
RelationsStore,
):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..fa6839ceca 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -227,6 +229,8 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+ self._account_validity = self.hs.config.account_validity
+
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -243,6 +247,14 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -275,6 +287,52 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ @defer.inlineCallbacks
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL;"
+ )
+ txn.execute(sql, [])
+
+ res = self.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(txn, user["name"])
+
+ yield self.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ },
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 881d6d0126..2ffc27ff41 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -575,10 +575,11 @@ class EventsStore(
def _get_events(txn, batch):
sql = """
- SELECT prev_event_id
+ SELECT prev_event_id, internal_metadata
FROM event_edges
INNER JOIN events USING (event_id)
LEFT JOIN rejections USING (event_id)
+ LEFT JOIN event_json USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
@@ -588,7 +589,11 @@ class EventsStore(
)
txn.execute(sql, batch)
- results.extend(r[0] for r in txn)
+ results.extend(
+ r[0]
+ for r in txn
+ if not json.loads(r[1]).get("soft_failed")
+ )
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index adc6cf26b5..21b353cad3 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -610,4 +610,28 @@ class EventsWorkerStore(SQLBaseStore):
return res
- return self.runInteraction("get_rejection_reasons", f)
+ return self.runInteraction("get_seen_events_with_rejections", f)
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn, room_id
+ )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 7036541792..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
import six
+import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
db_binary_type = memoryview
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+ verify_key = attr.ib() # VerifyKey: the key itself
+ valid_until_ts = attr.ib() # int: how long we can use this key for
+
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
- map from (server_name, key_id) -> VerifyKey, or None if the key is
+ Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+ map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
- "SELECT server_name, key_id, verify_key FROM server_signature_keys "
- "WHERE 1=0"
+ "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+ "FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
- server_name, key_id, key_bytes = row
- keys[(server_name, key_id)] = decode_verify_key_bytes(
- key_id, bytes(key_bytes)
+ server_name, key_id, key_bytes, ts_valid_until_ms = row
+ res = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+ valid_until_ts=ts_valid_until_ms,
)
+ keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore):
return self.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_key(
- self, server_name, from_server, time_now_ms, verify_key
- ):
- """Stores a NACL verification key for the given server.
+ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ """Stores NACL verification keys for remote servers.
Args:
- server_name (str): The name of the server.
- from_server (str): Where the verification key was looked up
- time_now_ms (int): The time now in milliseconds
- verify_key (nacl.signing.VerifyKey): The NACL verify key.
+ from_server (str): Where the verification keys were looked up
+ ts_added_ms (int): The time to record that the key was added
+ verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ keys to be stored. Each entry is a triplet of
+ (server_name, key_id, key).
"""
- key_id = "%s:%s" % (verify_key.alg, verify_key.version)
-
- # XXX fix this to not need a lock (#3819)
- def _txn(txn):
- self._simple_upsert_txn(
- txn,
- table="server_signature_keys",
- keyvalues={"server_name": server_name, "key_id": key_id},
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": db_binary_type(verify_key.encode()),
- },
+ key_values = []
+ value_values = []
+ invalidations = []
+ for server_name, key_id, fetch_result in verify_keys:
+ key_values.append((server_name, key_id))
+ value_values.append(
+ (
+ from_server,
+ ts_added_ms,
+ fetch_result.valid_until_ts,
+ db_binary_type(fetch_result.verify_key.encode()),
+ )
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- txn.call_after(
- self._get_server_verify_key.invalidate, ((server_name, key_id),)
- )
-
- return self.runInteraction("store_server_verify_key", _txn)
+ invalidations.append((server_name, key_id))
+
+ def _invalidate(res):
+ f = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ f((i, ))
+ return res
+
+ return self.runInteraction(
+ "store_server_verify_keys",
+ self._simple_upsert_many_txn,
+ table="server_signature_keys",
+ key_names=("server_name", "key_id"),
+ key_values=key_values,
+ value_names=(
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "verify_key",
+ ),
+ value_values=value_values,
+ ).addCallback(_invalidate)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a06a83d6..4cf159ba81 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -725,17 +727,7 @@ class RegistrationStore(
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- }
- )
+ self.set_expiration_date_for_user_txn(txn, user_id)
if token:
# it's possible for this to get a conflict, but only for a single user
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 493abe405e..4c83800cca 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -280,7 +280,7 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = ""
sql = """
- SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+ SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
@@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
def _get_applicable_edit_txn(txn):
- txn.execute(
- sql, (event_id, RelationTypes.REPLACE,)
- )
+ txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
@@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore):
edit_event = yield self.get_event(edit_id, allow_none=True)
defer.returnValue(edit_event)
+ def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ """Check if a user has already annotated an event with the same key
+ (e.g. already liked an event).
+
+ Args:
+ parent_id (str): The event being annotated
+ event_type (str): The event type of the annotation
+ aggregation_key (str): The aggregation key of the annotation
+ sender (str): The sender of the annotation
+
+ Returns:
+ Deferred[bool]
+ """
+
+ sql = """
+ SELECT 1 FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND type = ?
+ AND sender = ?
+ AND aggregation_key = ?
+ LIMIT 1;
+ """
+
+ def _get_if_user_has_annotated_event(txn):
+ txn.execute(
+ sql,
+ (
+ parent_id,
+ RelationTypes.ANNOTATION,
+ event_type,
+ sender,
+ aggregation_key,
+ ),
+ )
+
+ return bool(txn.fetchone())
+
+ return self.runInteraction(
+ "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+ )
+
class RelationsStore(RelationsWorkerStore):
def _handle_event_relations(self, txn, event):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 57df17bcc2..4bd1669458 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -142,6 +142,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction("get_room_summary", _get_room_summary_txn)
+ def _get_user_count_in_room_txn(self, txn, room_id, membership):
+ """
+ See get_user_count_in_room.
+ """
+ sql = (
+ "SELECT count(*) FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
+ " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
+ )
+
+ txn.execute(sql, (room_id, membership))
+ row = txn.fetchone()
+ return row[0]
+
+ def get_user_count_in_room(self, room_id, membership):
+ """
+ Get the user count in a room with a particular membership.
+
+ Args:
+ room_id (str)
+ membership (Membership)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_users_in_room", self._get_user_count_in_room_txn, room_id, membership
+ )
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
new file mode 100644
index 0000000000..c01aa9d2d9
--- /dev/null
+++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* When we can use this key until, before we have to refresh it. */
+ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
+
+UPDATE server_signature_keys SET ts_valid_until_ms = (
+ SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
+ skj.server_name = server_signature_keys.server_name AND
+ skj.key_id = server_signature_keys.key_id
+);
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql
new file mode 100644
index 0000000000..652e58308e
--- /dev/null
+++ b/synapse/storage/schema/delta/54/stats.sql
@@ -0,0 +1,80 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stats_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO stats_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_stats (
+ user_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ public_rooms INT NOT NULL,
+ private_rooms INT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts);
+
+CREATE TABLE room_stats (
+ room_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ current_state_events INT NOT NULL,
+ joined_members INT NOT NULL,
+ invited_members INT NOT NULL,
+ left_members INT NOT NULL,
+ banned_members INT NOT NULL,
+ state_events INT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts);
+
+-- cache of current room state; useful for the publicRooms list
+CREATE TABLE room_state (
+ room_id TEXT NOT NULL,
+ join_rules TEXT,
+ history_visibility TEXT,
+ encryption TEXT,
+ name TEXT,
+ topic TEXT,
+ avatar TEXT,
+ canonical_alias TEXT
+ -- get aliases straight from the right table
+);
+
+CREATE UNIQUE INDEX room_state_room ON room_state(room_id);
+
+CREATE TABLE room_stats_earliest_token (
+ room_id TEXT NOT NULL,
+ token BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id);
+
+-- Set up staging tables
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_stats_createtables', '{}');
+
+-- Run through each room and update stats
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_rooms', '{}', 'populate_stats_createtables');
+
+-- Clean up staging tables
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms');
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 31a0279b18..5fdb442104 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self._simple_select_one_onecol(
+ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
+ return self._simple_select_one_onecol_txn(
+ txn,
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
- desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self.runInteraction(
+ "get_max_stream_id_in_current_state_deltas",
+ self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
new file mode 100644
index 0000000000..71b80a891d
--- /dev/null
+++ b/synapse/storage/stats.py
@@ -0,0 +1,450 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state_deltas import StateDeltasStore
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
+
+# these fields track absolutes (e.g. total number of rooms on the server)
+ABSOLUTE_STATS_FIELDS = {
+ "room": (
+ "current_state_events",
+ "joined_members",
+ "invited_members",
+ "left_members",
+ "banned_members",
+ "state_events",
+ ),
+ "user": ("public_rooms", "private_rooms"),
+}
+
+TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
+
+TEMP_TABLE = "_temp_populate_stats"
+
+
+class StatsStore(StateDeltasStore):
+ def __init__(self, db_conn, hs):
+ super(StatsStore, self).__init__(db_conn, hs)
+
+ self.server_name = hs.hostname
+ self.clock = self.hs.get_clock()
+ self.stats_enabled = hs.config.stats_enabled
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ self.register_background_update_handler(
+ "populate_stats_createtables", self._populate_stats_createtables
+ )
+ self.register_background_update_handler(
+ "populate_stats_process_rooms", self._populate_stats_process_rooms
+ )
+ self.register_background_update_handler(
+ "populate_stats_cleanup", self._populate_stats_cleanup
+ )
+
+ @defer.inlineCallbacks
+ def _populate_stats_createtables(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ # Get all the rooms that we want to process.
+ def _make_staging_area(txn):
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_position(position TEXT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ # Get rooms we want to process from the database
+ sql = """
+ SELECT room_id, count(*) FROM current_state_events
+ GROUP BY room_id
+ """
+ txn.execute(sql)
+ rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
+ self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ del rooms
+
+ new_pos = yield self.get_max_stream_id_in_current_state_deltas()
+ yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
+ yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ self.get_earliest_token_for_room_stats.invalidate_all()
+
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_cleanup(self, progress, batch_size):
+ """
+ Update the user directory stream position, then clean up the old tables.
+ """
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ position = yield self._simple_select_one_onecol(
+ TEMP_TABLE + "_position", None, "position"
+ )
+ yield self.update_stats_stream_pos(position)
+
+ def _delete_staging_area(txn):
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
+
+ yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
+
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_rooms(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ # If we don't have progress filed, delete everything.
+ if not progress:
+ yield self.delete_all_stats()
+
+ def _get_next_batch(txn):
+ # Only fetch 250 rooms, so we don't fetch too many at once, even
+ # if those 250 rooms have less than batch_size state events.
+ sql = """
+ SELECT room_id, events FROM %s_rooms
+ ORDER BY events DESC
+ LIMIT 250
+ """ % (
+ TEMP_TABLE,
+ )
+ txn.execute(sql)
+ rooms_to_work_on = txn.fetchall()
+
+ if not rooms_to_work_on:
+ return None
+
+ # Get how many are left to process, so we can give status on how
+ # far we are in processing
+ txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
+ progress["remaining"] = txn.fetchone()[0]
+
+ return rooms_to_work_on
+
+ rooms_to_work_on = yield self.runInteraction(
+ "populate_stats_temp_read", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ logger.info(
+ "Processing the next %d rooms of %d remaining",
+ (len(rooms_to_work_on), progress["remaining"]),
+ )
+
+ # Number of state events we've processed by going through each room
+ processed_event_count = 0
+
+ for room_id, event_count in rooms_to_work_on:
+
+ current_state_ids = yield self.get_current_state_ids(room_id)
+
+ join_rules = yield self.get_event(
+ current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True
+ )
+ history_visibility = yield self.get_event(
+ current_state_ids.get((EventTypes.RoomHistoryVisibility, "")),
+ allow_none=True,
+ )
+ encryption = yield self.get_event(
+ current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True
+ )
+ name = yield self.get_event(
+ current_state_ids.get((EventTypes.Name, "")), allow_none=True
+ )
+ topic = yield self.get_event(
+ current_state_ids.get((EventTypes.Topic, "")), allow_none=True
+ )
+ avatar = yield self.get_event(
+ current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True
+ )
+ canonical_alias = yield self.get_event(
+ current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True
+ )
+
+ def _or_none(x, arg):
+ if x:
+ return x.content.get(arg)
+ return None
+
+ yield self.update_room_state(
+ room_id,
+ {
+ "join_rules": _or_none(join_rules, "join_rule"),
+ "history_visibility": _or_none(
+ history_visibility, "history_visibility"
+ ),
+ "encryption": _or_none(encryption, "algorithm"),
+ "name": _or_none(name, "name"),
+ "topic": _or_none(topic, "topic"),
+ "avatar": _or_none(avatar, "url"),
+ "canonical_alias": _or_none(canonical_alias, "alias"),
+ },
+ )
+
+ now = self.hs.get_reactor().seconds()
+
+ # quantise time to the nearest bucket
+ now = (now // self.stats_bucket_size) * self.stats_bucket_size
+
+ def _fetch_data(txn):
+
+ # Get the current token of the room
+ current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+
+ current_state_events = len(current_state_ids)
+ joined_members = self._get_user_count_in_room_txn(
+ txn, room_id, Membership.JOIN
+ )
+ invited_members = self._get_user_count_in_room_txn(
+ txn, room_id, Membership.INVITE
+ )
+ left_members = self._get_user_count_in_room_txn(
+ txn, room_id, Membership.LEAVE
+ )
+ banned_members = self._get_user_count_in_room_txn(
+ txn, room_id, Membership.BAN
+ )
+ total_state_events = self._get_total_state_event_counts_txn(
+ txn, room_id
+ )
+
+ self._update_stats_txn(
+ txn,
+ "room",
+ room_id,
+ now,
+ {
+ "bucket_size": self.stats_bucket_size,
+ "current_state_events": current_state_events,
+ "joined_members": joined_members,
+ "invited_members": invited_members,
+ "left_members": left_members,
+ "banned_members": banned_members,
+ "state_events": total_state_events,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ "room_stats_earliest_token",
+ {"room_id": room_id, "token": current_token},
+ )
+
+ yield self.runInteraction("update_room_stats", _fetch_data)
+
+ # We've finished a room. Delete it from the table.
+ yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ # Update the remaining counter.
+ progress["remaining"] -= 1
+ yield self.runInteraction(
+ "populate_stats",
+ self._background_update_progress_txn,
+ "populate_stats_process_rooms",
+ progress,
+ )
+
+ processed_event_count += event_count
+
+ if processed_event_count > batch_size:
+ # Don't process any more rooms, we've hit our batch size.
+ defer.returnValue(processed_event_count)
+
+ defer.returnValue(processed_event_count)
+
+ def delete_all_stats(self):
+ """
+ Delete all statistics records.
+ """
+
+ def _delete_all_stats_txn(txn):
+ txn.execute("DELETE FROM room_state")
+ txn.execute("DELETE FROM room_stats")
+ txn.execute("DELETE FROM room_stats_earliest_token")
+ txn.execute("DELETE FROM user_stats")
+
+ return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
+
+ def get_stats_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="stats_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="stats_stream_pos",
+ )
+
+ def update_stats_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="stats_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_stats_stream_pos",
+ )
+
+ def update_room_state(self, room_id, fields):
+ """
+ Args:
+ room_id (str)
+ fields (dict[str:Any])
+ """
+ return self._simple_upsert(
+ table="room_state",
+ keyvalues={"room_id": room_id},
+ values=fields,
+ desc="update_room_state",
+ )
+
+ def get_deltas_for_room(self, room_id, start, size=100):
+ """
+ Get statistics deltas for a given room.
+
+ Args:
+ room_id (str)
+ start (int): Pagination start. Number of entries, not timestamp.
+ size (int): How many entries to return.
+
+ Returns:
+ Deferred[list[dict]], where the dict has the keys of
+ ABSOLUTE_STATS_FIELDS["room"] and "ts".
+ """
+ return self._simple_select_list_paginate(
+ "room_stats",
+ {"room_id": room_id},
+ "ts",
+ start,
+ size,
+ retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
+ order_direction="DESC",
+ )
+
+ def get_all_room_state(self):
+ return self._simple_select_list(
+ "room_state", None, retcols=("name", "topic", "canonical_alias")
+ )
+
+ @cached()
+ def get_earliest_token_for_room_stats(self, room_id):
+ """
+ Fetch the "earliest token". This is used by the room stats delta
+ processor to ignore deltas that have been processed between the
+ start of the background task and any particular room's stats
+ being calculated.
+
+ Returns:
+ Deferred[int]
+ """
+ return self._simple_select_one_onecol(
+ "room_stats_earliest_token",
+ {"room_id": room_id},
+ retcol="token",
+ allow_none=True,
+ )
+
+ def update_stats(self, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert(
+ table=table,
+ keyvalues={id_col: stats_id, "ts": ts},
+ values=fields,
+ desc="update_stats",
+ )
+
+ def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert_txn(
+ txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
+ )
+
+ def update_stats_delta(self, ts, stats_type, stats_id, field, value):
+ def _update_stats_delta(txn):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+
+ sql = (
+ "SELECT * FROM %s"
+ " WHERE %s=? and ts=("
+ " SELECT MAX(ts) FROM %s"
+ " WHERE %s=?"
+ ")"
+ ) % (table, id_col, table, id_col)
+ txn.execute(sql, (stats_id, stats_id))
+ rows = self.cursor_to_dict(txn)
+ if len(rows) == 0:
+ # silently skip as we don't have anything to apply a delta to yet.
+ # this tries to minimise any race between the initial sync and
+ # subsequent deltas arriving.
+ return
+
+ current_ts = ts
+ latest_ts = rows[0]["ts"]
+ if current_ts < latest_ts:
+ # This one is in the past, but we're just encountering it now.
+ # Mark it as part of the current bucket.
+ current_ts = latest_ts
+ elif ts != latest_ts:
+ # we have to copy our absolute counters over to the new entry.
+ values = {
+ key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
+ }
+ values[id_col] = stats_id
+ values["ts"] = ts
+ values["bucket_size"] = self.stats_bucket_size
+
+ self._simple_insert_txn(txn, table=table, values=values)
+
+ # actually update the new value
+ if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
+ self._simple_update_txn(
+ txn,
+ table=table,
+ keyvalues={id_col: stats_id, "ts": current_ts},
+ updatevalues={field: value},
+ )
+ else:
+ sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
+ table,
+ field,
+ field,
+ id_col,
+ )
+ txn.execute(sql, (value, stats_id, current_ts))
+
+ return self.runInteraction("update_stats_delta", _update_stats_delta)
|