diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f47c33a074..dd373fa4b8 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -69,6 +69,7 @@ class EventTypes(object):
Redaction = "m.room.redaction"
ThirdPartyInvite = "m.room.third_party_invite"
Encryption = "m.room.encryption"
+ RelatedGroups = "m.room.related_groups"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
@@ -102,46 +103,6 @@ class ThirdPartyEntityKind(object):
LOCATION = "location"
-class RoomVersions(object):
- V1 = "1"
- V2 = "2"
- V3 = "3"
- STATE_V2_TEST = "state-v2-test"
-
-
-class RoomDisposition(object):
- STABLE = "stable"
- UNSTABLE = "unstable"
-
-
-# the version we will give rooms which are created on this server
-DEFAULT_ROOM_VERSION = RoomVersions.V1
-
-# vdh-test-version is a placeholder to get room versioning support working and tested
-# until we have a working v2.
-KNOWN_ROOM_VERSIONS = {
- RoomVersions.V1,
- RoomVersions.V2,
- RoomVersions.V3,
- RoomVersions.STATE_V2_TEST,
- RoomVersions.V3,
-}
-
-
-class EventFormatVersions(object):
- """This is an internal enum for tracking the version of the event format,
- independently from the room version.
- """
- V1 = 1
- V2 = 2
-
-
-KNOWN_EVENT_FORMAT_VERSIONS = {
- EventFormatVersions.V1,
- EventFormatVersions.V2,
-}
-
-
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
new file mode 100644
index 0000000000..e77abe1040
--- /dev/null
+++ b/synapse/api/room_versions.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import attr
+
+
+class EventFormatVersions(object):
+ """This is an internal enum for tracking the version of the event format,
+ independently from the room version.
+ """
+ V1 = 1 # $id:server format
+ V2 = 2 # MSC1659-style $hash format: introduced for room v3
+
+
+KNOWN_EVENT_FORMAT_VERSIONS = {
+ EventFormatVersions.V1,
+ EventFormatVersions.V2,
+}
+
+
+class StateResolutionVersions(object):
+ """Enum to identify the state resolution algorithms"""
+ V1 = 1 # room v1 state res
+ V2 = 2 # MSC1442 state res: room v2 and later
+
+
+class RoomDisposition(object):
+ STABLE = "stable"
+ UNSTABLE = "unstable"
+
+
+@attr.s(slots=True, frozen=True)
+class RoomVersion(object):
+ """An object which describes the unique attributes of a room version."""
+
+ identifier = attr.ib() # str; the identifier for this version
+ disposition = attr.ib() # str; one of the RoomDispositions
+ event_format = attr.ib() # int; one of the EventFormatVersions
+ state_res = attr.ib() # int; one of the StateResolutionVersions
+
+
+class RoomVersions(object):
+ V1 = RoomVersion(
+ "1",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V1,
+ StateResolutionVersions.V1,
+ )
+ STATE_V2_TEST = RoomVersion(
+ "state-v2-test",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V1,
+ StateResolutionVersions.V2,
+ )
+ V2 = RoomVersion(
+ "2",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V1,
+ StateResolutionVersions.V2,
+ )
+ V3 = RoomVersion(
+ "3",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V2,
+ StateResolutionVersions.V2,
+ )
+
+
+# the version we will give rooms which are created on this server
+DEFAULT_ROOM_VERSION = RoomVersions.V1
+
+
+KNOWN_ROOM_VERSIONS = {
+ v.identifier: v for v in (
+ RoomVersions.V1,
+ RoomVersions.V2,
+ RoomVersions.V3,
+ RoomVersions.STATE_V2_TEST,
+ )
+} # type: dict[str, RoomVersion]
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 9711a7147c..1d43f2b075 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -38,7 +38,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams import ReceiptsStream
+from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 869c028d1f..79be977ea6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -518,6 +518,7 @@ def run(hs):
uptime = 0
stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
version = sys.version_info
@@ -558,7 +559,6 @@ def run(hs):
stats["database_engine"] = hs.get_datastore().database_engine_name
stats["database_server_version"] = hs.get_datastore().get_server_version()
-
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 9163b56d86..5388def28a 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -48,6 +48,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.streams.events import EventsStreamEventRow
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
@@ -369,7 +370,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
- event = yield self.store.get_event(row.event_id)
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ event = yield self.store.get_event(row.data.event_id)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index d1ab9512cd..355f5aa71d 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -36,6 +36,10 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamCurrentStateRow,
+)
from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -73,19 +77,18 @@ class UserDirectorySlaveStore(
prefilled_cache=curr_state_delta_prefill,
)
- self._current_state_delta_pos = events_max
-
def stream_positions(self):
result = super(UserDirectorySlaveStore, self).stream_positions()
- result["current_state_deltas"] = self._current_state_delta_pos
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "current_state_deltas":
- self._current_state_delta_pos = token
+ if stream_name == EventsStream.NAME:
+ self._stream_id_gen.advance(token)
for row in rows:
+ if row.type != EventsStreamCurrentStateRow.TypeId:
+ continue
self._curr_state_delta_stream_cache.entity_has_changed(
- row.room_id, token
+ row.data.room_id, token
)
return super(UserDirectorySlaveStore, self).process_replication_rows(
stream_name, token, rows
@@ -170,7 +173,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
- if stream_name == "current_state_deltas":
+ if stream_name == EventsStream.NAME:
run_in_background(self._notify_directory)
@defer.inlineCallbacks
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 933928885a..eb10259818 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -42,7 +42,8 @@ class KeyConfig(Config):
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
- self.signing_key = self.read_signing_key(config["signing_key_path"])
+ self.signing_key_path = config["signing_key_path"]
+ self.signing_key = self.read_signing_key(self.signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b7a7b4f1cf..dd242b1211 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -44,6 +44,7 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -114,6 +115,10 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # Enable 3PIDs lookup requests to identity servers from this server.
+ #
+ #enable_3pid_lookup: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 08e4e45482..c5e5679d52 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -37,6 +37,7 @@ class ServerConfig(Config):
def read_config(self, config):
self.server_name = config["server_name"]
+ self.server_context = config.get("server_context", None)
try:
parse_and_validate_server_name(self.server_name)
@@ -484,6 +485,9 @@ class ServerConfig(Config):
#mau_limit_reserved_threepids:
# - medium: 'email'
# address: 'reserved_user@example.com'
+
+ # Used by phonehome stats to group together related servers.
+ #server_context: context
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 0207cd989a..834b107705 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -20,6 +20,7 @@ from collections import namedtuple
from six import raise_from
from six.moves import urllib
+import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -274,10 +275,6 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
- # dict[str, dict[str, VerifyKey]]: results so far.
- # map server_name -> key_id -> VerifyKey
- merged_results = {}
-
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
@@ -287,29 +284,29 @@ class Keyring(object):
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
- merged_results.update(results)
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
- server_name = verify_request.server_name
- result_keys = merged_results[server_name]
-
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
+ server_name = verify_request.server_name
+
+ # see if any of the keys we got this time are sufficient to
+ # complete this VerifyKeyRequest.
+ result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
- if key_id in result_keys:
+ key = result_keys.get(key_id)
+ if key:
with PreserveLoggingContext():
- verify_request.deferred.callback((
- server_name,
- key_id,
- result_keys[key_id],
- ))
+ verify_request.deferred.callback(
+ (server_name, key_id, key)
+ )
break
else:
# The else block is only reached if the loop above
@@ -343,27 +340,24 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
"""
-
Args:
- server_name_and_key_ids (list[(str, iterable[str])]):
+ server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+ Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
server_name -> key_id -> VerifyKey
"""
- res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.get_server_verify_keys,
- server_name, key_ids,
- ).addCallback(lambda ks, server: (server, ks), server_name)
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
- defer.returnValue(dict(res))
+ keys_to_fetch = (
+ (server_name, key_id)
+ for server_name, key_ids in server_name_and_key_ids
+ for key_id in key_ids
+ )
+ res = yield self.store.get_server_verify_keys(keys_to_fetch)
+ keys = {}
+ for (server_name, key_id), key in res.items():
+ keys.setdefault(server_name, {})[key_id] = key
+ defer.returnValue(keys)
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@@ -494,11 +488,11 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
- perspective_name, response, only_from_server=False
+ perspective_name, response
)
+ server_name = response["server_name"]
- for server_name, response_keys in processed_response.items():
- keys.setdefault(server_name, {}).update(response_keys)
+ keys.setdefault(server_name, {}).update(processed_response)
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
@@ -517,7 +511,7 @@ class Keyring(object):
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
- keys = {}
+ keys = {} # type: dict[str, nacl.signing.VerifyKey]
for requested_key_id in key_ids:
if requested_key_id in keys:
@@ -542,6 +536,11 @@ class Keyring(object):
or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server")
+ if response["server_name"] != server_name:
+ raise KeyLookupError("Expected a response for server %r not %r" % (
+ server_name, response["server_name"]
+ ))
+
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
@@ -550,24 +549,45 @@ class Keyring(object):
keys.update(response_keys)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store_keys,
- server_name=key_server_name,
- from_server=server_name,
- verify_keys=verify_keys,
- )
- for key_server_name, verify_keys in keys.items()
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError))
-
- defer.returnValue(keys)
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=keys,
+ )
+ defer.returnValue({server_name: keys})
@defer.inlineCallbacks
- def process_v2_response(self, from_server, response_json,
- requested_ids=[], only_from_server=True):
+ def process_v2_response(
+ self, from_server, response_json, requested_ids=[],
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key request
+
+ This is used to parse either the entirety of the response from
+ GET /_matrix/key/v2/server, or a single entry from the list returned by
+ POST /_matrix/key/v2/query.
+
+ Checks that each signature in the response that claims to come from the origin
+ server is valid. (Does not check that there actually is such a signature, for
+ some reason.)
+
+ Stores the json in server_keys_json so that it can be used for future responses
+ to /_matrix/key/v2/query.
+
+ Args:
+ from_server (str): the name of the server producing this result: either
+ the origin server for a /_matrix/key/v2/server request, or the notary
+ for a /_matrix/key/v2/query.
+
+ response_json (dict): the json-decoded Server Keys response object
+
+ requested_ids (iterable[str]): a list of the key IDs that were requested.
+ We will store the json for these key ids as well as any that are
+ actually in the response
+
+ Returns:
+ Deferred[dict[str, nacl.signing.VerifyKey]]:
+ map from key_id to key object
+ """
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -589,15 +609,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
- results = {}
server_name = response_json["server_name"]
- if only_from_server:
- if server_name != from_server:
- raise KeyLookupError(
- "Expected a response for server %r not %r" % (
- from_server, server_name
- )
- )
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise KeyLookupError(
@@ -633,7 +645,7 @@ class Keyring(object):
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
- from_server=server_name,
+ from_server=from_server,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
@@ -643,9 +655,7 @@ class Keyring(object):
consumeErrors=True,
).addErrback(unwrapFirstError))
- results[server_name] = response_keys
-
- defer.returnValue(results)
+ defer.returnValue(response_keys)
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 8f9e330da5..203490fc36 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -20,15 +20,9 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
-from synapse.api.constants import (
- KNOWN_ROOM_VERSIONS,
- EventFormatVersions,
- EventTypes,
- JoinRules,
- Membership,
- RoomVersions,
-)
+from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -452,16 +446,18 @@ def check_redaction(room_version, event, auth_events):
if user_level >= redact_level:
return False
- if room_version in (RoomVersions.V1, RoomVersions.V2,):
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+ if not v:
+ raise RuntimeError("Unrecognized room version %r" % (room_version,))
+
+ if v.event_format == EventFormatVersions.V1:
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
- elif room_version == RoomVersions.V3:
+ else:
event.internal_metadata.recheck_redaction = True
return True
- else:
- raise RuntimeError("Unrecognized room version %r" % (room_version,))
raise AuthError(
403,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index fafa135182..12056d5be2 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -21,7 +21,7 @@ import six
from unpaddedbase64 import encode_base64
-from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -351,18 +351,13 @@ def room_version_to_event_format(room_version):
Returns:
int
"""
- if room_version not in KNOWN_ROOM_VERSIONS:
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+
+ if not v:
# We should have already checked version, so this should not happen
raise RuntimeError("Unrecognized room version %s" % (room_version,))
- if room_version in (
- RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
- ):
- return EventFormatVersions.V1
- elif room_version in (RoomVersions.V3,):
- return EventFormatVersions.V2
- else:
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
+ return v.event_format
def event_type_from_format_version(format_version):
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 06e01be918..fba27177c7 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -17,21 +17,17 @@ import attr
from twisted.internet import defer
-from synapse.api.constants import (
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
- MAX_DEPTH,
EventFormatVersions,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID
from synapse.util.stringutils import random_string
-from . import (
- _EventInternalMetadata,
- event_type_from_format_version,
- room_version_to_event_format,
-)
+from . import _EventInternalMetadata, event_type_from_format_version
@attr.s(slots=True, cmp=False, frozen=True)
@@ -170,21 +166,34 @@ class EventBuilderFactory(object):
def new(self, room_version, key_values):
"""Generate an event builder appropriate for the given room version
+ Deprecated: use for_room_version with a RoomVersion object instead
+
Args:
- room_version (str): Version of the room that we're creating an
- event builder for
+ room_version (str): Version of the room that we're creating an event builder
+ for
key_values (dict): Fields used as the basis of the new event
Returns:
EventBuilder
"""
-
- # There's currently only the one event version defined
- if room_version not in KNOWN_ROOM_VERSIONS:
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+ if not v:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
+ return self.for_room_version(v, key_values)
+ def for_room_version(self, room_version, key_values):
+ """Generate an event builder appropriate for the given room version
+
+ Args:
+ room_version (synapse.api.room_versions.RoomVersion):
+ Version of the room that we're creating an event builder for
+ key_values (dict): Fields used as the basis of the new event
+
+ Returns:
+ EventBuilder
+ """
return EventBuilder(
store=self.store,
state=self.state,
@@ -192,7 +201,7 @@ class EventBuilderFactory(object):
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
- format_version=room_version_to_event_format(room_version),
+ format_version=room_version.event_format,
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
@@ -222,7 +231,6 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
FrozenEvent
"""
- # There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index a072674b02..514273c792 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -15,8 +15,9 @@
from six import string_types
-from synapse.api.constants import EventFormatVersions, EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.room_versions import EventFormatVersions
from synapse.types import EventID, RoomID, UserID
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a7a2ec4523..dfe6b4aa5c 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -20,8 +20,9 @@ import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import event_type_from_format_version
from synapse.events.utils import prune_event
@@ -274,9 +275,12 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
- if room_version in (
- RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
- ):
+ # (ie, the room version uses old-style non-hash event IDs).
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+ if not v:
+ raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
+ if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
@@ -289,10 +293,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
for p, d in zip(pdus_to_check_event_id, more_deferreds):
p.deferreds.append(d)
- elif room_version in (RoomVersions.V3,):
- pass # No further checks needed, as event IDs are hashes here
- else:
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 58e04d81ab..f3fc897a0a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,12 +25,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import (
- KNOWN_ROOM_VERSIONS,
- EventTypes,
- Membership,
- RoomVersions,
-)
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -38,6 +33,11 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersions,
+)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError
@@ -570,7 +570,7 @@ class FederationClient(FederationBase):
Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
`(origin, event, event_format)` where origin is the remote
homeserver which generated the event, and event_format is one of
- `synapse.api.constants.EventFormatVersions`.
+ `synapse.api.room_versions.EventFormatVersions`.
Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code.
@@ -592,7 +592,7 @@ class FederationClient(FederationBase):
# Note: If not supplied, the room version may be either v1 or v2,
# however either way the event format version will be v1.
- room_version = ret.get("room_version", RoomVersions.V1)
+ room_version = ret.get("room_version", RoomVersions.V1.identifier)
event_format = room_version_to_event_format(room_version)
pdu_dict = ret.get("event", None)
@@ -695,7 +695,9 @@ class FederationClient(FederationBase):
room_version = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
- room_version = e.content.get("room_version", RoomVersions.V1)
+ room_version = e.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
break
if room_version is None:
@@ -802,11 +804,10 @@ class FederationClient(FederationBase):
raise err
# Otherwise, we assume that the remote server doesn't understand
- # the v2 invite API.
-
- if room_version in (RoomVersions.V1, RoomVersions.V2):
- pass # We'll fall through
- else:
+ # the v2 invite API. That's ok provided the room uses old-style event
+ # IDs.
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+ if v.event_format != EventFormatVersions.V1:
raise SynapseError(
400,
"User's homeserver does not support this room version",
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 81f3b4b1ff..df60828dba 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -34,6 +34,7 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 04d04a4457..0240b339b0 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object):
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> user_id
+ self.presence_changed = SortedDict() # Stream position -> list[user_id]
+
+ # Stores the destinations we need to explicitly send presence to about a
+ # given user.
+ # Stream position -> (user_id, destinations)
+ self.presence_destinations = SortedDict()
self.keyed_edu = {} # (destination, key) -> EDU
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
@@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "device_messages", "pos_time",
+ "edus", "device_messages", "pos_time", "presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object):
for user_id in uids
)
+ keys = self.presence_destinations.keys()
+ i = self.presence_destinations.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.presence_destinations[key]
+
+ user_ids.update(
+ user_id for user_id, _ in self.presence_destinations.values()
+ )
+
to_del = [
user_id for user_id in self.presence_map if user_id not in user_ids
]
@@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
+ def send_presence_to_destinations(self, states, destinations):
+ """As per FederationSender
+
+ Args:
+ states (list[UserPresenceState])
+ destinations (list[str])
+ """
+ for state in states:
+ pos = self._next_pos()
+ self.presence_map.update({state.user_id: state for state in states})
+ self.presence_destinations[pos] = (state.user_id, destinations)
+
+ self.notifier.on_new_replication_data()
+
def send_device_messages(self, destination):
"""As per FederationSender"""
pos = self._next_pos()
@@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object):
state=self.presence_map[user_id],
)))
+ # Fetch presence to send to destinations
+ i = self.presence_destinations.bisect_right(from_token)
+ j = self.presence_destinations.bisect_right(to_token) + 1
+
+ for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
+ rows.append((pos, PresenceDestinationsRow(
+ state=self.presence_map[user_id],
+ destinations=list(dests),
+ )))
+
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
j = self.keyed_edu_changed.bisect_right(to_token) + 1
@@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
+class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
+ "state", # UserPresenceState
+ "destinations", # list[str]
+))):
+ TypeId = "pd"
+
+ @staticmethod
+ def from_data(data):
+ return PresenceDestinationsRow(
+ state=UserPresenceState.from_dict(data["state"]),
+ destinations=data["dests"],
+ )
+
+ def to_data(self):
+ return {
+ "state": self.state.as_dict(),
+ "dests": self.destinations,
+ }
+
+ def add_to_buffer(self, buff):
+ buff.presence_destinations.append((self.state, self.destinations))
+
+
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
"key", # tuple(str) - the edu key passed to send_edu
"edu", # Edu
@@ -428,6 +489,7 @@ TypeToRow = {
Row.TypeId: Row
for Row in (
PresenceRow,
+ PresenceDestinationsRow,
KeyedEduRow,
EduRow,
DeviceRow,
@@ -437,6 +499,7 @@ TypeToRow = {
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
+ "presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
"device_destinations", # set of destinations
@@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows):
buff = ParsedFederationStreamData(
presence=[],
+ presence_destinations=[],
keyed_edus={},
edus={},
device_destinations=set(),
@@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
+ for state, destinations in buff.presence_destinations:
+ transaction_queue.send_presence_to_destinations(
+ states=[state], destinations=destinations,
+ )
+
for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 1dc041752b..4f0f939102 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -371,7 +371,7 @@ class FederationSender(object):
return
# First we queue up the new presence by user ID, so multiple presence
- # updates in quick successtion are correctly handled
+ # updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
self.pending_presence.update({
@@ -402,6 +402,23 @@ class FederationSender(object):
finally:
self._processing_pending_presence = False
+ def send_presence_to_destinations(self, states, destinations):
+ """Send the given presence states to the given destinations.
+
+ Args:
+ states (list[UserPresenceState])
+ destinations (list[str])
+ """
+
+ if not states or not self.hs.config.use_presence:
+ # No-op if presence is disabled.
+ return
+
+ for destination in destinations:
+ if destination == self.server_name:
+ continue
+ self._get_per_destination_queue(destination).send_presence(states)
+
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index efb6bdca48..452599e1a1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -21,8 +21,8 @@ import re
from twisted.internet import defer
import synapse
-from synapse.api.constants import RoomVersions
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
+from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
@@ -513,7 +513,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
# state resolution algorithm, and we don't use that for processing
# invites
content = yield self.handler.on_invite_request(
- origin, content, room_version=RoomVersions.V1,
+ origin, content, room_version=RoomVersions.V1.identifier,
)
# V1 federation API is defined to return a content of `[200, {...}]`
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index a7eaead56b..817be40360 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
@@ -896,6 +897,78 @@ class GroupsServerHandler(object):
"group_id": group_id,
})
+ @defer.inlineCallbacks
+ def delete_group(self, group_id, requester_user_id):
+ """Deletes a group, kicking out all current members.
+
+ Only group admins or server admins can call this request
+
+ Args:
+ group_id (str)
+ request_user_id (str)
+
+ Returns:
+ Deferred
+ """
+
+ yield self.check_group_is_ours(
+ group_id, requester_user_id,
+ and_exists=True,
+ )
+
+ # Only server admins or group admins can delete groups.
+
+ is_admin = yield self.store.is_user_admin_in_group(
+ group_id, requester_user_id
+ )
+
+ if not is_admin:
+ is_admin = yield self.auth.is_server_admin(
+ UserID.from_string(requester_user_id),
+ )
+
+ if not is_admin:
+ raise SynapseError(403, "User is not an admin")
+
+ # Before deleting the group lets kick everyone out of it
+ users = yield self.store.get_users_in_group(
+ group_id, include_private=True,
+ )
+
+ @defer.inlineCallbacks
+ def _kick_user_from_group(user_id):
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ yield groups_local.user_removed_from_group(group_id, user_id, {})
+ else:
+ yield self.transport_client.remove_user_from_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, {}
+ )
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+
+ # We kick users out in the order of:
+ # 1. Non-admins
+ # 2. Other admins
+ # 3. The requester
+ #
+ # This is so that if the deletion fails for some reason other admins or
+ # the requester still has auth to retry.
+ non_admins = []
+ admins = []
+ for u in users:
+ if u["user_id"] == requester_user_id:
+ continue
+ if u["is_admin"]:
+ admins.append(u["user_id"])
+ else:
+ non_admins.append(u["user_id"])
+
+ yield concurrently_execute(_kick_user_from_group, non_admins, 10)
+ yield concurrently_execute(_kick_user_from_group, admins, 10)
+ yield _kick_user_from_group(requester_user_id)
+
+ yield self.store.delete_group(group_id)
+
def _parse_join_policy_from_contents(content):
"""Given a content for a request, return the specified join policy or None
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4544de821d..aa5d89a9ac 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -912,7 +912,7 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def delete_threepid(self, user_id, medium, address):
+ def delete_threepid(self, user_id, medium, address, id_server=None):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
@@ -920,6 +920,10 @@ class AuthHandler(BaseHandler):
user_id (str)
medium (str)
address (str)
+ id_server (str|None): Use the given identity server when unbinding
+ any threepids. If None then will attempt to unbind using the
+ identity server specified when binding (if known).
+
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
@@ -937,6 +941,7 @@ class AuthHandler(BaseHandler):
{
'medium': medium,
'address': address,
+ 'id_server': id_server,
},
)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 97d3f31d98..6a91f7698e 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -43,12 +43,15 @@ class DeactivateAccountHandler(BaseHandler):
hs.get_reactor().callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks
- def deactivate_account(self, user_id, erase_data):
+ def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
erase_data (bool): whether to GDPR-erase the user's data
+ id_server (str|None): Use the given identity server when unbinding
+ any threepids. If None then will attempt to unbind using the
+ identity server specified when binding (if known).
Returns:
Deferred[bool]: True if identity server supports removing
@@ -74,6 +77,7 @@ class DeactivateAccountHandler(BaseHandler):
{
'medium': threepid['medium'],
'address': threepid['address'],
+ 'id_server': id_server,
},
)
identity_server_supports_unbinding &= result
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index fe128d9c88..27bd06df5d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -68,7 +68,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
- users = yield self.state.get_current_user_in_room(room_id)
+ users = yield self.state.get_current_users_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers:
@@ -268,7 +268,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND
)
- users = yield self.state.get_current_user_in_room(room_id)
+ users = yield self.state.get_current_users_in_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d883e98381..1b4d8c74ae 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -102,7 +102,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = yield self.state.get_current_user_in_room(event.room_id)
+ users = yield self.state.get_current_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 9eaf2d3e18..0684778882 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -29,13 +29,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
-from synapse.api.constants import (
- KNOWN_ROOM_VERSIONS,
- EventTypes,
- Membership,
- RejectedReason,
- RoomVersions,
-)
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -44,6 +38,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
from synapse.events.validator import EventValidator
@@ -1733,7 +1728,9 @@ class FederationHandler(BaseHandler):
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
- room_version = create_event.content.get("room_version", RoomVersions.V1)
+ room_version = create_event.content.get(
+ "room_version", RoomVersions.V1.identifier,
+ )
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 39184f0e22..22469486d7 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -132,6 +132,14 @@ class IdentityHandler(BaseHandler):
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
+
+ # Remember where we bound the threepid
+ yield self.store.add_user_bound_threepid(
+ user_id=mxid,
+ medium=data["medium"],
+ address=data["address"],
+ id_server=id_server,
+ )
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
defer.returnValue(data)
@@ -142,30 +150,61 @@ class IdentityHandler(BaseHandler):
Args:
mxid (str): Matrix user ID of binding to be removed
- threepid (dict): Dict with medium & address of binding to be removed
+ threepid (dict): Dict with medium & address of binding to be
+ removed, and an optional id_server.
Raises:
SynapseError: If we failed to contact the identity server
Returns:
Deferred[bool]: True on success, otherwise False if the identity
- server doesn't support unbinding
+ server doesn't support unbinding (or no identity server found to
+ contact).
"""
- logger.debug("unbinding threepid %r from %s", threepid, mxid)
- if not self.trusted_id_servers:
- logger.warn("Can't unbind threepid: no trusted ID servers set in config")
+ if threepid.get("id_server"):
+ id_servers = [threepid["id_server"]]
+ else:
+ id_servers = yield self.store.get_id_servers_user_bound(
+ user_id=mxid,
+ medium=threepid["medium"],
+ address=threepid["address"],
+ )
+
+ # We don't know where to unbind, so we don't have a choice but to return
+ if not id_servers:
defer.returnValue(False)
- # We don't track what ID server we added 3pids on (perhaps we ought to)
- # but we assume that any of the servers in the trusted list are in the
- # same ID server federation, so we can pick any one of them to send the
- # deletion request to.
- id_server = next(iter(self.trusted_id_servers))
+ changed = True
+ for id_server in id_servers:
+ changed &= yield self.try_unbind_threepid_with_id_server(
+ mxid, threepid, id_server,
+ )
+
+ defer.returnValue(changed)
+
+ @defer.inlineCallbacks
+ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ """Removes a binding from an identity server
+ Args:
+ mxid (str): Matrix user ID of binding to be removed
+ threepid (dict): Dict with medium & address of binding to be removed
+ id_server (str): Identity server to unbind from
+
+ Raises:
+ SynapseError: If we failed to contact the identity server
+
+ Returns:
+ Deferred[bool]: True on success, otherwise False if the identity
+ server doesn't support unbinding
+ """
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
- "threepid": threepid,
+ "threepid": {
+ "medium": threepid["medium"],
+ "address": threepid["address"],
+ },
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -188,16 +227,24 @@ class IdentityHandler(BaseHandler):
content,
headers,
)
+ changed = True
except HttpResponseException as e:
+ changed = False
if e.code in (400, 404, 501,):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
- defer.returnValue(False)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(502, "Failed to contact identity server")
- defer.returnValue(True)
+ yield self.store.remove_user_bound_threepid(
+ user_id=mxid,
+ medium=threepid["medium"],
+ address=threepid["address"],
+ id_server=id_server,
+ )
+
+ defer.returnValue(changed)
@defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9b41c7b205..224d34ef3a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -30,6 +30,7 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
+from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -191,7 +192,7 @@ class MessageHandler(object):
"Getting joined members after leaving is not implemented"
)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.state.get_current_users_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
@@ -603,7 +604,9 @@ class EventCreationHandler(object):
"""
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
- room_version = event.content.get("room_version", RoomVersions.V1)
+ room_version = event.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
else:
room_version = yield self.store.get_room_version(event.room_id)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37e87fc054..bd1285b15c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -31,9 +31,11 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import PresenceState
+import synapse.metrics
+from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -98,6 +100,7 @@ class PresenceHandler(object):
self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
+ self.server_name = hs.hostname
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
@@ -110,30 +113,6 @@ class PresenceHandler(object):
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
- federation_registry.register_edu_handler(
- "m.presence_invite",
- lambda origin, content: self.invite_presence(
- observed_user=UserID.from_string(content["observed_user"]),
- observer_user=UserID.from_string(content["observer_user"]),
- )
- )
- federation_registry.register_edu_handler(
- "m.presence_accept",
- lambda origin, content: self.accept_presence(
- observed_user=UserID.from_string(content["observed_user"]),
- observer_user=UserID.from_string(content["observer_user"]),
- )
- )
- federation_registry.register_edu_handler(
- "m.presence_deny",
- lambda origin, content: self.deny_presence(
- observed_user=UserID.from_string(content["observed_user"]),
- observer_user=UserID.from_string(content["observer_user"]),
- )
- )
-
- distributor = hs.get_distributor()
- distributor.observe("user_joined_room", self.user_joined_room)
active_presence = self.store.take_presence_startup_info()
@@ -220,6 +199,15 @@ class PresenceHandler(object):
LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
lambda: len(self.wheel_timer))
+ # Used to handle sending of presence to newly joined users/servers
+ if hs.config.use_presence:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # Presence is best effort and quickly heals itself, so lets just always
+ # stream from the current state when we restart.
+ self._event_pos = self.store.get_current_events_token()
+ self._event_processing = False
+
@defer.inlineCallbacks
def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
@@ -751,199 +739,173 @@ class PresenceHandler(object):
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
- def user_joined_room(self, user, room_id):
- """Called (via the distributor) when a user joins a room. This funciton
- sends presence updates to servers, either:
- 1. the joining user is a local user and we send their presence to
- all servers in the room.
- 2. the joining user is a remote user and so we send presence for all
- local users in the room.
+ def is_visible(self, observed_user, observer_user):
+ """Returns whether a user can see another user's presence.
"""
- # We only need to send presence to servers that don't have it yet. We
- # don't need to send to local clients here, as that is done as part
- # of the event stream/sync.
- # TODO: Only send to servers not already in the room.
- if self.is_mine(user):
- state = yield self.current_state_for_user(user.to_string())
-
- self._push_to_remotes([state])
- else:
- user_ids = yield self.store.get_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, user_ids))
+ observer_room_ids = yield self.store.get_rooms_for_user(
+ observer_user.to_string()
+ )
+ observed_room_ids = yield self.store.get_rooms_for_user(
+ observed_user.to_string()
+ )
- states = yield self.current_state_for_users(user_ids)
+ if observer_room_ids & observed_room_ids:
+ defer.returnValue(True)
- self._push_to_remotes(list(states.values()))
+ defer.returnValue(False)
@defer.inlineCallbacks
- def get_presence_list(self, observer_user, accepted=None):
- """Returns the presence for all users in their presence list.
+ def get_all_presence_updates(self, last_id, current_id):
"""
- if not self.is_mine(observer_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
-
- presence_list = yield self.store.get_presence_list(
- observer_user.localpart, accepted=accepted
- )
+ Gets a list of presence update rows from between the given stream ids.
+ Each row has:
+ - stream_id(str)
+ - user_id(str)
+ - state(str)
+ - last_active_ts(int)
+ - last_federation_update_ts(int)
+ - last_user_sync_ts(int)
+ - status_msg(int)
+ - currently_active(int)
+ """
+ # TODO(markjh): replicate the unpersisted changes.
+ # This could use the in-memory stores for recent changes.
+ rows = yield self.store.get_all_presence_updates(last_id, current_id)
+ defer.returnValue(rows)
- results = yield self.get_states(
- target_user_ids=[row["observed_user_id"] for row in presence_list],
- as_event=False,
- )
+ def notify_new_event(self):
+ """Called when new events have happened. Handles users and servers
+ joining rooms and require being sent presence.
+ """
- now = self.clock.time_msec()
- results[:] = [format_user_presence_state(r, now) for r in results]
+ if self._event_processing:
+ return
- is_accepted = {
- row["observed_user_id"]: row["accepted"] for row in presence_list
- }
+ @defer.inlineCallbacks
+ def _process_presence():
+ assert not self._event_processing
- for result in results:
- result.update({
- "accepted": is_accepted,
- })
+ self._event_processing = True
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._event_processing = False
- defer.returnValue(results)
+ run_as_background_process("presence.notify_new_event", _process_presence)
@defer.inlineCallbacks
- def send_presence_invite(self, observer_user, observed_user):
- """Sends a presence invite.
- """
- yield self.store.add_presence_list_pending(
- observer_user.localpart, observed_user.to_string()
- )
+ def _unsafe_process(self):
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "presence_delta"):
+ deltas = yield self.store.get_current_state_deltas(self._event_pos)
+ if not deltas:
+ return
- if self.is_mine(observed_user):
- yield self.invite_presence(observed_user, observer_user)
- else:
- yield self.federation.build_and_send_edu(
- destination=observed_user.domain,
- edu_type="m.presence_invite",
- content={
- "observed_user": observed_user.to_string(),
- "observer_user": observer_user.to_string(),
- }
- )
+ yield self._handle_state_delta(deltas)
+
+ self._event_pos = deltas[-1]["stream_id"]
+
+ # Expose current event processing position to prometheus
+ synapse.metrics.event_processing_positions.labels("presence").set(
+ self._event_pos
+ )
@defer.inlineCallbacks
- def invite_presence(self, observed_user, observer_user):
- """Handles new presence invites.
+ def _handle_state_delta(self, deltas):
+ """Process current state deltas to find new joins that need to be
+ handled.
"""
- if not self.is_mine(observed_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ prev_event_id = delta["prev_event_id"]
- # TODO: Don't auto accept
- if self.is_mine(observer_user):
- yield self.accept_presence(observed_user, observer_user)
- else:
- self.federation.build_and_send_edu(
- destination=observer_user.domain,
- edu_type="m.presence_accept",
- content={
- "observed_user": observed_user.to_string(),
- "observer_user": observer_user.to_string(),
- }
- )
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
- state_dict = yield self.get_state(observed_user, as_event=False)
- state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
+ if typ != EventTypes.Member:
+ continue
- self.federation.build_and_send_edu(
- destination=observer_user.domain,
- edu_type="m.presence",
- content={
- "push": [state_dict]
- }
- )
+ event = yield self.store.get_event(event_id)
+ if event.content.get("membership") != Membership.JOIN:
+ # We only care about joins
+ continue
- @defer.inlineCallbacks
- def accept_presence(self, observed_user, observer_user):
- """Handles a m.presence_accept EDU. Mark a presence invite from a
- local or remote user as accepted in a local user's presence list.
- Starts polling for presence updates from the local or remote user.
- Args:
- observed_user(UserID): The user to update in the presence list.
- observer_user(UserID): The owner of the presence list to update.
- """
- yield self.store.set_presence_list_accepted(
- observer_user.localpart, observed_user.to_string()
- )
+ if prev_event_id:
+ prev_event = yield self.store.get_event(prev_event_id)
+ if prev_event.content.get("membership") == Membership.JOIN:
+ # Ignore changes to join events.
+ continue
+
+ yield self._on_user_joined_room(room_id, state_key)
@defer.inlineCallbacks
- def deny_presence(self, observed_user, observer_user):
- """Handle a m.presence_deny EDU. Removes a local or remote user from a
- local user's presence list.
+ def _on_user_joined_room(self, room_id, user_id):
+ """Called when we detect a user joining the room via the current state
+ delta stream.
+
Args:
- observed_user(UserID): The local or remote user to remove from the
- list.
- observer_user(UserID): The local owner of the presence list.
+ room_id (str)
+ user_id (str)
+
Returns:
- A Deferred.
+ Deferred
"""
- yield self.store.del_presence_list(
- observer_user.localpart, observed_user.to_string()
- )
- # TODO(paul): Inform the user somehow?
+ if self.is_mine_id(user_id):
+ # If this is a local user then we need to send their presence
+ # out to hosts in the room (who don't already have it)
- @defer.inlineCallbacks
- def drop(self, observed_user, observer_user):
- """Remove a local or remote user from a local user's presence list and
- unsubscribe the local user from updates that user.
- Args:
- observed_user(UserId): The local or remote user to remove from the
- list.
- observer_user(UserId): The local owner of the presence list.
- Returns:
- A Deferred.
- """
- if not self.is_mine(observer_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ # TODO: We should be able to filter the hosts down to those that
+ # haven't previously seen the user
- yield self.store.del_presence_list(
- observer_user.localpart, observed_user.to_string()
- )
+ state = yield self.current_state_for_user(user_id)
+ hosts = yield self.state.get_current_hosts_in_room(room_id)
- # TODO: Inform the remote that we've dropped the presence list.
+ # Filter out ourselves.
+ hosts = set(host for host in hosts if host != self.server_name)
- @defer.inlineCallbacks
- def is_visible(self, observed_user, observer_user):
- """Returns whether a user can see another user's presence.
- """
- observer_room_ids = yield self.store.get_rooms_for_user(
- observer_user.to_string()
- )
- observed_room_ids = yield self.store.get_rooms_for_user(
- observed_user.to_string()
- )
+ self.federation.send_presence_to_destinations(
+ states=[state],
+ destinations=hosts,
+ )
+ else:
+ # A remote user has joined the room, so we need to:
+ # 1. Check if this is a new server in the room
+ # 2. If so send any presence they don't already have for
+ # local users in the room.
- if observer_room_ids & observed_room_ids:
- defer.returnValue(True)
+ # TODO: We should be able to filter the users down to those that
+ # the server hasn't previously seen
- accepted_observers = yield self.store.get_presence_list_observers_accepted(
- observed_user.to_string()
- )
+ # TODO: Check that this is actually a new server joining the
+ # room.
- defer.returnValue(observer_user.to_string() in accepted_observers)
+ user_ids = yield self.state.get_current_users_in_room(room_id)
+ user_ids = list(filter(self.is_mine_id, user_ids))
- @defer.inlineCallbacks
- def get_all_presence_updates(self, last_id, current_id):
- """
- Gets a list of presence update rows from between the given stream ids.
- Each row has:
- - stream_id(str)
- - user_id(str)
- - state(str)
- - last_active_ts(int)
- - last_federation_update_ts(int)
- - last_user_sync_ts(int)
- - status_msg(int)
- - currently_active(int)
- """
- # TODO(markjh): replicate the unpersisted changes.
- # This could use the in-memory stores for recent changes.
- rows = yield self.store.get_all_presence_updates(last_id, current_id)
- defer.returnValue(rows)
+ states = yield self.current_state_for_users(user_ids)
+
+ # Filter out old presence, i.e. offline presence states where
+ # the user hasn't been active for a week. We can change this
+ # depending on what we want the UX to be, but at the least we
+ # should filter out offline presence where the state is just the
+ # default state.
+ now = self.clock.time_msec()
+ states = [
+ state for state in states.values()
+ if state.state != PresenceState.OFFLINE
+ or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
+ or state.status_msg is not None
+ ]
+
+ if states:
+ self.federation.send_presence_to_destinations(
+ states=states,
+ destinations=[get_domain_from_id(user_id)],
+ )
def should_notify(old_state, new_state):
@@ -1086,10 +1048,7 @@ class PresenceEventSource(object):
updates for
"""
user_id = user.to_string()
- plist = yield self.store.get_presence_list_accepted(
- user.localpart, on_invalidate=cache_context.invalidate,
- )
- users_interested_in = set(row["observed_user_id"] for row in plist)
+ users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
@@ -1294,10 +1253,6 @@ def get_interested_parties(store, states):
for room_id in room_ids:
room_ids_to_states.setdefault(room_id, []).append(state)
- plist = yield store.get_presence_list_observers_accepted(state.user_id)
- for u in plist:
- users_to_states.setdefault(u, []).append(state)
-
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 58940e0320..a51d11a257 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler):
user_type=None,
default_display_name=None,
address=None,
+ bind_emails=[],
):
"""Registers a new client on the server.
@@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler):
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration.
+ bind_emails (List[str]): list of emails to bind to this account.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler):
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
+ # Bind any specified emails to this account
+ current_time = self.hs.get_clock().time_msec()
+ for email in bind_emails:
+ # generate threepid dict
+ threepid_dict = {
+ "medium": "email",
+ "address": email,
+ "validated_at": current_time,
+ }
+
+ # Bind email to new account
+ yield self._register_email_threepid(
+ user_id, threepid_dict, None, False,
+ )
+
defer.returnValue((user_id, token))
@defer.inlineCallbacks
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 67b15697fd..17628e2684 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,14 +25,9 @@ from six import iteritems, string_types
from twisted.internet import defer
-from synapse.api.constants import (
- DEFAULT_ROOM_VERSION,
- KNOWN_ROOM_VERSIONS,
- EventTypes,
- JoinRules,
- RoomCreationPreset,
-)
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -285,6 +280,7 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.RoomAvatar, ""),
(EventTypes.Encryption, ""),
(EventTypes.ServerACL, ""),
+ (EventTypes.RelatedGroups, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
@@ -479,7 +475,7 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
- room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
+ room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
if not isinstance(room_version, string_types):
raise SynapseError(
400,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index d6c9d56007..617d1c9ef8 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -167,7 +167,7 @@ class RoomListHandler(BaseHandler):
if not latest_event_ids:
return
- joined_users = yield self.state_handler.get_current_user_in_room(
+ joined_users = yield self.state_handler.get_current_users_in_room(
room_id, latest_event_ids,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 71ce5b54e5..024d6db27a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -70,6 +70,7 @@ class RoomMemberHandler(object):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
self._server_notices_mxid = self.config.server_notices_mxid
+ self._enable_lookup = hs.config.enable_3pid_lookup
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -421,6 +422,9 @@ class RoomMemberHandler(object):
room_id, latest_event_ids=latest_event_ids,
)
+ # TODO: Refactor into dictionary of explicitly allowed transitions
+ # between old and new state, with specific error messages for some
+ # transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
@@ -446,6 +450,9 @@ class RoomMemberHandler(object):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
+ if old_membership in ["ban", "leave"] and action == "kick":
+ raise AuthError(403, "The target user is not in the room")
+
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
@@ -459,6 +466,9 @@ class RoomMemberHandler(object):
"You cannot reject this invite",
errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
)
+ else:
+ if action == "kick":
+ raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids)
@@ -729,6 +739,10 @@ class RoomMemberHandler(object):
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
+ if not self._enable_lookup:
+ raise SynapseError(
+ 403, "Looking up third-party identifiers is denied from this server",
+ )
try:
data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 57bb996245..153312e39f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1049,11 +1049,11 @@ class SyncHandler(object):
# TODO: Be more clever than this, i.e. remove users who we already
# share a room with?
for room_id in newly_joined_rooms:
- joined_users = yield self.state.get_current_user_in_room(room_id)
+ joined_users = yield self.state.get_current_users_in_room(room_id)
newly_joined_users.update(joined_users)
for room_id in newly_left_rooms:
- left_users = yield self.state.get_current_user_in_room(room_id)
+ left_users = yield self.state.get_current_users_in_room(room_id)
newly_left_users.update(left_users)
# TODO: Check that these users are actually new, i.e. either they
@@ -1213,7 +1213,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms:
- users = yield self.state.get_current_user_in_room(room_id)
+ users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
@@ -1855,7 +1855,7 @@ class SyncHandler(object):
extrems = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering,
)
- users_in_room = yield self.state.get_current_user_in_room(
+ users_in_room = yield self.state.get_current_users_in_room(
room_id, extrems,
)
if user_id in users_in_room:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 39df960c31..972662eb48 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -218,7 +218,7 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
try:
- users = yield self.state.get_current_user_in_room(member.room_id)
+ users = yield self.state.get_current_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -261,7 +261,7 @@ class TypingHandler(object):
)
return
- users = yield self.state.get_current_user_in_room(room_id)
+ users = yield self.state.get_current_users_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index b689979b4b..5de9630950 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -276,7 +276,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
for user_id in iterkeys(users_with_profile):
@@ -325,7 +325,7 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id
)
# Now we update users who share rooms with users.
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.state.get_current_users_in_room(room_id)
if is_public:
yield self.store.add_users_in_public_rooms(room_id, (user_id,))
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 235ce8334e..b3abd1b3c6 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -74,14 +74,14 @@ class ModuleApi(object):
return self._auth_handler.check_user_exists(user_id)
@defer.inlineCallbacks
- def register(self, localpart, displayname=None):
+ def register(self, localpart, displayname=None, emails=[]):
"""Registers a new user with given localpart and optional
- displayname.
+ displayname, emails.
Args:
localpart (str): The localpart of the new user.
- displayname (str|None): The displayname of the new user. If None,
- the user's displayname will default to `localpart`.
+ displayname (str|None): The displayname of the new user.
+ emails (List[str]): Emails to bind to the new user.
Returns:
Deferred: a 2-tuple of (user_id, access_token)
@@ -90,6 +90,7 @@ class ModuleApi(object):
reg = self.hs.get_registration_handler()
user_id, access_token = yield reg.register(
localpart=localpart, default_display_name=displayname,
+ bind_emails=emails,
)
defer.returnValue((user_id, access_token))
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 50e1007d84..e8ee67401f 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -72,8 +72,15 @@ class EmailPusher(object):
self._is_processing = False
- def on_started(self):
- if self.mailer is not None:
+ def on_started(self, should_check_for_notifs):
+ """Called when this pusher has been started.
+
+ Args:
+ should_check_for_notifs (bool): Whether we should immediately
+ check for push to send. Set to False only if it's known there
+ is nothing to send
+ """
+ if should_check_for_notifs and self.mailer is not None:
self._start_processing()
def on_stop(self):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e65f8c63d3..fac05aa44c 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -112,8 +112,16 @@ class HttpPusher(object):
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
- def on_started(self):
- self._start_processing()
+ def on_started(self, should_check_for_notifs):
+ """Called when this pusher has been started.
+
+ Args:
+ should_check_for_notifs (bool): Whether we should immediately
+ check for push to send. Set to False only if it's known there
+ is nothing to send
+ """
+ if should_check_for_notifs:
+ self._start_processing()
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index abf1a1a9c1..40a7709c09 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.pusher import PusherFactory
+from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
@@ -197,7 +198,7 @@ class PusherPool:
p = r
if p:
- self._start_pusher(p)
+ yield self._start_pusher(p)
@defer.inlineCallbacks
def _start_pushers(self):
@@ -208,10 +209,14 @@ class PusherPool:
"""
pushers = yield self.store.get_all_pushers()
logger.info("Starting %d pushers", len(pushers))
- for pusherdict in pushers:
- self._start_pusher(pusherdict)
+
+ # Stagger starting up the pushers so we don't completely drown the
+ # process on start up.
+ yield concurrently_execute(self._start_pusher, pushers, 10)
+
logger.info("Started pushers")
+ @defer.inlineCallbacks
def _start_pusher(self, pusherdict):
"""Start the given pusher
@@ -248,7 +253,22 @@ class PusherPool:
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
- p.on_started()
+
+ # Check if there *may* be push to process. We do this as this check is a
+ # lot cheaper to do than actually fetching the exact rows we need to
+ # push.
+ user_id = pusherdict["user_name"]
+ last_stream_ordering = pusherdict["last_stream_ordering"]
+ if last_stream_ordering:
+ have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
+ user_id, last_stream_ordering,
+ )
+ else:
+ # We always want to default to starting up the pusher rather than
+ # risk missing push.
+ have_notifs = True
+
+ p.on_started(have_notifs)
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index f71e21ff4d..62c1748665 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -74,7 +74,9 @@ REQUIREMENTS = [
CONDITIONAL_REQUIREMENTS = {
"email.enable_notifs": ["Jinja2>=2.9", "bleach>=1.4.2"],
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
- "postgres": ["psycopg2>=2.6"],
+
+ # we use execute_batch, which arrived in psycopg 2.7.
+ "postgres": ["psycopg2>=2.7"],
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 4830c68f35..b457c5563f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -16,6 +16,10 @@
import logging
from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+)
from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
@@ -79,11 +83,7 @@ class SlavedEventStore(EventFederationWorkerStore,
if stream_name == "events":
self._stream_id_gen.advance(token)
for row in rows:
- self.invalidate_caches_for_event(
- token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts,
- backfilled=False,
- )
+ self._process_event_stream_row(token, row)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
for row in rows:
@@ -96,6 +96,23 @@ class SlavedEventStore(EventFederationWorkerStore,
stream_name, token, rows
)
+ def _process_event_stream_row(self, token, row):
+ data = row.data
+
+ if row.type == EventsStreamEventRow.TypeId:
+ self.invalidate_caches_for_event(
+ token, data.event_id, data.room_id, data.type, data.state_key,
+ data.redacts,
+ backfilled=False,
+ )
+ elif row.type == EventsStreamCurrentStateRow.TypeId:
+ if data.type == EventTypes.Member:
+ self.get_rooms_for_user_with_stream_ordering.invalidate(
+ (data.state_key, ),
+ )
+ else:
+ raise Exception("Unknown events stream row type %s" % (row.type, ))
+
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
etype, state_key, redacts, backfilled):
self._invalidate_get_event_cache(event_id)
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 8032f53fec..cc6f7f009f 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,22 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
-from synapse.storage.keys import KeyStore
+from synapse.storage import KeyStore
-from ._base import BaseSlavedStore, __func__
+# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
+# the races it creates aren't too bad.
-
-class SlavedKeyStore(BaseSlavedStore):
- _get_server_verify_key = KeyStore.__dict__[
- "_get_server_verify_key"
- ]
-
- get_server_verify_keys = __func__(DataStore.get_server_verify_keys)
- store_server_verify_key = __func__(DataStore.store_server_verify_key)
-
- get_server_certificate = __func__(DataStore.get_server_certificate)
- store_server_certificate = __func__(DataStore.store_server_certificate)
-
- get_server_keys_json = __func__(DataStore.get_server_keys_json)
- store_server_keys_json = __func__(DataStore.store_server_keys_json)
+SlavedKeyStore = KeyStore
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 9e530defe0..0ec1db25ce 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -39,16 +39,6 @@ class SlavedPresenceStore(BaseSlavedStore):
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
- # XXX: This is a bit broken because we don't persist the accepted list in a
- # way that can be replicated. This means that we don't have a way to
- # invalidate the cache correctly.
- get_presence_list_accepted = PresenceStore.__dict__[
- "get_presence_list_accepted"
- ]
- get_presence_list_observers_accepted = PresenceStore.__dict__[
- "get_presence_list_observers_accepted"
- ]
-
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e558f90e1a..206dc3b397 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -103,10 +103,19 @@ class ReplicationClientHandler(object):
hs.get_reactor().connectTCP(host, port, self.factory)
def on_rdata(self, stream_name, token, rows):
- """Called when we get new replication data. By default this just pokes
- the slave store.
+ """Called to handle a batch of replication data with a given stream token.
- Can be overriden in subclasses to handle more.
+ By default this just pokes the slave store. Can be overridden in subclasses to
+ handle more.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 02e5bf6cc8..b51590cf8f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -42,8 +42,8 @@ indicate which side is sending, these are *not* included on the wire::
> POSITION backfill 1
> POSITION caches 1
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
- > RDATA events 14 ["$149019767112vOHxz:localhost:8823",
- "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
+ > RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823",
+ "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]]
< PING 1490197675618
> ERROR server stopping
* connection closed by server *
@@ -605,7 +605,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
inbound_rdata_count.labels(stream_name).inc()
try:
- row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 7fc346c7b6..f6a38f5140 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -30,7 +30,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP, FederationStream
+from .streams import STREAMS_MAP
+from .streams.federation import FederationStream
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
"", ["stream_name"])
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
new file mode 100644
index 0000000000..634f636dc9
--- /dev/null
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines all the valid streams that clients can subscribe to, and the format
+of the rows returned by each stream.
+
+Each stream is defined by the following information:
+
+ stream name: The name of the stream
+ row type: The type that is used to serialise/deserialse the row
+ current_token: The function that returns the current token for the stream
+ update_function: The function that returns a list of updates between two tokens
+"""
+
+from . import _base, events, federation
+
+STREAMS_MAP = {
+ stream.NAME: stream
+ for stream in (
+ events.EventsStream,
+ _base.BackfillStream,
+ _base.PresenceStream,
+ _base.TypingStream,
+ _base.ReceiptsStream,
+ _base.PushRulesStream,
+ _base.PushersStream,
+ _base.CachesStream,
+ _base.PublicRoomsStream,
+ _base.DeviceListsStream,
+ _base.ToDeviceStream,
+ federation.FederationStream,
+ _base.TagAccountDataStream,
+ _base.AccountDataStream,
+ _base.GroupServerStream,
+ )
+}
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams/_base.py
index e23084baae..8971a6a22e 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Defines all the valid streams that clients can subscribe to, and the format
-of the rows returned by each stream.
-Each stream is defined by the following information:
-
- stream name: The name of the stream
- row type: The type that is used to serialise/deserialse the row
- current_token: The function that returns the current token for the stream
- update_function: The function that returns a list of updates between two tokens
-"""
import itertools
import logging
from collections import namedtuple
@@ -34,14 +26,6 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 10000
-
-EventStreamRow = namedtuple("EventStreamRow", (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
-))
BackfillStreamRow = namedtuple("BackfillStreamRow", (
"event_id", # str
"room_id", # str
@@ -96,10 +80,6 @@ DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
"entity", # str
))
-FederationStreamRow = namedtuple("FederationStreamRow", (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
-))
TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
"user_id", # str
"room_id", # str
@@ -111,12 +91,6 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
"data_type", # str
"data", # dict
))
-CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
- "room_id", # str
- "type", # str
- "state_key", # str
- "event_id", # str, optional
-))
GroupsStreamRow = namedtuple("GroupsStreamRow", (
"group_id", # str
"user_id", # str
@@ -132,9 +106,24 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called.
"""
NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row
+ ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
_LIMITED = True # Whether the update function takes a limit
+ @classmethod
+ def parse_row(cls, row):
+ """Parse a row received over replication
+
+ By default, assumes that the row data is an array object and passes its contents
+ to the constructor of the ROW_TYPE for this stream.
+
+ Args:
+ row: row data from the incoming RDATA command, after json decoding
+
+ Returns:
+ ROW_TYPE object for this stream
+ """
+ return cls.ROW_TYPE(*row)
+
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
@@ -162,8 +151,10 @@ class Stream(object):
until the `upto_token`
Returns:
- (list(ROW_TYPE), int): list of updates plus the token used as an
- upper bound of the updates (i.e. the "current token")
+ Deferred[Tuple[List[Tuple[int, Any]], int]:
+ Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
+ list of ``(token, row)`` entries. ``row`` will be json-serialised and
+ sent over the replication steam.
"""
updates, current_token = yield self.get_updates_since(self.last_token)
self.last_token = current_token
@@ -176,8 +167,10 @@ class Stream(object):
stream updates
Returns:
- (list(ROW_TYPE), int): list of updates plus the token used as an
- upper bound of the updates (i.e. the "current token")
+ Deferred[Tuple[List[Tuple[int, Any]], int]:
+ Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
+ list of ``(token, row)`` entries. ``row`` will be json-serialised and
+ sent over the replication steam.
"""
if from_token in ("NOW", "now"):
defer.returnValue(([], self.upto_token))
@@ -202,7 +195,7 @@ class Stream(object):
from_token, current_token,
)
- updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows]
+ updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
@@ -232,20 +225,6 @@ class Stream(object):
raise NotImplementedError()
-class EventsStream(Stream):
- """We received a new event, or an event went from being an outlier to not
- """
- NAME = "events"
- ROW_TYPE = EventStreamRow
-
- def __init__(self, hs):
- store = hs.get_datastore()
- self.current_token = store.get_current_events_token
- self.update_function = store.get_all_new_forward_event_rows
-
- super(EventsStream, self).__init__(hs)
-
-
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -400,22 +379,6 @@ class ToDeviceStream(Stream):
super(ToDeviceStream, self).__init__(hs)
-class FederationStream(Stream):
- """Data to be sent over federation. Only available when master has federation
- sending disabled.
- """
- NAME = "federation"
- ROW_TYPE = FederationStreamRow
-
- def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
-
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
-
- super(FederationStream, self).__init__(hs)
-
-
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
@@ -459,21 +422,6 @@ class AccountDataStream(Stream):
defer.returnValue(results)
-class CurrentStateDeltaStream(Stream):
- """Current state for a room was changed
- """
- NAME = "current_state_deltas"
- ROW_TYPE = CurrentStateDeltaStreamRow
-
- def __init__(self, hs):
- store = hs.get_datastore()
-
- self.current_token = store.get_max_current_state_delta_stream_id
- self.update_function = store.get_all_updated_current_state_deltas
-
- super(CurrentStateDeltaStream, self).__init__(hs)
-
-
class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@@ -485,26 +433,3 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
-
-
-STREAMS_MAP = {
- stream.NAME: stream
- for stream in (
- EventsStream,
- BackfillStream,
- PresenceStream,
- TypingStream,
- ReceiptsStream,
- PushRulesStream,
- PushersStream,
- CachesStream,
- PublicRoomsStream,
- DeviceListsStream,
- ToDeviceStream,
- FederationStream,
- TagAccountDataStream,
- AccountDataStream,
- CurrentStateDeltaStream,
- GroupServerStream,
- )
-}
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
new file mode 100644
index 0000000000..e0f6e29248
--- /dev/null
+++ b/synapse/replication/tcp/streams/events.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import heapq
+
+import attr
+
+from twisted.internet import defer
+
+from ._base import Stream
+
+
+"""Handling of the 'events' replication stream
+
+This stream contains rows of various types. Each row therefore contains a 'type'
+identifier before the real data. For example::
+
+ RDATA events batch ["state", ["!room:id", "m.type", "", "$event:id"]]
+ RDATA events 12345 ["ev", ["$event:id", "!room:id", "m.type", null, null]]
+
+An "ev" row is sent for each new event. The fields in the data part are:
+
+ * The new event id
+ * The room id for the event
+ * The type of the new event
+ * The state key of the event, for state events
+ * The event id of an event which is redacted by this event.
+
+A "state" row is sent whenever the "current state" in a room changes. The fields in the
+data part are:
+
+ * The room id for the state change
+ * The event type of the state which has changed
+ * The state_key of the state which has changed
+ * The event id of the new state
+
+"""
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamRow(object):
+ """A parsed row from the events replication stream"""
+ type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
+ data = attr.ib() # BaseEventsStreamRow
+
+
+class BaseEventsStreamRow(object):
+ """Base class for rows to be sent in the events stream.
+
+ Specifies how to identify, serialize and deserialize the different types.
+ """
+
+ TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+
+ @classmethod
+ def from_data(cls, data):
+ """Parse the data from the replication stream into a row.
+
+ By default we just call the constructor with the data list as arguments
+
+ Args:
+ data: The value of the data object from the replication stream
+ """
+ return cls(*data)
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamEventRow(BaseEventsStreamRow):
+ TypeId = "ev"
+
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
+
+
+@attr.s(slots=True, frozen=True)
+class EventsStreamCurrentStateRow(BaseEventsStreamRow):
+ TypeId = "state"
+
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str
+ event_id = attr.ib() # str, optional
+
+
+TypeToRow = {
+ Row.TypeId: Row
+ for Row in (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+ )
+}
+
+
+class EventsStream(Stream):
+ """We received a new event, or an event went from being an outlier to not
+ """
+ NAME = "events"
+
+ def __init__(self, hs):
+ self._store = hs.get_datastore()
+ self.current_token = self._store.get_current_events_token
+
+ super(EventsStream, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def update_function(self, from_token, current_token, limit=None):
+ event_rows = yield self._store.get_all_new_forward_event_rows(
+ from_token, current_token, limit,
+ )
+ event_updates = (
+ (row[0], EventsStreamEventRow.TypeId, row[1:])
+ for row in event_rows
+ )
+
+ state_rows = yield self._store.get_all_updated_current_state_deltas(
+ from_token, current_token, limit
+ )
+ state_updates = (
+ (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
+ for row in state_rows
+ )
+
+ all_updates = heapq.merge(event_updates, state_updates)
+
+ defer.returnValue(all_updates)
+
+ @classmethod
+ def parse_row(cls, row):
+ (typ, data) = row
+ data = TypeToRow[typ].from_data(data)
+ return EventsStreamRow(typ, data)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
new file mode 100644
index 0000000000..9aa43aa8d2
--- /dev/null
+++ b/synapse/replication/tcp/streams/federation.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import namedtuple
+
+from ._base import Stream
+
+FederationStreamRow = namedtuple("FederationStreamRow", (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+))
+
+
+class FederationStream(Stream):
+ """Data to be sent over federation. Only available when master has federation
+ sending disabled.
+ """
+ NAME = "federation"
+ ROW_TYPE = FederationStreamRow
+
+ def __init__(self, hs):
+ federation_sender = hs.get_federation_sender()
+
+ self.current_token = federation_sender.get_current_token
+ self.update_function = federation_sender.get_replication_rows
+
+ super(FederationStream, self).__init__(hs)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index e788769639..7d7a75fc30 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -499,7 +499,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
# desirable in case the first attempt at blocking the room failed below.
yield self.store.block_room(room_id, requester_user_id)
- users = yield self.state.get_current_user_in_room(room_id)
+ users = yield self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
for user_id in users:
@@ -647,8 +647,6 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- logger.info("new_password: %r", new_password)
-
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
@@ -786,6 +784,31 @@ class SearchUsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
+class DeleteGroupAdminRestServlet(ClientV1RestServlet):
+ """Allows deleting of local groups
+ """
+ PATTERNS = client_path_patterns("/admin/delete_group/(?P<group_id>[^/]*)")
+
+ def __init__(self, hs):
+ super(DeleteGroupAdminRestServlet, self).__init__(hs)
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Can only delete local groups")
+
+ yield self.group_server.delete_group(group_id, requester.user.to_string())
+ defer.returnValue((200, {}))
+
+
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
@@ -801,3 +824,4 @@ def register_servlets(hs, http_server):
ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
+ DeleteGroupAdminRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index b5a6d6aebf..045d5a20ac 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -93,72 +93,5 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
return (200, {})
-class PresenceListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- super(PresenceListRestServlet, self).__init__(hs)
- self.presence_handler = hs.get_presence_handler()
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
- user = UserID.from_string(user_id)
-
- if not self.hs.is_mine(user):
- raise SynapseError(400, "User not hosted on this Home Server")
-
- if requester.user != user:
- raise SynapseError(400, "Cannot get another user's presence list")
-
- presence = yield self.presence_handler.get_presence_list(
- observer_user=user, accepted=True
- )
-
- defer.returnValue((200, presence))
-
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
- user = UserID.from_string(user_id)
-
- if not self.hs.is_mine(user):
- raise SynapseError(400, "User not hosted on this Home Server")
-
- if requester.user != user:
- raise SynapseError(
- 400, "Cannot modify another user's presence list")
-
- content = parse_json_object_from_request(request)
-
- if "invite" in content:
- for u in content["invite"]:
- if not isinstance(u, string_types):
- raise SynapseError(400, "Bad invite value.")
- if len(u) == 0:
- continue
- invited_user = UserID.from_string(u)
- yield self.presence_handler.send_presence_invite(
- observer_user=user, observed_user=invited_user
- )
-
- if "drop" in content:
- for u in content["drop"]:
- if not isinstance(u, string_types):
- raise SynapseError(400, "Bad drop value.")
- if len(u) == 0:
- continue
- dropped_user = UserID.from_string(u)
- yield self.presence_handler.drop(
- observer_user=user, observed_user=dropped_user
- )
-
- defer.returnValue((200, {}))
-
- def on_OPTIONS(self, request):
- return (200, {})
-
-
def register_servlets(hs, http_server):
PresenceStatusRestServlet(hs).register(http_server)
- PresenceListRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 37b32dd37b..ee069179f0 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -215,6 +215,7 @@ class DeactivateAccountRestServlet(RestServlet):
)
result = yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase,
+ id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@@ -363,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
+ PATTERNS = client_v2_patterns("/account/3pid/delete$")
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
@@ -380,7 +381,7 @@ class ThreepidDeleteRestServlet(RestServlet):
try:
ret = yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address']
+ user_id, body['medium'], body['address'], body.get("id_server"),
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 373f95126e..a868d06098 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
-from synapse.api.constants import DEFAULT_ROOM_VERSION, RoomDisposition, RoomVersions
+from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
@@ -48,12 +48,10 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
- "default": DEFAULT_ROOM_VERSION,
+ "default": DEFAULT_ROOM_VERSION.identifier,
"available": {
- RoomVersions.V1: RoomDisposition.STABLE,
- RoomVersions.V2: RoomDisposition.STABLE,
- RoomVersions.STATE_V2_TEST: RoomDisposition.UNSTABLE,
- RoomVersions.V3: RoomDisposition.STABLE,
+ v.identifier: v.disposition
+ for v in KNOWN_ROOM_VERSIONS.values()
},
},
"m.change_password": {"enabled": change_password},
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index e6356101fd..3db7ff8d1b 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -17,8 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.api.constants import KNOWN_ROOM_VERSIONS
from synapse.api.errors import Codes, SynapseError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 68058f613c..36684ef9f6 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -24,7 +24,8 @@ from frozendict import frozendict
from twisted.internet import defer
-from synapse.api.constants import EventTypes, RoomVersions
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events.snapshot import EventContext
from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer
@@ -160,10 +161,21 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_user_in_room(self, room_id, latest_event_ids=None):
+ def get_current_users_in_room(self, room_id, latest_event_ids=None):
+ """
+ Get the users who are currently in a room.
+
+ Args:
+ room_id (str): The ID of the room.
+ latest_event_ids (List[str]|None): Precomputed list of latest
+ event IDs. Will be computed if None.
+ Returns:
+ Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
+ profileinfo.
+ """
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- logger.debug("calling resolve_state_groups from get_current_user_in_room")
+ logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@@ -603,22 +615,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
- if room_version == RoomVersions.V1:
+ v = KNOWN_ROOM_VERSIONS[room_version]
+ if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store(
state_sets, event_map, state_res_store.get_events,
)
- elif room_version in (
- RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3,
- ):
+ else:
return v2.resolve_events_with_store(
room_version, state_sets, event_map, state_res_store,
)
- else:
- # This should only happen if we added a version but forgot to add it to
- # the list above.
- raise Exception(
- "No state resolution algorithm defined for version %r" % (room_version,)
- )
@attr.s
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6d3afcae7c..29b4e86cfd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -21,8 +21,9 @@ from six import iteritems, iterkeys, itervalues
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import EventTypes, RoomVersions
+from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.room_versions import RoomVersions
logger = logging.getLogger(__name__)
@@ -275,7 +276,9 @@ def _resolve_auth_events(events, auth_events):
try:
# The signatures have already been checked at this point
event_auth.check(
- RoomVersions.V1, event, auth_events,
+ RoomVersions.V1.identifier,
+ event,
+ auth_events,
do_sig_check=False,
do_size_check=False,
)
@@ -291,7 +294,9 @@ def _resolve_normal_events(events, auth_events):
try:
# The signatures have already been checked at this point
event_auth.check(
- RoomVersions.V1, event, auth_events,
+ RoomVersions.V1.identifier,
+ event,
+ auth_events,
do_sig_check=False,
do_size_check=False,
)
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index 3a958749a1..e02663f50e 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -49,7 +49,7 @@ var show_login = function() {
$("#loading").hide();
var this_page = window.location.origin + window.location.pathname;
- $("#sso_redirect_url").val(encodeURIComponent(this_page));
+ $("#sso_redirect_url").val(this_page);
if (matrixLogin.serverAcceptsPassword) {
$("#password_flow").show();
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 42cd3c83ad..c432041b4e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -18,6 +18,8 @@ import calendar
import logging
import time
+from twisted.internet import defer
+
from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
@@ -61,48 +63,60 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
logger = logging.getLogger(__name__)
-class DataStore(RoomMemberStore, RoomStore,
- RegistrationStore, StreamStore, ProfileStore,
- PresenceStore, TransactionStore,
- DirectoryStore, KeyStore, StateStore, SignatureStore,
- ApplicationServiceStore,
- EventsStore,
- EventFederationStore,
- MediaRepositoryStore,
- RejectionsStore,
- FilteringStore,
- PusherStore,
- PushRuleStore,
- ApplicationServiceTransactionStore,
- ReceiptsStore,
- EndToEndKeyStore,
- EndToEndRoomKeyStore,
- SearchStore,
- TagsStore,
- AccountDataStore,
- EventPushActionsStore,
- OpenIdStore,
- ClientIpStore,
- DeviceStore,
- DeviceInboxStore,
- UserDirectoryStore,
- GroupServerStore,
- UserErasureStore,
- MonthlyActiveUsersStore,
- ):
-
+class DataStore(
+ RoomMemberStore,
+ RoomStore,
+ RegistrationStore,
+ StreamStore,
+ ProfileStore,
+ PresenceStore,
+ TransactionStore,
+ DirectoryStore,
+ KeyStore,
+ StateStore,
+ SignatureStore,
+ ApplicationServiceStore,
+ EventsStore,
+ EventFederationStore,
+ MediaRepositoryStore,
+ RejectionsStore,
+ FilteringStore,
+ PusherStore,
+ PushRuleStore,
+ ApplicationServiceTransactionStore,
+ ReceiptsStore,
+ EndToEndKeyStore,
+ EndToEndRoomKeyStore,
+ SearchStore,
+ TagsStore,
+ AccountDataStore,
+ EventPushActionsStore,
+ OpenIdStore,
+ ClientIpStore,
+ DeviceStore,
+ DeviceInboxStore,
+ UserDirectoryStore,
+ GroupServerStore,
+ UserErasureStore,
+ MonthlyActiveUsersStore,
+):
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
- extra_tables=[("local_invites", "stream_id")]
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering", step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
@@ -114,7 +128,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id",
+ db_conn, "device_lists_stream", "stream_id"
)
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
@@ -125,16 +139,15 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id",
- extra_tables=[("deleted_pushers", "stream_id")],
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self._group_updates_id_gen = StreamIdGenerator(
- db_conn, "local_group_updates", "stream_id",
+ db_conn, "local_group_updates", "stream_id"
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id",
+ db_conn, "cache_invalidation_stream", "stream_id"
)
else:
self._cache_id_gen = None
@@ -142,72 +155,82 @@ class DataStore(RoomMemberStore, RoomStore,
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
- db_conn, "presence_stream",
+ db_conn,
+ "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
- "PresenceStreamChangeCache", min_presence_val,
- prefilled_cache=presence_cache_prefill
+ "PresenceStreamChangeCache",
+ min_presence_val,
+ prefilled_cache=presence_cache_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
- db_conn, "device_inbox",
+ db_conn,
+ "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
- "DeviceInboxStreamChangeCache", min_device_inbox_id,
+ "DeviceInboxStreamChangeCache",
+ min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
- db_conn, "device_federation_outbox",
+ db_conn,
+ "device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
- "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
+ "DeviceFederationOutboxStreamChangeCache",
+ min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max,
+ "DeviceListStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max,
+ "DeviceListFederationStreamChangeCache", device_list_max
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
- db_conn, "current_state_delta_stream",
+ db_conn,
+ "current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
- db_conn, "local_group_updates",
+ db_conn,
+ "local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache", min_group_updates_id,
+ "_group_updates_stream_cache",
+ min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
@@ -250,6 +273,7 @@ class DataStore(RoomMemberStore, RoomStore,
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
+
def _count_users(txn):
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
@@ -277,6 +301,7 @@ class DataStore(RoomMemberStore, RoomStore,
Returns counts globaly for a given user as well as breaking
by platform
"""
+
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
@@ -313,8 +338,7 @@ class DataStore(RoomMemberStore, RoomStore,
"""
results = {}
- txn.execute(sql, (thirty_days_ago_in_secs,
- thirty_days_ago_in_secs))
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
for row in txn:
if row[0] == 'unknown':
@@ -341,8 +365,7 @@ class DataStore(RoomMemberStore, RoomStore,
) u
"""
- txn.execute(sql, (thirty_days_ago_in_secs,
- thirty_days_ago_in_secs))
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
count, = txn.fetchone()
results['all'] = count
@@ -356,15 +379,14 @@ class DataStore(RoomMemberStore, RoomStore,
Returns millisecond unixtime for start of UTC day.
"""
now = time.gmtime()
- today_start = calendar.timegm((
- now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
- ))
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
def generate_user_daily_visits(self):
"""
Generates daily visit data for use in cohort/ retention analysis
"""
+
def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
@@ -395,25 +417,29 @@ class DataStore(RoomMemberStore, RoomStore,
# often to minimise this case.
if today_start > self._last_user_visit_update:
yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(sql, (
- yesterday_start, yesterday_start,
- self._last_user_visit_update, today_start
- ))
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
self._last_user_visit_update = today_start
- txn.execute(sql, (
- today_start, today_start,
- self._last_user_visit_update,
- now
- ))
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
# Update _last_user_visit_update to now. The reason to do this
# rather just clamping to the beginning of the day is to limit
# the size of the join - meaning that the query can be run more
# frequently
self._last_user_visit_update = now
- return self.runInteraction("generate_user_daily_visits",
- _generate_user_daily_visits)
+ return self.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
def get_users(self):
"""Function to reterive a list of users in users table.
@@ -425,15 +451,11 @@ class DataStore(RoomMemberStore, RoomStore,
return self._simple_select_list(
table="users",
keyvalues={},
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin"
- ],
+ retcols=["name", "password_hash", "is_guest", "admin"],
desc="get_users",
)
+ @defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
@@ -446,27 +468,19 @@ class DataStore(RoomMemberStore, RoomStore,
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
- is_guest = 0
- i_start = (int)(start)
- i_limit = (int)(limit)
- return self.get_user_list_paginate(
+ users = yield self.runInteraction(
+ "get_users_paginate",
+ self._simple_select_list_paginate_txn,
table="users",
- keyvalues={
- "is_guest": is_guest
- },
- pagevalues=[
- order,
- i_limit,
- i_start
- ],
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin"
- ],
- desc="get_users_paginate",
+ keyvalues={"is_guest": False},
+ orderby=order,
+ start=start,
+ limit=limit,
+ retcols=["name", "password_hash", "is_guest", "admin"],
)
+ count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
+ retval = {"users": users, "total": count}
+ defer.returnValue(retval)
def search_users(self, term):
"""Function to search users list for one or more users with
@@ -482,12 +496,7 @@ class DataStore(RoomMemberStore, RoomStore,
table="users",
term=term,
col="name",
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin"
- ],
+ retcols=["name", "password_hash", "is_guest", "admin"],
desc="search_users",
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7e3903859b..983ce026e1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -41,7 +41,7 @@ try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
- MAX_TXN_ID = 2**63 - 1
+ MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -76,12 +76,18 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
+
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
+ "txn",
+ "name",
+ "database_engine",
+ "after_callbacks",
+ "exception_callbacks",
]
- def __init__(self, txn, name, database_engine, after_callbacks,
- exception_callbacks):
+ def __init__(
+ self, txn, name, database_engine, after_callbacks, exception_callbacks
+ ):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
@@ -110,6 +116,7 @@ class LoggingTransaction(object):
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
+
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
@@ -134,10 +141,7 @@ class LoggingTransaction(object):
sql = self.database_engine.convert_param_style(sql)
if args:
try:
- sql_logger.debug(
- "[SQL values] {%s} %r",
- self.name, args[0]
- )
+ sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
except Exception:
# Don't let logging failures stop SQL from working
pass
@@ -145,9 +149,7 @@ class LoggingTransaction(object):
start = time.time()
try:
- return func(
- sql, *args
- )
+ return func(sql, *args)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
@@ -176,11 +178,9 @@ class PerformanceCounters(object):
counters = []
for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append((
- (cum_time - prev_time) / interval_duration,
- count - prev_count,
- name
- ))
+ counters.append(
+ ((cum_time - prev_time) / interval_duration, count - prev_count, name)
+ )
self.previous_counters = dict(self.current_counters)
@@ -212,8 +212,9 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = Cache("*getEvent*", keylen=3,
- max_entries=hs.config.event_cache_size)
+ self._get_event_cache = Cache(
+ "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ )
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
@@ -239,7 +240,7 @@ class SQLBaseStore(object):
0.0,
run_as_background_process,
"upsert_safety_check",
- self._check_safe_to_upsert
+ self._check_safe_to_upsert,
)
@defer.inlineCallbacks
@@ -271,7 +272,7 @@ class SQLBaseStore(object):
15.0,
run_as_background_process,
"upsert_safety_check",
- self._check_safe_to_upsert
+ self._check_safe_to_upsert,
)
def start_profiling(self):
@@ -298,13 +299,16 @@ class SQLBaseStore(object):
perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
- ratio * 100, top_three_counters, top_3_event_counters
+ ratio * 100,
+ top_three_counters,
+ top_3_event_counters,
)
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
- func, *args, **kwargs):
+ def _new_transaction(
+ self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+ ):
start = time.time()
txn_id = self._TXN_ID
@@ -312,7 +316,7 @@ class SQLBaseStore(object):
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
- name = "%s-%x" % (desc, txn_id, )
+ name = "%s-%x" % (desc, txn_id)
transaction_logger.debug("[TXN START] {%s}", name)
@@ -323,7 +327,10 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks,
+ txn,
+ name,
+ self.database_engine,
+ after_callbacks,
exception_callbacks,
)
r = func(txn, *args, **kwargs)
@@ -334,7 +341,10 @@ class SQLBaseStore(object):
# transaction.
logger.warning(
"[TXN OPERROR] {%s} %s %d/%d",
- name, exception_to_unicode(e), i, N
+ name,
+ exception_to_unicode(e),
+ i,
+ N,
)
if i < N:
i += 1
@@ -342,8 +352,7 @@ class SQLBaseStore(object):
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warning(
- "[TXN EROLL] {%s} %s",
- name, exception_to_unicode(e1),
+ "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
continue
raise
@@ -357,7 +366,8 @@ class SQLBaseStore(object):
except self.database_engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
- name, exception_to_unicode(e1),
+ name,
+ exception_to_unicode(e1),
)
continue
raise
@@ -396,16 +406,17 @@ class SQLBaseStore(object):
exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn(
- "Starting db txn '%s' from sentinel context",
- desc,
- )
+ logger.warn("Starting db txn '%s' from sentinel context", desc)
try:
result = yield self.runWithConnection(
self._new_transaction,
- desc, after_callbacks, exception_callbacks, func,
- *args, **kwargs
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -434,7 +445,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel:
logger.warn(
- "Starting db connection from sentinel context: metrics will be lost",
+ "Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
@@ -453,9 +464,7 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
defer.returnValue(result)
@@ -469,9 +478,7 @@ class SQLBaseStore(object):
A list of dicts where the key is the column header.
"""
col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(
- dict(zip(col_headers, row)) for row in cursor
- )
+ results = list(dict(zip(col_headers, row)) for row in cursor)
return results
def _execute(self, desc, decoder, query, *args):
@@ -485,6 +492,7 @@ class SQLBaseStore(object):
Returns:
The result of decoder(results)
"""
+
def interaction(txn):
txn.execute(query, args)
if decoder:
@@ -498,8 +506,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False,
- desc="_simple_insert"):
+ def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -511,10 +518,7 @@ class SQLBaseStore(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(
- desc,
- self._simple_insert_txn, table, values,
- )
+ yield self.runInteraction(desc, self._simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -530,15 +534,13 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys),
- ", ".join("?" for _ in keys)
+ ", ".join("?" for _ in keys),
)
txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(
- desc, self._simple_insert_many_txn, table, values
- )
+ return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
@staticmethod
def _simple_insert_many_txn(txn, table, values):
@@ -553,24 +555,18 @@ class SQLBaseStore(object):
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
- keys, vals = zip(*[
- zip(
- *(sorted(i.items(), key=lambda kv: kv[0]))
- )
- for i in values
- if i
- ])
+ keys, vals = zip(
+ *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ )
for k in keys:
if k != keys[0]:
- raise RuntimeError(
- "All items must have the same keys"
- )
+ raise RuntimeError("All items must have the same keys")
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0])
+ ", ".join("?" for _ in keys[0]),
)
txn.executemany(sql, vals)
@@ -583,7 +579,7 @@ class SQLBaseStore(object):
values,
insertion_values={},
desc="_simple_upsert",
- lock=True
+ lock=True,
):
"""
@@ -599,7 +595,7 @@ class SQLBaseStore(object):
Args:
table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
+ keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
@@ -631,17 +627,11 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "%s when upserting into %s; retrying: %s", e.__name__, table, e
+ "IntegrityError when upserting into %s; retrying: %s", table, e
)
def _simple_upsert_txn(
- self,
- txn,
- table,
- keyvalues,
- values,
- insertion_values={},
- lock=True,
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
Pick the UPSERT method which works best on the platform. Either the
@@ -665,11 +655,7 @@ class SQLBaseStore(object):
and table not in self._unsafe_to_upsert_tables
):
return self._simple_upsert_txn_native_upsert(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
+ txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
return self._simple_upsert_txn_emulated(
@@ -714,7 +700,7 @@ class SQLBaseStore(object):
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
- " AND ".join(_getwhere(k) for k in keyvalues)
+ " AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
@@ -726,7 +712,7 @@ class SQLBaseStore(object):
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues)
+ " AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@@ -773,19 +759,14 @@ class SQLBaseStore(object):
latter = "NOTHING"
else:
allvalues.update(values)
- latter = (
- "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- )
+ latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = (
- "INSERT INTO %s (%s) VALUES (%s) "
- "ON CONFLICT (%s) DO %s"
- ) % (
+ sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
- latter
+ latter,
)
txn.execute(sql, list(allvalues.values()))
@@ -870,8 +851,8 @@ class SQLBaseStore(object):
latter = "NOTHING"
value_values = [() for x in range(len(key_values))]
else:
- latter = (
- "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+ latter = "UPDATE SET " + ", ".join(
+ k + "=EXCLUDED." + k for k in value_names
)
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
@@ -889,8 +870,9 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args)
- def _simple_select_one(self, table, keyvalues, retcols,
- allow_none=False, desc="_simple_select_one"):
+ def _simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
+ ):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -903,14 +885,17 @@ class SQLBaseStore(object):
statement returns no rows
"""
return self.runInteraction(
- desc,
- self._simple_select_one_txn,
- table, keyvalues, retcols, allow_none,
+ desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def _simple_select_one_onecol(self, table, keyvalues, retcol,
- allow_none=False,
- desc="_simple_select_one_onecol"):
+ def _simple_select_one_onecol(
+ self,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=False,
+ desc="_simple_select_one_onecol",
+ ):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -922,17 +907,18 @@ class SQLBaseStore(object):
return self.runInteraction(
desc,
self._simple_select_one_onecol_txn,
- table, keyvalues, retcol, allow_none=allow_none,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=allow_none,
)
@classmethod
- def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
- allow_none=False):
+ def _simple_select_one_onecol_txn(
+ cls, txn, table, keyvalues, retcol, allow_none=False
+ ):
ret = cls._simple_select_onecol_txn(
- txn,
- table=table,
- keyvalues=keyvalues,
- retcol=retcol,
+ txn, table=table, keyvalues=keyvalues, retcol=retcol
)
if ret:
@@ -945,12 +931,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = (
- "SELECT %(retcol)s FROM %(table)s"
- ) % {
- "retcol": retcol,
- "table": table,
- }
+ sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
@@ -960,8 +941,9 @@ class SQLBaseStore(object):
return [r[0] for r in txn]
- def _simple_select_onecol(self, table, keyvalues, retcol,
- desc="_simple_select_onecol"):
+ def _simple_select_onecol(
+ self, table, keyvalues, retcol, desc="_simple_select_onecol"
+ ):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -974,13 +956,12 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- desc,
- self._simple_select_onecol_txn,
- table, keyvalues, retcol
+ desc, self._simple_select_onecol_txn, table, keyvalues, retcol
)
- def _simple_select_list(self, table, keyvalues, retcols,
- desc="_simple_select_list"):
+ def _simple_select_list(
+ self, table, keyvalues, retcols, desc="_simple_select_list"
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -994,9 +975,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
- desc,
- self._simple_select_list_txn,
- table, keyvalues, retcols
+ desc, self._simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
@@ -1016,22 +995,26 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
else:
- sql = "SELECT %s FROM %s" % (
- ", ".join(retcols),
- table
- )
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
- def _simple_select_many_batch(self, table, column, iterable, retcols,
- keyvalues={}, desc="_simple_select_many_batch",
- batch_size=100):
+ def _simple_select_many_batch(
+ self,
+ table,
+ column,
+ iterable,
+ retcols,
+ keyvalues={},
+ desc="_simple_select_many_batch",
+ batch_size=100,
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
it_list = list(iterable)
chunks = [
- it_list[i:i + batch_size]
- for i in range(0, len(it_list), batch_size)
+ it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
desc,
self._simple_select_many_txn,
- table, column, chunk, keyvalues, retcols
+ table,
+ column,
+ chunk,
+ keyvalues,
+ retcols,
)
results.extend(rows)
@@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
clauses = []
values = []
- clauses.append(
- "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
- )
+ clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
- sql = "%s WHERE %s" % (
- sql,
- " AND ".join(clauses),
- )
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
- desc,
- self._simple_update_txn,
- table, keyvalues, updatevalues,
+ desc, self._simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
@@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
where,
)
- txn.execute(
- update_sql,
- list(updatevalues.values()) + list(keyvalues.values())
- )
+ txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
return txn.rowcount
- def _simple_update_one(self, table, keyvalues, updatevalues,
- desc="_simple_update_one"):
+ def _simple_update_one(
+ self, table, keyvalues, updatevalues, desc="_simple_update_one"
+ ):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
- desc,
- self._simple_update_one_txn,
- table, keyvalues, updatevalues,
+ desc, self._simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
@@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols,
- allow_none=False):
+ def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(select_sql, list(keyvalues.values()))
@@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction(
- desc, self._simple_delete_one_txn, table, keyvalues
- )
+ return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
@@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
@@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(
- desc, self._simple_delete_txn, table, keyvalues
- )
+ return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
return txn.execute(sql, list(keyvalues.values()))
@@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
clauses = []
values = []
- clauses.append(
- "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
- )
+ clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
- sql = "%s WHERE %s" % (
- sql,
- " AND ".join(clauses),
- )
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
return txn.execute(sql, values)
- def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
- max_value, limit=100000):
+ def _get_cache_dict(
+ self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+ ):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
- cache = {
- row[0]: int(row[1])
- for row in txn
- }
+ cache = {row[0]: int(row[1]) for row in txn}
txn.close()
@@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, keys,
- )
+ self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
@@ -1355,28 +1316,13 @@ class SQLBaseStore(object):
members_changed (iterable[str]): The user_ids of members that have
changed
"""
- for member in members_changed:
- self._attempt_to_invalidate_cache(
- "get_rooms_for_user_with_stream_ordering", (member,),
- )
-
for host in set(get_domain_from_id(u) for u in members_changed):
- self._attempt_to_invalidate_cache(
- "is_host_joined", (room_id, host,),
- )
- self._attempt_to_invalidate_cache(
- "was_host_joined", (room_id, host,),
- )
+ self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
+ self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
- self._attempt_to_invalidate_cache(
- "get_users_in_room", (room_id,),
- )
- self._attempt_to_invalidate_cache(
- "get_room_summary", (room_id,),
- )
- self._attempt_to_invalidate_cache(
- "get_current_state_ids", (room_id,),
- )
+ self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
+ self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(self, cache_name, key):
"""Attempts to invalidate the cache of the given name, ignoring if the
@@ -1424,7 +1370,7 @@ class SQLBaseStore(object):
"cache_func": cache_name,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
- }
+ },
)
def get_all_updated_caches(self, last_id, current_id, limit):
@@ -1440,11 +1386,10 @@ class SQLBaseStore(object):
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_id, limit,))
+ txn.execute(sql, (last_id, limit))
return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
+
+ return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
def get_cache_stream_token(self):
if self._cache_id_gen:
@@ -1452,33 +1397,61 @@ class SQLBaseStore(object):
else:
return 0
- def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
- desc="_simple_select_list_paginate"):
- """Executes a SELECT query on the named table with start and limit,
+ def _simple_select_list_paginate(
+ self,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ desc="_simple_select_list_paginate",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
- keyvalues (dict[str, Any] | None):
+ keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
- table, keyvalues, pagevalues, retcols
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction=order_direction,
)
@classmethod
- def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
- """Executes a SELECT query on the named table with start and limit,
+ def _simple_select_list_paginate_txn(
+ cls,
+ txn,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
@@ -1488,66 +1461,32 @@ class SQLBaseStore(object):
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
- pagevalues ([]):
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
-
"""
+ if order_direction not in ["ASC", "DESC"]:
+ raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- " ? ASC LIMIT ? OFFSET ?"
- )
- txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
+ where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
else:
- sql = "SELECT %s FROM %s ORDER BY %s" % (
- ", ".join(retcols),
- table,
- " ? ASC LIMIT ? OFFSET ?"
- )
- txn.execute(sql, pagevalues)
-
- return cls.cursor_to_dict(txn)
+ where_clause = ""
- @defer.inlineCallbacks
- def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
- desc="get_user_list_paginate"):
- """Get a list of users from start row to a limit number of rows. This will
- return a json object with users and total number of users in users list.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- pagevalues ([]):
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
- """
- users = yield self.runInteraction(
- desc,
- self._simple_select_list_paginate_txn,
- table, keyvalues, pagevalues, retcols
- )
- count = yield self.runInteraction(
- desc,
- self.get_user_count_txn
+ sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+ ", ".join(retcols),
+ table,
+ where_clause,
+ orderby,
+ order_direction,
)
- retval = {
- "users": users,
- "total": count
- }
- defer.returnValue(retval)
+ txn.execute(sql, list(keyvalues.values()) + [limit, start])
+
+ return cls.cursor_to_dict(txn)
def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list.
@@ -1561,8 +1500,9 @@ class SQLBaseStore(object):
txn.execute(sql_count)
return txn.fetchone()[0]
- def _simple_search_list(self, table, term, col, retcols,
- desc="_simple_search_list"):
+ def _simple_search_list(
+ self, table, term, col, retcols, desc="_simple_search_list"
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1577,9 +1517,7 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
- desc,
- self._simple_search_list_txn,
- table, term, col, retcols
+ desc, self._simple_search_list_txn, table, term, col, retcols
)
@classmethod
@@ -1598,11 +1536,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
- ", ".join(retcols),
- table,
- col
- )
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
@@ -1623,6 +1557,7 @@ class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
+
pass
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index bbc3355c73..8394389073 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max,
+ "AccountDataAndTagsChangeCache", account_max
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn(
- txn, "account_data", {"user_id": user_id},
- ["account_data_type", "content"]
+ txn,
+ "account_data",
+ {"user_id": user_id},
+ ["account_data_type", "content"],
)
global_account_data = {
@@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
}
rows = self._simple_select_list_txn(
- txn, "room_account_data", {"user_id": user_id},
- ["room_id", "account_data_type", "content"]
+ txn,
+ "room_account_data",
+ {"user_id": user_id},
+ ["room_id", "account_data_type", "content"],
)
by_room = {}
@@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
result = yield self._simple_select_one_onecol(
table="account_data",
- keyvalues={
- "user_id": user_id,
- "account_data_type": data_type,
- },
+ keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
desc="get_global_account_data_by_type_for_user",
allow_none=True,
@@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
A deferred dict of the room account_data
"""
+
def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn(
- txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"]
+ txn,
+ "room_account_data",
+ {"user_id": user_id, "room_id": room_id},
+ ["account_data_type", "content"],
)
return {
@@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
A deferred of the room account_data for that type, or None if
there isn't any set.
"""
+
def get_account_data_for_room_and_type_txn(txn):
content_json = self._simple_select_one_onecol_txn(
txn,
@@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
"account_data_type": account_data_type,
},
retcol="content",
- allow_none=True
+ allow_none=True,
)
return json.loads(content_json) if content_json else None
return self.runInteraction(
- "get_account_data_for_room_and_type",
- get_account_data_for_room_and_type_txn,
+ "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
- def get_all_updated_account_data(self, last_global_id, last_room_id,
- current_id, limit):
+ def get_all_updated_account_data(
+ self, last_global_id, last_room_id, current_id, limit
+ ):
"""Get all the client account_data that has changed on the server
Args:
last_global_id(int): The position to fetch from for top level data
@@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return (global_results, room_results)
+
return self.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- global_account_data = {
- row[0]: json.loads(row[1]) for row in txn
- }
+ global_account_data = {row[0]: json.loads(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", ignorer_user_id,
+ "m.ignored_user_list",
+ ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
@@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
"room_id": room_id,
"account_data_type": account_data_type,
},
- values={
- "stream_id": next_id,
- "content": content_json,
- },
+ values={"stream_id": next_id, "content": content_json},
lock=False,
)
@@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
- self.get_account_data_for_room.invalidate((user_id, room_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
- (user_id, room_id, account_data_type,), content,
+ (user_id, room_id, account_data_type), content
)
result = self._account_data_id_gen.get_current_token()
@@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
yield self._simple_upsert(
desc="add_user_account_data",
table="account_data",
- keyvalues={
- "user_id": user_id,
- "account_data_type": account_data_type,
- },
- values={
- "stream_id": next_id,
- "content": content_json,
- },
+ keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+ values={"stream_id": next_id, "content": content_json},
lock=False,
)
@@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
# transaction.
yield self._update_max_stream_id(next_id)
- self._account_data_stream_cache.entity_has_changed(
- user_id, next_id,
- )
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
- (account_data_type, user_id,)
+ (account_data_type, user_id)
)
result = self._account_data_id_gen.get_current_token()
@@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
Args:
next_id(int): The the revision to advance to.
"""
+
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
@@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.runInteraction(
- "update_account_data_max_stream_id",
- _update,
- )
+
+ return self.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 31248d5e06..6092f600ba 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
self.services_cache = load_appservices(
- hs.hostname,
- hs.config.app_service_config_files
+ hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
@@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
pass
-class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
- EventsWorkerStore):
+class ApplicationServiceTransactionWorkerStore(
+ ApplicationServiceWorkerStore, EventsWorkerStore
+):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
may be empty.
"""
results = yield self._simple_select_list(
- "application_services_state",
- dict(state=state),
- ["as_id"]
+ "application_services_state", dict(state=state), ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves when the state was set successfully.
"""
return self._simple_upsert(
- "application_services_state",
- dict(as_id=service.id),
- dict(state=state)
+ "application_services_state", dict(as_id=service.id), dict(state=state)
)
def create_appservice_txn(self, service, events):
@@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
Returns:
AppServiceTransaction: A new transaction.
"""
+
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
@@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,)
+ (service.id,),
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
@@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
- (service.id, new_txn_id, event_ids)
- )
- return AppServiceTransaction(
- service=service, id=new_txn_id, events=events
+ (service.id, new_txn_id, event_ids),
)
+ return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.runInteraction(
- "create_appservice_txn",
- _create_appservice_txn,
- )
+ return self.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s", last_txn_id, txn_id,
- service.id
+ "completing_txn=%s service_id=%s",
+ last_txn_id,
+ txn_id,
+ service.id,
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
- txn, "application_services_state", dict(as_id=service.id),
- dict(last_txn=txn_id)
+ txn,
+ "application_services_state",
+ dict(as_id=service.id),
+ dict(last_txn=txn_id),
)
# Delete txn
self._simple_delete_txn(
- txn, "application_services_txns",
- dict(txn_id=txn_id, as_id=service.id)
+ txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
)
- return self.runInteraction(
- "complete_appservice_txn",
- _complete_appservice_txn,
- )
+ return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves to an AppServiceTransaction or
None.
"""
+
def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
- (service.id,)
+ (service.id,),
)
rows = self.cursor_to_dict(txn)
if not rows:
@@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return entry
entry = yield self.runInteraction(
- "get_oldest_unsent_appservice_txn",
- _get_oldest_unsent_txn,
+ "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
if not entry:
@@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
events = yield self._get_events(event_ids)
- defer.returnValue(AppServiceTransaction(
- service=service, id=entry["txn_id"], events=events
- ))
+ defer.returnValue(
+ AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
+ )
def _get_last_txn(self, txn, service_id):
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service_id,)
+ (service_id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
@@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
+
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
- "get_new_events_for_appservice", get_new_events_for_appservice_txn,
+ "get_new_events_for_appservice", get_new_events_for_appservice_txn
)
events = yield self._get_events(event_ids)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index a2f8c23a65..b8b8273f73 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = False
def start_doing_background_updates(self):
- run_as_background_process(
- "background_updates", self._run_background_updates,
- )
+ run_as_background_process("background_updates", self._run_background_updates)
@defer.inlineCallbacks
def _run_background_updates(self):
logger.info("Starting background schema updates")
while True:
- yield self.hs.get_clock().sleep(
- self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
+ yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
result = yield self.do_next_background_update(
@@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
- logger.info("Starting update batch on background update '%s'",
- update_name)
+ logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = yield self._simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
- retcol="progress_json"
+ retcol="progress_json",
)
progress = json.loads(progress_json)
@@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
- update_name, items_updated, duration_ms,
+ update_name,
+ items_updated,
+ duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
@@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore):
Args:
update_name (str): Name of update
"""
+
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
@@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self.register_background_update_handler(update_name, noop_update)
- def register_background_index_update(self, update_name, index_name,
- table, columns, where_clause=None,
- unique=False,
- psql_only=False):
+ def register_background_index_update(
+ self,
+ update_name,
+ index_name,
+ table,
+ columns,
+ where_clause=None,
+ unique=False,
+ psql_only=False,
+ ):
"""Helper for store classes to do a background index addition
To use:
@@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore):
"name": index_name,
"table": table,
"columns": ", ".join(columns),
- "where_clause": "WHERE " + where_clause if where_clause else ""
+ "where_clause": "WHERE " + where_clause if where_clause else "",
}
logger.debug("[SQL] %s", sql)
c.execute(sql)
@@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore):
return self._simple_insert(
"background_updates",
- {"update_name": update_name, "progress_json": progress_json}
+ {"update_name": update_name, "progress_json": progress_json},
)
def _end_background_update(self, update_name):
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 9c21362226..bda68de5be 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen",
- keylen=4,
- max_entries=50000 * CACHE_SIZE_FACTOR,
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
@@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
self.register_background_update_handler(
- "user_ips_analyze",
- self._analyze_user_ip,
+ "user_ips_analyze", self._analyze_user_ip
)
self.register_background_update_handler(
- "user_ips_remove_dupes",
- self._remove_user_ip_dupes,
+ "user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
@@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Drop the old non-unique index
self.register_background_update_handler(
- "user_ips_drop_nonunique_index",
- self._remove_user_ip_nonunique,
+ "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
@@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
- txn.execute(
- "DROP INDEX IF EXISTS user_ips_user_ip"
- )
+ txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
yield self.runWithConnection(f)
@@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.runInteraction(
- "user_ips_analyze", user_ips_analyze
- )
+ yield self.runInteraction("user_ips_analyze", user_ips_analyze)
yield self._end_background_update("user_ips_analyze")
@@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
LIMIT 1
OFFSET ?
""",
- (begin_last_seen, batch_size)
+ (begin_last_seen, batch_size),
)
row = txn.fetchone()
if row:
@@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
logger.info(
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
- begin_last_seen, end_last_seen,
+ begin_last_seen,
+ end_last_seen,
)
def remove(txn):
@@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip
HAVING count(*) > 1
- """.format(clause),
- args
+ """.format(
+ clause
+ ),
+ args,
)
res = txn.fetchall()
@@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
""",
- (user_id, access_token, ip, last_seen)
+ (user_id, access_token, ip, last_seen),
)
if txn.rowcount == count - 1:
# We deleted all but one of the duplicate rows, i.e. there
@@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
continue
elif txn.rowcount >= count:
raise Exception(
- "We deleted more duplicate rows from 'user_ips' than expected",
+ "We deleted more duplicate rows from 'user_ips' than expected"
)
# The previous step didn't delete enough rows, so we fallback to
@@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ?
""",
- (user_id, access_token, ip)
+ (user_id, access_token, ip),
)
# Add in one to be the last_seen
@@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen)
VALUES (?, ?, ?, ?, ?, ?)
""",
- (user_id, access_token, ip, device_id, user_agent, last_seen)
+ (user_id, access_token, ip, device_id, user_agent, last_seen),
)
self._background_update_progress_txn(
@@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
defer.returnValue(batch_size)
@defer.inlineCallbacks
- def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
- now=None):
+ def insert_client_ip(
+ self, user_id, access_token, ip, user_agent, device_id, now=None
+ ):
if not now:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
@@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn,
- to_update,
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
- return run_as_background_process(
- "update_client_ips", update,
- )
+ return run_as_background_process("update_client_ips", update)
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
@@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
- user_id, device_id,
+ user_id,
+ device_id,
retcols=(
"user_id",
"access_token",
@@ -416,7 +409,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
bindings = []
if device_id is None:
where_clauses.append("user_id = ?")
- bindings.extend((user_id, ))
+ bindings.extend((user_id,))
else:
where_clauses.append("(user_id = ? AND device_id = ?)")
bindings.extend((user_id, device_id))
@@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s "
"GROUP BY user_id, device_id"
- ) % {
- "where": " OR ".join(where_clauses),
- }
+ ) % {"where": " OR ".join(where_clauses)}
sql = (
"SELECT %(retcols)s FROM user_ips "
@@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
rows = yield self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
- retcols=[
- "access_token", "ip", "user_agent", "last_seen"
- ],
+ retcols=["access_token", "ip", "user_agent", "last_seen"],
desc="get_user_ip_and_agents",
)
@@ -472,12 +461,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- defer.returnValue(list(
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
- ))
+ defer.returnValue(
+ list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ )
+ )
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index e6a42a53bb..fed4ea3610 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -57,9 +57,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC"
" LIMIT ?"
)
- txn.execute(sql, (
- user_id, device_id, last_stream_id, current_stream_id, limit
- ))
+ txn.execute(
+ sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+ )
messages = []
for row in txn:
stream_pos = row[0]
@@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return (messages, stream_pos)
return self.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn,
+ "get_new_messages_for_device", get_new_messages_for_device_txn
)
@defer.inlineCallbacks
@@ -146,9 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC"
" LIMIT ?"
)
- txn.execute(sql, (
- destination, last_stream_id, current_stream_id, limit
- ))
+ txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
messages = []
for row in txn:
stream_pos = row[0]
@@ -172,6 +170,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
Returns:
A deferred that resolves when the messages have been deleted.
"""
+
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
@@ -181,8 +180,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
- "delete_device_msgs_for_remote",
- delete_messages_for_remote_destination_txn
+ "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
@@ -200,8 +198,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
)
self.register_background_update_handler(
- self.DEVICE_INBOX_STREAM_ID,
- self._background_drop_index_device_inbox,
+ self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
# Map of (user_id, device_id) to the last stream_id that has been
@@ -214,8 +211,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
- remote_messages_by_destination):
+ def add_messages_to_device_inbox(
+ self, local_messages_by_user_then_device, remote_messages_by_destination
+ ):
"""Used to send messages from this server.
Args:
@@ -252,15 +250,10 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
- "add_messages_to_device_inbox",
- add_messages_txn,
- now_ms,
- stream_id,
+ "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
- self._device_inbox_stream_cache.entity_has_changed(
- user_id, stream_id
- )
+ self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id
@@ -277,7 +270,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
- txn, table="device_federation_inbox",
+ txn,
+ table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",),
allow_none=True,
@@ -288,7 +282,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
- txn, table="device_federation_inbox",
+ txn,
+ table="device_federation_inbox",
values={
"origin": origin,
"message_id": message_id,
@@ -311,19 +306,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
- self._device_inbox_stream_cache.entity_has_changed(
- user_id, stream_id
- )
+ self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
defer.returnValue(stream_id)
- def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
- messages_by_user_then_device):
- sql = (
- "UPDATE device_max_stream_id"
- " SET stream_id = ?"
- " WHERE stream_id < ?"
- )
+ def _add_messages_to_local_device_inbox_txn(
+ self, txn, stream_id, messages_by_user_then_device
+ ):
+ sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
@@ -332,10 +322,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = (
- "SELECT device_id FROM devices"
- " WHERE user_id = ?"
- )
+ sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
@@ -428,9 +415,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
- txn.execute(
- "DROP INDEX IF EXISTS device_inbox_stream_id"
- )
+ txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index e716dc1437..fd869b934c 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user"
+ desc="get_devices_by_user",
)
defer.returnValue({d["device_id"]: d for d in devices})
@@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore):
return (now_stream_id, [])
return self.runInteraction(
- "get_devices_by_remote", self._get_devices_by_remote_txn,
- destination, from_stream_id, now_stream_id,
+ "get_devices_by_remote",
+ self._get_devices_by_remote_txn,
+ destination,
+ from_stream_id,
+ now_stream_id,
)
- def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
- now_stream_id):
+ def _get_devices_by_remote_txn(
+ self, txn, destination, from_stream_id, now_stream_id
+ ):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
- txn.execute(
- sql, (destination, from_stream_id, now_stream_id, False)
- )
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
@@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+ txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
)
prev_sent_id_sql = """
@@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
- "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
- destination, stream_id,
+ "mark_as_sent_devices_by_remote",
+ self._mark_as_sent_devices_by_remote_txn,
+ destination,
+ stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
@@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
- txn.execute(sql, (destination, stream_id,))
+ txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
sql = """
@@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore):
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(
- sql, ((row[1], destination, row[0],) for row in rows if row[2])
- )
+ txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ sql, ((destination, row[0], row[1]) for row in rows if not row[2])
)
# Delete all sent outbound pokes
@@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
- txn.execute(sql, (destination, stream_id,))
+ txn.execute(sql, (destination, stream_id))
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
+ keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
desc="_get_cached_user_device",
)
@@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
- defer.returnValue({
- device["device_id"]: db_to_json(device["content"])
- for device in devices
- })
+ defer.returnValue(
+ {device["device_id"]: db_to_json(device["content"]) for device in devices}
+ )
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore):
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn, user_id,
+ self._get_devices_with_keys_by_user_txn,
+ user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
@@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
- result = {
- "device_id": device_id,
- }
+ result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
@@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
- rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+ rows = yield self._execute(
+ "get_user_whose_devices_changed", None, sql, from_key
+ )
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
@@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore):
GROUP BY user_id, destination
"""
return self._execute(
- "get_all_device_list_changes_for_remotes", None,
- sql, from_key, to_key
+ "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@cached(max_entries=10000)
@@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
- list_name="user_ids", inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids",
+ inlineCallbacks=True,
+ )
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
- retcols=("user_id", "stream_id",),
+ retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
- results.update({
- row["user_id"]: row["stream_id"] for row in rows
- })
+ results.update({row["user_id"]: row["stream_id"] for row in rows})
defer.returnValue(results)
@@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
- name="device_id_exists",
- keylen=2,
- max_entries=10000,
+ name="device_id_exists", keylen=2, max_entries=10000
)
- self._clock.looping_call(
- self._prune_old_outbound_device_pokes, 60 * 60 * 1000
- )
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
self.register_background_index_update(
"device_lists_stream_idx",
@@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def store_device(self, user_id, device_id,
- initial_device_display_name):
+ def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
Args:
@@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
values={
"user_id": user_id,
"device_id": device_id,
- "display_name": initial_device_display_name
+ "display_name": initial_device_display_name,
},
desc="store_device",
or_ignore=True,
@@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self.device_id_exists_cache.prefill(key, True)
defer.returnValue(inserted)
except Exception as e:
- logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
- " display_name=%s(%r) failed: %s",
- type(device_id).__name__, device_id,
- type(user_id).__name__, user_id,
- type(initial_device_display_name).__name__,
- initial_device_display_name, e)
+ logger.error(
+ "store_device with device_id=%s(%r) user_id=%s(%r)"
+ " display_name=%s(%r) failed: %s",
+ type(device_id).__name__,
+ device_id,
+ type(user_id).__name__,
+ user_id,
+ type(initial_device_display_name).__name__,
+ initial_device_display_name,
+ e,
+ )
raise StoreError(500, "Problem storing device.")
@defer.inlineCallbacks
@@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
yield self._simple_delete(
table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
)
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
- def update_remote_device_list_cache_entry(self, user_id, device_id, content,
- stream_id):
+ def update_remote_device_list_cache_entry(
+ self, user_id, device_id, content, stream_id
+ ):
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
- user_id, device_id, content, stream_id,
+ user_id,
+ device_id,
+ content,
+ stream_id,
)
- def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
- content, stream_id):
+ def _update_remote_device_list_cache_entry_txn(
+ self, txn, user_id, device_id, content, stream_id
+ ):
if content.get("deleted"):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
+ keyvalues={"user_id": user_id, "device_id": device_id},
)
- txn.call_after(
- self.device_id_exists_cache.invalidate, (user_id, device_id,)
- )
+ txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "content": json.dumps(content),
- },
-
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"content": json.dumps(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
)
- txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+ txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
@@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
- values={
- "stream_id": stream_id,
- },
-
+ keyvalues={"user_id": user_id},
+ values={"stream_id": stream_id},
# again, we can assume we are the only thread updating this user's
# extremity.
lock=False,
@@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
- user_id, devices, stream_id,
+ user_id,
+ devices,
+ stream_id,
)
- def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
- stream_id):
+ def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
self._simple_delete_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
+ txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
self._simple_insert_many_txn(
@@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"content": json.dumps(content),
}
for content in devices
- ]
+ ],
)
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
- values={
- "stream_id": stream_id,
- },
-
+ keyvalues={"user_id": user_id},
+ values={"stream_id": stream_id},
# we don't need to lock, because we can assume we are the only thread
# updating this user's extremity.
lock=False,
@@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
- "add_device_change_to_streams", self._add_device_change_txn,
- user_id, device_ids, hosts, stream_id,
+ "add_device_change_to_streams",
+ self._add_device_change_txn,
+ user_id,
+ device_ids,
+ hosts,
+ stream_id,
)
defer.returnValue(stream_id)
@@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
now = self._clock.time_msec()
txn.call_after(
- self._device_list_stream_cache.entity_has_changed,
- user_id, stream_id,
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_id
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
- host, stream_id,
+ host,
+ stream_id,
)
# Delete older entries in the table, as we really only care about
@@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids]
+ [(user_id, device_id, stream_id) for device_id in device_ids],
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
- {
- "stream_id": stream_id,
- "user_id": user_id,
- "device_id": device_id,
- }
+ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids
- ]
+ ],
)
self._simple_insert_many_txn(
@@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
}
for destination in hosts
for device_id in device_ids
- ]
+ ],
)
def _prune_old_outbound_device_pokes(self):
@@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
txn.executemany(
- delete_sql,
- (
- (yesterday, row[0], row[1], row[2])
- for row in rows
- )
+ delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
)
# Since we've deleted unsent deltas, we need to remove the entry
@@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
- txn.execute(
- "DROP INDEX IF EXISTS device_lists_remote_cache_id"
- )
- txn.execute(
- "DROP INDEX IF EXISTS device_lists_remote_extremeties_id"
- )
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 61a029a53c..201bbd430c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -22,10 +22,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
-RoomAliasMapping = namedtuple(
- "RoomAliasMapping",
- ("room_id", "room_alias", "servers",)
-)
+RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
@@ -63,16 +60,12 @@ class DirectoryWorkerStore(SQLBaseStore):
defer.returnValue(None)
return
- defer.returnValue(
- RoomAliasMapping(room_id, room_alias.to_string(), servers)
- )
+ defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
- keyvalues={
- "room_alias": room_alias,
- },
+ keyvalues={"room_alias": room_alias},
retcol="creator",
desc="get_room_alias_creator",
)
@@ -101,6 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
Returns:
Deferred
"""
+
def alias_txn(txn):
self._simple_insert_txn(
txn,
@@ -115,10 +109,10 @@ class DirectoryStore(DirectoryWorkerStore):
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
- values=[{
- "room_alias": room_alias.to_string(),
- "server": server,
- } for server in servers],
+ values=[
+ {"room_alias": room_alias.to_string(), "server": server}
+ for server in servers
+ ],
)
self._invalidate_cache_and_stream(
@@ -126,9 +120,7 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.runInteraction(
- "create_room_alias_association", alias_txn
- )
+ ret = yield self.runInteraction("create_room_alias_association", alias_txn)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
@@ -138,9 +130,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
- "delete_room_alias",
- self._delete_room_alias_txn,
- room_alias,
+ "delete_room_alias", self._delete_room_alias_txn, room_alias
)
defer.returnValue(room_id)
@@ -148,7 +138,7 @@ class DirectoryStore(DirectoryWorkerStore):
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
- (room_alias.to_string(),)
+ (room_alias.to_string(),),
)
res = txn.fetchone()
@@ -158,31 +148,29 @@ class DirectoryStore(DirectoryWorkerStore):
return None
txn.execute(
- "DELETE FROM room_aliases WHERE room_alias = ?",
- (room_alias.to_string(),)
+ "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),)
)
txn.execute(
"DELETE FROM room_alias_servers WHERE room_alias = ?",
- (room_alias.to_string(),)
+ (room_alias.to_string(),),
)
- self._invalidate_cache_and_stream(
- txn, self.get_aliases_for_room, (room_id,)
- )
+ self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,))
return room_id
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
- txn.execute(sql, (new_room_id, creator, old_room_id,))
+ txn.execute(sql, (new_room_id, creator, old_room_id))
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (new_room_id,)
)
+
return self.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 9a3aec759e..521936e3b0 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -23,7 +23,6 @@ from ._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
-
@defer.inlineCallbacks
def get_e2e_room_key(self, user_id, version, room_id, session_id):
"""Get the encrypted E2E room key for a given session from a given
@@ -97,9 +96,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_e2e_room_keys(
- self, user_id, version, room_id=None, session_id=None
- ):
+ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@@ -123,10 +120,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError:
defer.returnValue({'rooms': {}})
- keyvalues = {
- "user_id": user_id,
- "version": version,
- }
+ keyvalues = {"user_id": user_id, "version": version}
if room_id:
keyvalues['room_id'] = room_id
if session_id:
@@ -160,9 +154,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
defer.returnValue(sessions)
@defer.inlineCallbacks
- def delete_e2e_room_keys(
- self, user_id, version, room_id=None, session_id=None
- ):
+ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@@ -180,19 +172,14 @@ class EndToEndRoomKeyStore(SQLBaseStore):
A deferred of the deletion transaction
"""
- keyvalues = {
- "user_id": user_id,
- "version": int(version),
- }
+ keyvalues = {"user_id": user_id, "version": int(version)}
if room_id:
keyvalues['room_id'] = room_id
if session_id:
keyvalues['session_id'] = session_id
yield self._simple_delete(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- desc="delete_e2e_room_keys",
+ table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@staticmethod
@@ -200,7 +187,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
- (user_id,)
+ (user_id,),
)
row = txn.fetchone()
if not row:
@@ -238,24 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result = self._simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
- keyvalues={
- "user_id": user_id,
- "version": this_version,
- "deleted": 0,
- },
- retcols=(
- "version",
- "algorithm",
- "auth_data",
- ),
+ keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
+ retcols=("version", "algorithm", "auth_data"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
return result
return self.runInteraction(
- "get_e2e_room_keys_version_info",
- _get_e2e_room_keys_version_info_txn
+ "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
def create_e2e_room_keys_version(self, user_id, info):
@@ -273,7 +251,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _create_e2e_room_keys_version_txn(txn):
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
- (user_id,)
+ (user_id,),
)
current_version = txn.fetchone()[0]
if current_version is None:
@@ -309,14 +287,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update(
table="e2e_room_keys_versions",
- keyvalues={
- "user_id": user_id,
- "version": version,
- },
- updatevalues={
- "auth_data": json.dumps(info["auth_data"]),
- },
- desc="update_e2e_room_keys_version"
+ keyvalues={"user_id": user_id, "version": version},
+ updatevalues={"auth_data": json.dumps(info["auth_data"])},
+ desc="update_e2e_room_keys_version",
)
def delete_e2e_room_keys_version(self, user_id, version=None):
@@ -341,16 +314,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
- keyvalues={
- "user_id": user_id,
- "version": this_version,
- },
- updatevalues={
- "deleted": 1,
- }
+ keyvalues={"user_id": user_id, "version": this_version},
+ updatevalues={"deleted": 1},
)
return self.runInteraction(
- "delete_e2e_room_keys_version",
- _delete_e2e_room_keys_version_txn
+ "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index e381e472a2..2fabb9e2cb 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json
class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
- self, query_list, include_all_devices=False,
- include_deleted_devices=False,
+ self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
Args:
@@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue({})
results = yield self.runInteraction(
- "get_e2e_device_keys", self._get_e2e_device_keys_txn,
- query_list, include_all_devices, include_deleted_devices,
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
)
for user_id, device_keys in iteritems(results):
@@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue(results)
def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False,
- include_deleted_devices=False,
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
):
query_clauses = []
query_params = []
@@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
" WHERE %s"
) % (
"LEFT" if include_all_devices else "INNER",
- " OR ".join("(" + q + ")" for q in query_clauses)
+ " OR ".join("(" + q + ")" for q in query_clauses),
)
txn.execute(sql, query_params)
@@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json",),
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
)
- defer.returnValue({
- (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
- })
+ defer.returnValue(
+ {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ )
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
@@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self._simple_insert_many_txn(
- txn, table="e2e_one_time_keys_json",
+ txn,
+ table="e2e_one_time_keys_json",
values=[
{
"user_id": user_id,
@@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
],
)
self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
"""
+
def _count_e2e_one_time_keys(txn):
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
@@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for algorithm, key_count in txn:
result[algorithm] = key_count
return result
- return self.runInteraction(
- "count_e2e_one_time_keys", _count_e2e_one_time_keys
- )
+
+ return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
@@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
+
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
+ keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
@@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "ts_added_ms": time_now,
- "key_json": new_key_json,
- }
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
)
return True
- return self.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
- )
+ return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
+
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
@@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return result
- return self.runInteraction(
- "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
- )
+
+ return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
@@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+
return self.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index ff5ef97ca8..9d2d519922 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -20,10 +20,7 @@ from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {
- "sqlite3": Sqlite3Engine,
- "psycopg2": PostgresEngine,
-}
+SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config):
@@ -32,15 +29,12 @@ def create_engine(database_config):
if engine_class:
# pypy requires psycopg2cffi rather than psycopg2
- if (name == "psycopg2" and
- platform.python_implementation() == "PyPy"):
+ if name == "psycopg2" and platform.python_implementation() == "PyPy":
name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
- raise RuntimeError(
- "Unsupported database engine '%s'" % (name,)
- )
+ raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index dc3238501c..1b97ee74e3 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -23,7 +23,7 @@ class PostgresEngine(object):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
self.synchronous_commit = database_config.get("synchronous_commit", True)
- self._version = None # unknown as yet
+ self._version = None # unknown as yet
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@@ -31,8 +31,7 @@ class PostgresEngine(object):
if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information."
- % (rows[0][0],)
+ "See docs/postgres.rst for more information." % (rows[0][0],)
)
def convert_param_style(self, sql):
@@ -103,12 +102,6 @@ class PostgresEngine(object):
# https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
if numver >= 100000:
- return "%i.%i" % (
- numver / 10000, numver % 10000,
- )
+ return "%i.%i" % (numver / 10000, numver % 10000)
else:
- return "%i.%i.%i" % (
- numver / 10000,
- (numver % 10000) / 100,
- numver % 100,
- )
+ return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 1bcd5b99a4..933bcf42c2 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -82,9 +82,10 @@ class Sqlite3Engine(object):
# Following functions taken from: https://github.com/coleifer/peewee
+
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
@@ -98,7 +99,7 @@ def _rank(raw_match_info):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
col_idx = phrase_info_idx + (col_num * 3)
- x1, x2 = match_info[col_idx:col_idx + 2]
+ x1, x2 = match_info[col_idx : col_idx + 2]
if x1 > 0:
score += float(x1) / x2
return score
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index a8d90456e3..956f876572 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,8 +32,7 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
- SQLBaseStore):
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@@ -45,7 +44,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of events
"""
return self.get_auth_chain_ids(
- event_ids, include_given=include_given,
+ event_ids, include_given=include_given
).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False):
@@ -59,9 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of event_ids
"""
return self.runInteraction(
- "get_auth_chain_ids",
- self._get_auth_chain_ids_txn,
- event_ids, include_given
+ "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
)
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
@@ -70,23 +67,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
else:
results = set()
- base_sql = (
- "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
- )
+ base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
front = set(event_ids)
while front:
new_front = set()
front_list = list(front)
- chunks = [
- front_list[x:x + 100]
- for x in range(0, len(front), 100)
- ]
+ chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks:
- txn.execute(
- base_sql % (",".join(["?"] * len(chunk)),),
- chunk
- )
+ txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
new_front.update([r[0] for r in txn])
new_front -= results
@@ -98,9 +87,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(
- "get_oldest_events_in_room",
- self._get_oldest_events_in_room_txn,
- room_id,
+ "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
@@ -121,7 +108,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
" GROUP BY b.event_id"
)
- txn.execute(sql, (room_id, False,))
+ txn.execute(sql, (room_id, False))
return dict(txn)
@@ -152,9 +139,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return self._simple_select_onecol_txn(
txn,
table="event_backward_extremities",
- keyvalues={
- "room_id": room_id,
- },
+ keyvalues={"room_id": room_id},
retcol="event_id",
)
@@ -209,9 +194,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
table="event_forward_extremities",
- keyvalues={
- "room_id": room_id,
- },
+ keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
@@ -225,14 +208,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
"WHERE f.room_id = ?"
)
- txn.execute(sql, (room_id, ))
+ txn.execute(sql, (room_id,))
results = []
for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
- k: encode_base64(v) for k, v in hashes.items()
- if k == "sha256"
+ k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
}
results.append((event_id, prev_hashes, depth))
@@ -242,9 +224,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
""" For hte given room, get the minimum depth we have seen for it.
"""
return self.runInteraction(
- "get_min_depth",
- self._get_min_depth_interaction,
- room_id,
+ "get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
@@ -300,7 +280,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
- sql = ("""
+ sql = """
SELECT event_id FROM stream_ordering_to_exterm
INNER JOIN (
SELECT room_id, MAX(stream_ordering) AS stream_ordering
@@ -308,15 +288,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
WHERE stream_ordering <= ? GROUP BY room_id
) AS rms USING (room_id, stream_ordering)
WHERE room_id = ?
- """)
+ """
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.runInteraction(
- "get_forward_extremeties_for_room",
- get_forward_extremeties_for_room_txn
+ "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
def get_backfill_events(self, room_id, event_list, limit):
@@ -329,19 +308,21 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_list (list)
limit (int)
"""
- return self.runInteraction(
- "get_backfill_events",
- self._get_backfill_events, room_id, event_list, limit
- ).addCallback(
- self._get_events
- ).addCallback(
- lambda l: sorted(l, key=lambda e: -e.depth)
+ return (
+ self.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
+ )
+ .addCallback(self._get_events)
+ .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug(
- "_get_backfill_events: %s, %s, %s",
- room_id, repr(event_list), limit
+ "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
)
event_results = set()
@@ -364,10 +345,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
depth = self._simple_select_one_onecol_txn(
txn,
table="events",
- keyvalues={
- "event_id": event_id,
- "room_id": room_id,
- },
+ keyvalues={"event_id": event_id, "room_id": room_id},
retcol="depth",
allow_none=True,
)
@@ -386,10 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.add(event_id)
- txn.execute(
- query,
- (event_id, False, limit - len(event_results))
- )
+ txn.execute(query, (event_id, False, limit - len(event_results)))
for row in txn:
if row[1] not in event_results:
@@ -398,18 +373,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return event_results
@defer.inlineCallbacks
- def get_missing_events(self, room_id, earliest_events, latest_events,
- limit):
+ def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.runInteraction(
"get_missing_events",
self._get_missing_events,
- room_id, earliest_events, latest_events, limit,
+ room_id,
+ earliest_events,
+ latest_events,
+ limit,
)
events = yield self._get_events(ids)
defer.returnValue(events)
- def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
- limit):
+ def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
@@ -425,8 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
new_front = set()
for event_id in front:
txn.execute(
- query,
- (room_id, event_id, False, limit - len(event_results))
+ query, (room_id, event_id, False, limit - len(event_results))
)
new_results = set(t[0] for t in txn) - seen_events
@@ -457,12 +432,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
column="prev_event_id",
iterable=event_ids,
retcols=("event_id",),
- desc="get_successor_events"
+ desc="get_successor_events",
)
- defer.returnValue([
- row["event_id"] for row in rows
- ])
+ defer.returnValue([row["event_id"] for row in rows])
class EventFederationStore(EventFederationWorkerStore):
@@ -481,12 +454,11 @@ class EventFederationStore(EventFederationWorkerStore):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
- self.EVENT_AUTH_STATE_ONLY,
- self._background_delete_non_state_event_auth,
+ self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000,
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
@@ -498,12 +470,8 @@ class EventFederationStore(EventFederationWorkerStore):
self._simple_upsert_txn(
txn,
table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
+ keyvalues={"room_id": room_id},
+ values={"min_depth": depth},
)
def _handle_mult_prev_events(self, txn, events):
@@ -553,11 +521,15 @@ class EventFederationStore(EventFederationWorkerStore):
" )"
)
- txn.executemany(query, [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ])
+ txn.executemany(
+ query,
+ [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ ],
+ )
query = (
"DELETE FROM event_backward_extremities"
@@ -566,16 +538,17 @@ class EventFederationStore(EventFederationWorkerStore):
txn.executemany(
query,
[
- (ev.event_id, ev.room_id) for ev in events
+ (ev.event_id, ev.room_id)
+ for ev in events
if not ev.internal_metadata.is_outlier()
- ]
+ ],
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
- sql = ("""
+ sql = """
DELETE FROM stream_ordering_to_exterm
WHERE
room_id IN (
@@ -583,11 +556,11 @@ class EventFederationStore(EventFederationWorkerStore):
FROM stream_ordering_to_exterm
WHERE stream_ordering > ?
) AND stream_ordering < ?
- """)
+ """
txn.execute(
- sql,
- (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
)
+
return run_as_background_process(
"delete_old_forward_extrem_cache",
self.runInteraction,
@@ -597,9 +570,7 @@ class EventFederationStore(EventFederationWorkerStore):
def clean_room_for_join(self, room_id):
return self.runInteraction(
- "clean_room_for_join",
- self._clean_room_for_join_txn,
- room_id,
+ "clean_room_for_join", self._clean_room_for_join_txn, room_id
)
def _clean_room_for_join_txn(self, txn, room_id):
@@ -635,7 +606,7 @@ class EventFederationStore(EventFederationWorkerStore):
)
"""
- txn.execute(sql, (min_stream_id, max_stream_id,))
+ txn.execute(sql, (min_stream_id, max_stream_id))
new_progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 6840320641..a729f3e067 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
DEFAULT_HIGHLIGHT_ACTION = [
- "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
]
@@ -91,25 +93,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
+ self, room_id, user_id, last_read_event_id
):
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
- room_id, user_id, last_read_event_id
+ room_id,
+ user_id,
+ last_read_event_id,
)
defer.returnValue(ret)
- def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
- last_read_event_id):
+ def _get_unread_counts_by_receipt_txn(
+ self, txn, room_id, user_id, last_read_event_id
+ ):
sql = (
"SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
- txn.execute(
- sql, (room_id, last_read_event_id)
- )
+ txn.execute(sql, (room_id, last_read_event_id))
results = txn.fetchall()
if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0}
@@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone()
notify_count = row[0] if row else 0
- txn.execute("""
+ txn.execute(
+ """
SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """, (room_id, user_id, stream_ordering,))
+ """,
+ (room_id, user_id, stream_ordering),
+ )
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
@@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone()
highlight_count = row[0] if row else 0
- return {
- "notify_count": notify_count,
- "highlight_count": highlight_count,
- }
+ return {"notify_count": notify_count, "highlight_count": highlight_count}
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
+
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
@@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
- args = [
- user_id, user_id,
- min_stream_ordering, max_stream_ordering, limit,
- ]
+ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
+
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
- args = [
- user_id, user_id,
- min_stream_ordering, max_stream_ordering, limit,
- ]
+ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
+
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"room_id": row[1],
"stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]),
- } for row in after_read_receipt + no_read_receipt
+ }
+ for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
@@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
- args = [
- user_id, user_id,
- min_stream_ordering, max_stream_ordering, limit,
- ]
+ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
+
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
- args = [
- user_id, user_id,
- min_stream_ordering, max_stream_ordering, limit,
- ]
+ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
+
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]),
"received_ts": row[5],
- } for row in after_read_receipt + no_read_receipt
+ }
+ for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
@@ -386,6 +384,36 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
defer.returnValue(notifs[:limit])
+ def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+ """A fast check to see if there might be something to push for the
+ user since the given stream ordering. May return false positives.
+
+ Useful to know whether to bother starting a pusher on start up or not.
+
+ Args:
+ user_id (str)
+ min_stream_ordering (int)
+
+ Returns:
+ Deferred[bool]: True if there may be push to process, False if
+ there definitely isn't.
+ """
+
+ def _get_if_maybe_push_in_range_for_user_txn(txn):
+ sql = """
+ SELECT 1 FROM event_push_actions
+ WHERE user_id = ? AND stream_ordering > ?
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, min_stream_ordering))
+ return bool(txn.fetchone())
+
+ return self.runInteraction(
+ "get_if_maybe_push_in_range_for_user",
+ _get_if_maybe_push_in_range_for_user_txn,
+ )
+
def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
@@ -424,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?)
"""
- txn.executemany(sql, (
- _gen_entry(user_id, actions)
- for user_id, actions in iteritems(user_id_actions)
- ))
+ txn.executemany(
+ sql,
+ (
+ _gen_entry(user_id, actions)
+ for user_id, actions in iteritems(user_id_actions)
+ ),
+ )
return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
@@ -445,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
try:
res = yield self._simple_delete(
table="event_push_actions_staging",
- keyvalues={
- "event_id": event_id,
- },
+ keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
defer.returnValue(res)
@@ -456,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# another exception here really isn't helpful - there's nothing
# the caller can do about it. Just log the exception and move on.
logger.exception(
- "Error removing push actions after event persistence failure",
+ "Error removing push actions after event persistence failure"
)
def _find_stream_orderings_for_times(self):
@@ -473,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
)
logger.info(
- "Found stream ordering 1 month ago: it's %d",
- self.stream_ordering_month_ago
+ "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago
)
logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
)
logger.info(
- "Found stream ordering 1 day ago: it's %d",
- self.stream_ordering_day_ago
+ "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
def find_first_stream_ordering_after_ts(self, ts):
@@ -601,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
index_name="event_push_actions_highlights_index",
table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1"
+ where_clause="highlight=1",
)
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
- self._start_rotate_notifs, 30 * 60 * 1000,
+ self._start_rotate_notifs, 30 * 60 * 1000
)
- def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
- all_events_and_contexts):
+ def _set_push_actions_for_event_and_users_txn(
+ self, txn, events_and_contexts, all_events_and_contexts
+ ):
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@@ -637,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"""
if events_and_contexts:
- txn.executemany(sql, (
+ txn.executemany(
+ sql,
(
- event.room_id, event.internal_metadata.stream_ordering,
- event.depth, event.event_id,
- )
- for event, _ in events_and_contexts
- ))
+ (
+ event.room_id,
+ event.internal_metadata.stream_ordering,
+ event.depth,
+ event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ),
+ )
for event, _ in events_and_contexts:
user_ids = self._simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
- keyvalues={
- "event_id": event.event_id,
- },
+ keyvalues={"event_id": event.event_id},
retcol="user_id",
)
for uid in user_ids:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid,)
+ (event.room_id, uid),
)
# Now we delete the staging area for *all* events that were being
# persisted.
txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
- (
- (event.event_id,)
- for event, _ in all_events_and_contexts
- )
+ ((event.event_id,) for event, _ in all_events_and_contexts),
)
@defer.inlineCallbacks
- def get_push_actions_for_user(self, user_id, before=None, limit=50,
- only_highlight=False):
+ def get_push_actions_for_user(
+ self, user_id, before=None, limit=50, only_highlight=False
+ ):
def f(txn):
before_clause = ""
if before:
@@ -697,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC"
- " LIMIT ?"
- % (before_clause,)
+ " LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
- push_actions = yield self.runInteraction(
- "get_push_actions_for_user", f
- )
+ push_actions = yield self.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions)
@@ -723,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
+
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
defer.returnValue(result[0] if result else None)
@@ -731,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
+
+ result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
defer.returnValue(result[0] or 0)
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id,)
+ (room_id,),
)
txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
- (room_id, event_id)
+ (room_id, event_id),
)
- def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- stream_ordering):
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
"""
Purges old push actions for a user and room before a given
stream_ordering.
@@ -765,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id, )
+ (room_id, user_id),
)
# We need to join on the events table to get the received_ts for
@@ -781,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
- txn.execute("""
+ txn.execute(
+ """
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """, (room_id, user_id, stream_ordering))
+ """,
+ (room_id, user_id, stream_ordering),
+ )
def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs)
@@ -803,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
logger.info("Rotating notifications")
caught_up = yield self.runInteraction(
- "_rotate_notifs",
- self._rotate_notifs_txn
+ "_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
break
@@ -826,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# We don't to try and rotate millions of rows at once, so we cap the
# maximum stream ordering we'll rotate before.
- txn.execute("""
+ txn.execute(
+ """
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """, (old_rotate_stream_ordering, self._rotate_count))
+ """,
+ (old_rotate_stream_ordering, self._rotate_count),
+ )
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
@@ -874,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+ txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
rows = txn.fetchall()
logger.info("Rotating notifications, handling %d rows", len(rows))
@@ -892,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"notif_count": row[2],
"stream_ordering": row[3],
}
- for row in rows if row[4] is None
- ]
+ for row in rows
+ if row[4] is None
+ ],
)
txn.executemany(
@@ -901,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+ ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
)
txn.execute(
"DELETE FROM event_push_actions"
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
- (old_rotate_stream_ordering, rotate_to_stream_ordering,)
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
)
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
txn.execute(
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
- (rotate_to_stream_ordering,)
+ (rotate_to_stream_ordering,),
)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 428300ea0a..7a7f841c6c 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -30,7 +30,6 @@ from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -51,8 +50,11 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
-event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
- ["type", "origin_type", "origin_entity"])
+event_counter = Counter(
+ "synapse_storage_events_persisted_events_sep",
+ "",
+ ["type", "origin_type", "origin_entity"],
+)
# The number of times we are recalculating the current state
state_delta_counter = Counter("synapse_storage_events_state_delta", "")
@@ -60,13 +62,15 @@ state_delta_counter = Counter("synapse_storage_events_state_delta", "")
# The number of times we are recalculating state when there is only a
# single forward extremity
state_delta_single_event_counter = Counter(
- "synapse_storage_events_state_delta_single_event", "")
+ "synapse_storage_events_state_delta_single_event", ""
+)
# The number of times we are reculating state when we could have resonably
# calculated the delta when we calculated the state for an event we were
# persisting.
state_delta_reuse_delta_counter = Counter(
- "synapse_storage_events_state_delta_reuse_delta", "")
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
def encode_json(json_object):
@@ -75,7 +79,7 @@ def encode_json(json_object):
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
- out = out.decode('utf8')
+ out = out.decode("utf8")
return out
@@ -84,9 +88,9 @@ class _EventPeristenceQueue(object):
concurrent transaction per room.
"""
- _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
- "events_and_contexts", "backfilled", "deferred",
- ))
+ _EventPersistQueueItem = namedtuple(
+ "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+ )
def __init__(self):
self._event_persist_queues = {}
@@ -119,11 +123,13 @@ class _EventPeristenceQueue(object):
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
- queue.append(self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
- deferred=deferred,
- ))
+ queue.append(
+ self._EventPersistQueueItem(
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ )
return deferred.observe()
@@ -191,6 +197,7 @@ def _retry_on_integrity_error(func):
Args:
func: function that returns a Deferred and accepts a `delete_existing` arg
"""
+
@wraps(func)
@defer.inlineCallbacks
def f(self, *args, **kwargs):
@@ -206,8 +213,12 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
- BackgroundUpdateStore):
+class EventsStore(
+ StateGroupWorkerStore,
+ EventFederationStore,
+ EventsWorkerStore,
+ BackgroundUpdateStore,
+):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@@ -265,8 +276,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
deferreds = []
for room_id, evs_ctxs in iteritems(partitioned):
d = self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs,
- backfilled=backfilled,
+ room_id, evs_ctxs, backfilled=backfilled
)
deferreds.append(d)
@@ -296,8 +306,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
and the stream ordering of the latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)],
- backfilled=backfilled,
+ event.room_id, [(event, context)], backfilled=backfilled
)
self._maybe_start_persisting(event.room_id)
@@ -312,16 +321,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
def persisting_queue(item):
with Measure(self._clock, "persist_events"):
yield self._persist_events(
- item.events_and_contexts,
- backfilled=item.backfilled,
+ item.events_and_contexts, backfilled=item.backfilled
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(self, events_and_contexts, backfilled=False,
- delete_existing=False):
+ def _persist_events(
+ self, events_and_contexts, backfilled=False, delete_existing=False
+ ):
"""Persist events to db
Args:
@@ -345,13 +354,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
)
with stream_ordering_manager as stream_orderings:
- for (event, context), stream, in zip(
- events_and_contexts, stream_orderings
- ):
+ for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
chunks = [
- events_and_contexts[x:x + 100]
+ events_and_contexts[x : x + 100]
for x in range(0, len(events_and_contexts), 100)
]
@@ -445,12 +452,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
state_delta_reuse_delta_counter.inc()
break
- logger.info(
- "Calculating state delta for room %s", room_id,
- )
+ logger.info("Calculating state delta for room %s", room_id)
with Measure(
- self._clock,
- "persist_events.get_new_state_after_events",
+ self._clock, "persist_events.get_new_state_after_events"
):
res = yield self._get_new_state_after_events(
room_id,
@@ -470,11 +474,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
state_delta_for_room[room_id] = ([], delta_ids)
elif current_state is not None:
with Measure(
- self._clock,
- "persist_events.calculate_state_delta",
+ self._clock, "persist_events.calculate_state_delta"
):
delta = yield self._calculate_state_delta(
- room_id, current_state,
+ room_id, current_state
)
state_delta_for_room[room_id] = delta
@@ -498,7 +501,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering,
+ chunk[-1][0].internal_metadata.stream_ordering
)
for event, context in chunk:
@@ -515,9 +518,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
event_counter.labels(event.type, origin_type, origin_entity).inc()
for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill(
- (room_id, ), new_state
- )
+ self.get_current_state_ids.prefill((room_id,), new_state)
for room_id, latest_event_ids in iteritems(new_forward_extremeties):
self.get_latest_event_ids_in_room.prefill(
@@ -535,8 +536,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# we're only interested in new events which aren't outliers and which aren't
# being rejected.
new_events = [
- event for event, ctx in event_contexts
- if not event.internal_metadata.is_outlier() and not ctx.rejected
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
and not event.internal_metadata.is_soft_failed()
]
@@ -544,15 +547,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
result = set(latest_event_ids)
# add all the new events to the list
- result.update(
- event.event_id for event in new_events
- )
+ result.update(event.event_id for event in new_events)
# Now remove all events which are prev_events of any of the new events
result.difference_update(
- e_id
- for event in new_events
- for e_id in event.prev_event_ids()
+ e_id for event in new_events for e_id in event.prev_event_ids()
)
# Finally, remove any events which are prev_events of any existing events.
@@ -592,17 +591,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
results.extend(r[0] for r in txn)
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
- "_get_events_which_are_prevs",
- _get_events,
- chunk,
- )
+ yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
defer.returnValue(results)
@defer.inlineCallbacks
- def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
- new_latest_event_ids):
+ def _get_new_state_after_events(
+ self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+ ):
"""Calculate the current state dict after adding some new events to
a room
@@ -642,7 +638,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
if not ev.internal_metadata.is_outlier():
raise Exception(
"Context for new event %s has no state "
- "group" % (ev.event_id, ),
+ "group" % (ev.event_id,)
)
continue
@@ -682,9 +678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
if missing_event_ids:
# Now pull out the state groups for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(
- missing_event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
event_id_to_state_group.update(event_to_groups)
# State groups of old_latest_event_ids
@@ -710,9 +704,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
new_state_group = next(iter(new_state_groups))
old_state_group = next(iter(old_state_groups))
- delta_ids = state_group_deltas.get(
- (old_state_group, new_state_group,), None
- )
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
if delta_ids is not None:
# We have a delta from the existing to new current state,
# so lets just return that. If we happen to already have
@@ -735,9 +727,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Ok, we need to defer to the state handler to resolve our state sets.
- state_groups = {
- sg: state_groups_map[sg] for sg in new_state_groups
- }
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -755,8 +745,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
- room_id, room_version, state_groups, events_map,
- state_res_store=StateResolutionStore(self)
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self),
)
defer.returnValue((res.state, None))
@@ -774,22 +767,26 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"""
existing_state = yield self.get_current_state_ids(room_id)
- to_delete = [
- key for key in existing_state
- if key not in current_state
- ]
+ to_delete = [key for key in existing_state if key not in current_state]
to_insert = {
- key: ev_id for key, ev_id in iteritems(current_state)
+ key: ev_id
+ for key, ev_id in iteritems(current_state)
if ev_id != existing_state.get(key)
}
defer.returnValue((to_delete, to_insert))
@log_function
- def _persist_events_txn(self, txn, events_and_contexts, backfilled,
- delete_existing=False, state_delta_for_room={},
- new_forward_extremeties={}):
+ def _persist_events_txn(
+ self,
+ txn,
+ events_and_contexts,
+ backfilled,
+ delete_existing=False,
+ state_delta_for_room={},
+ new_forward_extremeties={},
+ ):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
@@ -816,9 +813,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"""
all_events_and_contexts = events_and_contexts
+ min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
self._update_forward_extremities_txn(
txn,
@@ -828,20 +826,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Ensure that we don't have the same event twice.
events_and_contexts = self._filter_events_and_contexts_for_duplicates(
- events_and_contexts,
+ events_and_contexts
)
self._update_room_depths_txn(
- txn,
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
+ txn, events_and_contexts=events_and_contexts, backfilled=backfilled
)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
events_and_contexts = self._update_outliers_txn(
- txn,
- events_and_contexts=events_and_contexts,
+ txn, events_and_contexts=events_and_contexts
)
# From this point onwards the events are only events that we haven't
@@ -852,15 +847,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# for these events so we can reinsert them.
# This gets around any problems with some tables already having
# entries.
- self._delete_existing_rows_txn(
- txn,
- events_and_contexts=events_and_contexts,
- )
+ self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
- self._store_event_txn(
- txn,
- events_and_contexts=events_and_contexts,
- )
+ self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
@@ -889,8 +878,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
events_and_contexts = self._store_rejected_events_txn(
- txn,
- events_and_contexts=events_and_contexts,
+ txn, events_and_contexts=events_and_contexts
)
# From this point onwards the events are only ones that weren't
@@ -903,7 +891,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
+ def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
@@ -912,6 +900,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# that we can use it to calculate the `prev_event_id`. (This
# allows us to not have to pull out the existing state
# unnecessarily).
+ #
+ # The stream_id for the update is chosen to be the minimum of the stream_ids
+ # for the batch of the events that we are persisting; that means we do not
+ # end up in a situation where workers see events before the
+ # current_state_delta updates.
+ #
sql = """
INSERT INTO current_state_delta_stream
(stream_id, room_id, type, state_key, event_id, prev_event_id)
@@ -920,22 +914,40 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(sql, (
+ txn.executemany(
+ sql,
(
- max_stream_order, room_id, etype, state_key, None,
- room_id, etype, state_key,
- )
- for etype, state_key in to_delete
- # We sanity check that we're deleting rather than updating
- if (etype, state_key) not in to_insert
- ))
- txn.executemany(sql, (
+ (
+ stream_id,
+ room_id,
+ etype,
+ state_key,
+ None,
+ room_id,
+ etype,
+ state_key,
+ )
+ for etype, state_key in to_delete
+ # We sanity check that we're deleting rather than updating
+ if (etype, state_key) not in to_insert
+ ),
+ )
+ txn.executemany(
+ sql,
(
- max_stream_order, room_id, etype, state_key, ev_id,
- room_id, etype, state_key,
- )
- for (etype, state_key), ev_id in iteritems(to_insert)
- ))
+ (
+ stream_id,
+ room_id,
+ etype,
+ state_key,
+ ev_id,
+ room_id,
+ etype,
+ state_key,
+ )
+ for (etype, state_key), ev_id in iteritems(to_insert)
+ ),
+ )
# Now we actually update the current_state_events table
@@ -964,7 +976,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
- room_id, max_stream_order,
+ room_id,
+ stream_id,
)
# Invalidate the various caches
@@ -980,28 +993,27 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
if ev_type == EventTypes.Member
)
+ for member in members_changed:
+ txn.call_after(
+ self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+ )
+
self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
- def _update_forward_extremities_txn(self, txn, new_forward_extremities,
- max_stream_order):
+ def _update_forward_extremities_txn(
+ self, txn, new_forward_extremities, max_stream_order
+ ):
for room_id, new_extrem in iteritems(new_forward_extremities):
self._simple_delete_txn(
- txn,
- table="event_forward_extremities",
- keyvalues={"room_id": room_id},
- )
- txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, (room_id,)
+ txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
+ txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
self._simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- }
+ {"event_id": ev_id, "room_id": room_id}
for room_id, new_extrem in iteritems(new_forward_extremities)
for ev_id in new_extrem
],
@@ -1021,7 +1033,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
}
for room_id, new_extrem in iteritems(new_forward_extremities)
for event_id in new_extrem
- ]
+ ],
)
@classmethod
@@ -1065,7 +1077,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
- event.room_id, event.internal_metadata.stream_ordering,
+ event.room_id,
+ event.internal_metadata.stream_ordering,
)
if not event.internal_metadata.is_outlier() and not context.rejected:
@@ -1092,16 +1105,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
are already in the events table.
"""
txn.execute(
- "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
- ",".join(["?"] * len(events_and_contexts)),
- ),
- [event.event_id for event, _ in events_and_contexts]
+ "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
+ % (",".join(["?"] * len(events_and_contexts)),),
+ [event.event_id for event, _ in events_and_contexts],
)
- have_persisted = {
- event_id: outlier
- for event_id, outlier in txn
- }
+ have_persisted = {event_id: outlier for event_id, outlier in txn}
to_remove = set()
for event, context in events_and_contexts:
@@ -1128,18 +1137,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
logger.exception("")
raise
- metadata_json = encode_json(
- event.internal_metadata.get_dict()
- )
+ metadata_json = encode_json(event.internal_metadata.get_dict())
sql = (
- "UPDATE event_json SET internal_metadata = ?"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (metadata_json, event.event_id,)
+ "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
)
+ txn.execute(sql, (metadata_json, event.event_id))
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
@@ -1152,25 +1155,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
- }
+ },
)
- sql = (
- "UPDATE events SET outlier = ?"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (False, event.event_id,)
- )
+ sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+ txn.execute(sql, (False, event.event_id))
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_backward_extremeties(txn, [event])
- return [
- ec for ec in events_and_contexts if ec[0] not in to_remove
- ]
+ return [ec for ec in events_and_contexts if ec[0] not in to_remove]
@classmethod
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
@@ -1181,39 +1176,33 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
logger.info("Deleting existing")
for table in (
- "events",
- "event_auth",
- "event_json",
- "event_content_hashes",
- "event_destinations",
- "event_edge_hashes",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "event_signatures",
- "event_to_state_groups",
- "guest_access",
- "history_visibility",
- "local_invites",
- "room_names",
- "state_events",
- "rejections",
- "redactions",
- "room_memberships",
- "topics"
+ "events",
+ "event_auth",
+ "event_json",
+ "event_edges",
+ "event_forward_extremities",
+ "event_reference_hashes",
+ "event_search",
+ "event_to_state_groups",
+ "guest_access",
+ "history_visibility",
+ "local_invites",
+ "room_names",
+ "state_events",
+ "rejections",
+ "redactions",
+ "room_memberships",
+ "topics",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts]
+ [(ev.event_id,) for ev, _ in events_and_contexts],
)
- for table in (
- "event_push_actions",
- ):
+ for table in ("event_push_actions",):
txn.executemany(
"DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
- [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts]
+ [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
)
def _store_event_txn(self, txn, events_and_contexts):
@@ -1296,17 +1285,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
- self._store_rejections_txn(
- txn, event.event_id, context.rejected
- )
+ self._store_rejections_txn(txn, event.event_id, context.rejected)
to_remove.add(event)
- return [
- ec for ec in events_and_contexts if ec[0] not in to_remove
- ]
+ return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- def _update_metadata_tables_txn(self, txn, events_and_contexts,
- all_events_and_contexts, backfilled):
+ def _update_metadata_tables_txn(
+ self, txn, events_and_contexts, all_events_and_contexts, backfilled
+ ):
"""Update all the miscellaneous tables for new events
Args:
@@ -1342,8 +1328,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
- txn,
- events=[event for event, _ in events_and_contexts],
+ txn, events=[event for event, _ in events_and_contexts]
)
for event, _ in events_and_contexts:
@@ -1401,11 +1386,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
state_values.append(vals)
- self._simple_insert_many_txn(
- txn,
- table="state_events",
- values=state_values,
- )
+ self._simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1416,10 +1397,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
rows = []
N = 200
for i in range(0, len(events_and_contexts), N):
- ev_map = {
- e[0].event_id: e[0]
- for e in events_and_contexts[i:i + N]
- }
+ ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
if not ev_map:
break
@@ -1439,14 +1417,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(
- event=event,
- redacted_event=None,
- ))
+ to_prefill.append(
+ _EventCacheEntry(event=event, redacted_event=None)
+ )
def prefill():
for cache_entry in to_prefill:
self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+
txn.call_after(prefill)
def _store_redaction(self, txn, event):
@@ -1454,7 +1432,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
txn.call_after(self._invalidate_get_event_cache, event.redacts)
txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts)
+ (event.event_id, event.redacts),
)
@defer.inlineCallbacks
@@ -1465,6 +1443,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
+
def _count_messages(txn):
sql = """
SELECT COALESCE(COUNT(*), 0) FROM events
@@ -1492,7 +1471,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
AND stream_ordering > ?
"""
- txn.execute(sql, (like_clause, self.stream_ordering_day_ago,))
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
count, = txn.fetchone()
return count
@@ -1557,18 +1536,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
update_rows.append((sender, contains_url, event_id))
- sql = (
- "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- )
+ sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+ clump = update_rows[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows)
+ "rows_inserted": rows_inserted + len(rows),
}
self._background_update_progress_txn(
@@ -1613,10 +1590,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
rows_to_update = []
- chunks = [
- event_ids[i:i + 100]
- for i in range(0, len(event_ids), 100)
- ]
+ chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
txn,
@@ -1639,18 +1613,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
rows_to_update.append((origin_server_ts, event_id))
- sql = (
- "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- )
+ sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
+ clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows_to_update)
+ "rows_inserted": rows_inserted + len(rows_to_update),
}
self._background_update_progress_txn(
@@ -1714,6 +1686,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
new_event_updates.extend(txn)
return new_event_updates
+
return self.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@@ -1756,13 +1729,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
new_event_updates.extend(txn.fetchall())
return new_event_updates
+
return self.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@cached(num_args=5, max_entries=10)
- def get_all_new_events(self, last_backfill_id, last_forward_id,
- current_backfill_id, current_forward_id, limit):
+ def get_all_new_events(
+ self,
+ last_backfill_id,
+ last_forward_id,
+ current_backfill_id,
+ current_forward_id,
+ limit,
+ ):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
have_backfill_events = last_backfill_id != current_backfill_id
@@ -1837,14 +1817,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
backward_ex_outliers = []
return AllNewEventsResult(
- new_forward_events, new_backfill_events,
- forward_ex_outliers, backward_ex_outliers,
+ new_forward_events,
+ new_backfill_events,
+ forward_ex_outliers,
+ backward_ex_outliers,
)
+
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
- def purge_history(
- self, room_id, token, delete_local_events,
- ):
+ def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
Args:
@@ -1860,28 +1841,24 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
return self.runInteraction(
"purge_history",
- self._purge_history_txn, room_id, token,
+ self._purge_history_txn,
+ room_id,
+ token,
delete_local_events,
)
- def _purge_history_txn(
- self, txn, room_id, token_str, delete_local_events,
- ):
+ def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
token = RoomStreamToken.parse(token_str)
# Tables that should be pruned:
# event_auth
# event_backward_extremities
- # event_content_hashes
- # event_destinations
- # event_edge_hashes
# event_edges
# event_forward_extremities
# event_json
# event_push_actions
# event_reference_hashes
# event_search
- # event_signatures
# event_to_state_groups
# events
# rejections
@@ -1913,7 +1890,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"ON e.event_id = f.event_id "
"AND e.room_id = f.room_id "
"WHERE f.room_id = ?",
- (room_id,)
+ (room_id,),
)
rows = txn.fetchall()
max_depth = max(row[1] for row in rows)
@@ -1934,10 +1911,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
should_delete_expr += " AND event_id NOT LIKE ?"
# We include the parameter twice since we use the expression twice
- should_delete_params += (
- "%:" + self.hs.hostname,
- "%:" + self.hs.hostname,
- )
+ should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
should_delete_params += (room_id, token.topological)
@@ -1948,10 +1922,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
" SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)"
" WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
- % (
- should_delete_expr,
- should_delete_expr,
- ),
+ % (should_delete_expr, should_delete_expr),
should_delete_params,
)
@@ -1961,23 +1932,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# the should_delete / shouldn't_delete subsets
txn.execute(
"CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)",
+ " ON events_to_purge(should_delete)"
)
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
- txn.execute(
- "CREATE INDEX events_to_purge_id"
- " ON events_to_purge(event_id)",
- )
+ txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
- txn.execute(
- "SELECT event_id, should_delete FROM events_to_purge"
- )
+ txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
logger.info(
"[purge] found %i events before cutoff, of which %i can be deleted",
- len(event_rows), sum(1 for e in event_rows if e[1]),
+ len(event_rows),
+ sum(1 for e in event_rows if e[1]),
)
logger.info("[purge] Finding new backward extremities")
@@ -1989,24 +1956,21 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"SELECT DISTINCT e.event_id FROM events_to_purge AS e"
" INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
" LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
- " WHERE ep2.event_id IS NULL",
+ " WHERE ep2.event_id IS NULL"
)
new_backwards_extrems = txn.fetchall()
logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute(
- "DELETE FROM event_backward_extremities WHERE room_id = ?",
- (room_id,)
+ "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
)
# Update backward extremeties
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
- [
- (room_id, event_id) for event_id, in new_backwards_extrems
- ]
+ [(room_id, event_id) for event_id, in new_backwards_extrems],
)
logger.info("[purge] finding redundant state groups")
@@ -2014,28 +1978,25 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Get all state groups that are referenced by events that are to be
# deleted. We then go and check if they are referenced by other events
# or state groups, and if not we delete them.
- txn.execute("""
+ txn.execute(
+ """
SELECT DISTINCT state_group FROM events_to_purge
INNER JOIN event_to_state_groups USING (event_id)
- """)
+ """
+ )
referenced_state_groups = set(sg for sg, in txn)
logger.info(
- "[purge] found %i referenced state groups",
- len(referenced_state_groups),
+ "[purge] found %i referenced state groups", len(referenced_state_groups)
)
logger.info("[purge] finding state groups that can be deleted")
- state_groups_to_delete, remaining_state_groups = (
- self._find_unreferenced_groups_during_purge(
- txn, referenced_state_groups,
- )
- )
+ _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
+ state_groups_to_delete, remaining_state_groups = _
logger.info(
- "[purge] found %i state groups to delete",
- len(state_groups_to_delete),
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
)
logger.info(
@@ -2047,25 +2008,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(
- txn, [sg],
- )
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
self._simple_delete_txn(
- txn,
- table="state_groups_state",
- keyvalues={
- "state_group": sg,
- }
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
)
self._simple_delete_txn(
- txn,
- table="state_group_edges",
- keyvalues={
- "state_group": sg,
- }
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
)
self._simple_insert_many_txn(
@@ -2099,23 +2050,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"WHERE event_id IN (SELECT event_id from events_to_purge)"
)
for event_id, _ in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (
- event_id,
- ))
+ txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
# Delete all remote non-state events
for table in (
"events",
"event_json",
"event_auth",
- "event_content_hashes",
- "event_destinations",
- "event_edge_hashes",
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
"event_search",
- "event_signatures",
"rejections",
):
logger.info("[purge] removing events from %s", table)
@@ -2123,21 +2068,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
txn.execute(
"DELETE FROM %s WHERE event_id IN ("
" SELECT event_id FROM events_to_purge WHERE should_delete"
- ")" % (table,),
+ ")" % (table,)
)
# event_push_actions lacks an index on event_id, and has one on
# (room_id, event_id) instead.
- for table in (
- "event_push_actions",
- ):
+ for table in ("event_push_actions",):
logger.info("[purge] removing events from %s", table)
txn.execute(
"DELETE FROM %s WHERE room_id = ? AND event_id IN ("
" SELECT event_id FROM events_to_purge WHERE should_delete"
")" % (table,),
- (room_id, )
+ (room_id,),
)
# Mark all state and own events as outliers
@@ -2162,27 +2105,28 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# extremities. However, the events in event_backward_extremities
# are ones we don't have yet so we need to look at the events that
# point to it via event_edges table.
- txn.execute("""
+ txn.execute(
+ """
SELECT COALESCE(MIN(depth), 0)
FROM event_backward_extremities AS eb
INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
INNER JOIN events AS e ON e.event_id = eg.event_id
WHERE eb.room_id = ?
- """, (room_id,))
+ """,
+ (room_id,),
+ )
min_depth, = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (min_depth, room_id,)
+ (min_depth, room_id),
)
# finally, drop the temp table. this will commit the txn in sqlite,
# so make sure to keep this actually last.
- txn.execute(
- "DROP TABLE events_to_purge"
- )
+ txn.execute("DROP TABLE events_to_purge")
logger.info("[purge] done")
@@ -2226,7 +2170,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id)
WHERE state_group IN (%s) AND ep.event_id IS NULL
- """ % (",".join("?" for _ in current_search),)
+ """ % (
+ ",".join("?" for _ in current_search),
+ )
txn.execute(sql, list(current_search))
referenced = set(sg for sg, in txn)
@@ -2242,7 +2188,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
column="prev_state_group",
iterable=current_search,
keyvalues={},
- retcols=("prev_state_group", "state_group",),
+ retcols=("prev_state_group", "state_group"),
)
prevs = set(row["state_group"] for row in rows)
@@ -2279,16 +2225,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
- allow_none=True
+ allow_none=True,
)
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
-
- def get_max_current_state_delta_stream_id(self):
- return self._stream_id_gen.get_current_token()
+ defer.returnValue(
+ (int(res["topological_ordering"]), int(res["stream_ordering"]))
+ )
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
@@ -2300,13 +2245,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
+
return self.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
-AllNewEventsResult = namedtuple("AllNewEventsResult", [
- "new_forward_events", "new_backfill_events",
- "forward_ex_outliers", "backward_ex_outliers",
-])
+AllNewEventsResult = namedtuple(
+ "AllNewEventsResult",
+ [
+ "new_forward_events",
+ "new_backfill_events",
+ "forward_ex_outliers",
+ "backward_ex_outliers",
+ ],
+)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 1716be529a..663991a9b6 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -21,8 +21,9 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.api.constants import EventFormatVersions, EventTypes
+from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
+from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
# these are only included to make the type annotations work
from synapse.events.snapshot import EventContext # noqa: F401
@@ -70,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore):
"""
return self._simple_select_one_onecol(
table="events",
- keyvalues={
- "event_id": event_id,
- },
+ keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
@defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False, check_room_id=None):
+ def get_event(
+ self,
+ event_id,
+ check_redacted=True,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=False,
+ check_room_id=None,
+ ):
"""Get an event from the database by event_id.
Args:
@@ -117,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(event)
@defer.inlineCallbacks
- def get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
+ def get_events(
+ self,
+ event_ids,
+ check_redacted=True,
+ get_prev_content=False,
+ allow_rejected=False,
+ ):
"""Get events from the database
Args:
@@ -142,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
+ def _get_events(
+ self,
+ event_ids,
+ check_redacted=True,
+ get_prev_content=False,
+ allow_rejected=False,
+ ):
if not event_ids:
defer.returnValue([])
@@ -151,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache(
- event_ids,
- allow_rejected=allow_rejected,
+ event_ids, allow_rejected=allow_rejected
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
@@ -168,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore):
#
# _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
- missing_events_ids,
- allow_rejected=allow_rejected,
+ missing_events_ids, allow_rejected=allow_rejected
)
event_entry_map.update(missing_events)
@@ -213,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore):
)
expected_domain = get_domain_from_id(entry.event.sender)
- if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
+ if (
+ orig_sender
+ and get_domain_from_id(orig_sender) == expected_domain
+ ):
# This redaction event is allowed. Mark as not needing a
# recheck.
entry.event.internal_metadata.recheck_redaction = False
@@ -266,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
ret = self._get_event_cache.get(
- (event_id,), None,
- update_metrics=update_metrics,
+ (event_id,), None, update_metrics=update_metrics
)
if not ret:
continue
@@ -317,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore):
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = list(zip(*event_list))[0]
- event_ids = [
- item for sublist in event_id_lists for item in sublist
- ]
+ event_ids = [item for sublist in event_id_lists for item in sublist]
rows = self._new_transaction(
- conn, "do_fetch", [], [],
- self._fetch_event_rows, event_ids,
+ conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
)
- row_dict = {
- r["event_id"]: r
- for r in rows
- }
+ row_dict = {r["event_id"]: r for r in rows}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
@@ -337,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore):
if not d.called:
try:
with PreserveLoggingContext():
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
+ d.callback([res[i] for i in ids if i in res])
except Exception:
logger.exception("Failed to callback")
+
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
@@ -370,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore):
events_d = defer.Deferred()
with self._event_fetch_lock:
- self._event_fetch_list.append(
- (events, events_d)
- )
+ self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
@@ -384,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
- "fetch_events",
- self.runWithConnection,
- self._do_fetch,
+ "fetch_events", self.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events", len(events))
@@ -397,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
- res = yield make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self._get_event_from_row,
- row["internal_metadata"], row["json"], row["redacts"],
- rejected_reason=row["rejects"],
- format_version=row["format_version"],
- )
- for row in rows
- ],
- consumeErrors=True
- ))
+ res = yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._get_event_from_row,
+ row["internal_metadata"],
+ row["json"],
+ row["redacts"],
+ rejected_reason=row["rejects"],
+ format_version=row["format_version"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True,
+ )
+ )
- defer.returnValue({
- e.event.event_id: e
- for e in res if e
- })
+ defer.returnValue({e.event.event_id: e for e in res if e})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) // N):
- evs = events[i * N:(i + 1) * N]
+ evs = events[i * N : (i + 1) * N]
if not evs:
break
@@ -443,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore):
return rows
@defer.inlineCallbacks
- def _get_event_from_row(self, internal_metadata, js, redacted,
- format_version, rejected_reason=None):
+ def _get_event_from_row(
+ self, internal_metadata, js, redacted, format_version, rejected_reason=None
+ ):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
@@ -483,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore):
# Get the redaction event.
because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
+ redaction_id, check_redacted=False, allow_none=True
)
if because:
@@ -507,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore):
redacted_event = None
cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
+ event=original_ev, redacted_event=redacted_event
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
@@ -544,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore):
results = set()
def have_seen_events_txn(txn, chunk):
- sql = (
- "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
- % (",".join("?" * len(chunk)), )
+ sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
+ ",".join("?" * len(chunk)),
)
txn.execute(sql, chunk)
- for (event_id, ) in txn:
+ for (event_id,) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
- []):
- yield self.runInteraction(
- "have_seen_events",
- have_seen_events_txn,
- chunk,
- )
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 6ddcc909bf..b195dc66a0 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -35,10 +35,7 @@ class FilteringStore(SQLBaseStore):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
- keyvalues={
- "user_id": user_localpart,
- "filter_id": filter_id,
- },
+ keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
@@ -61,10 +58,7 @@ class FilteringStore(SQLBaseStore):
if filter_id_response is not None:
return filter_id_response[0]
- sql = (
- "SELECT MAX(filter_id) FROM user_filters "
- "WHERE user_id = ?"
- )
+ sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 592d1b4c2a..dce6a43ac1 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -38,24 +38,22 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_update_one(
table="groups",
- keyvalues={
- "group_id": group_id,
- },
- updatevalues={
- "join_policy": join_policy,
- },
+ keyvalues={"group_id": group_id},
+ updatevalues={"join_policy": join_policy},
desc="set_group_join_policy",
)
def get_group(self, group_id):
return self._simple_select_one(
table="groups",
- keyvalues={
- "group_id": group_id,
- },
+ keyvalues={"group_id": group_id},
retcols=(
- "name", "short_description", "long_description",
- "avatar_url", "is_public", "join_policy",
+ "name",
+ "short_description",
+ "long_description",
+ "avatar_url",
+ "is_public",
+ "join_policy",
),
allow_none=True,
desc="get_group",
@@ -64,16 +62,14 @@ class GroupServerStore(SQLBaseStore):
def get_users_in_group(self, group_id, include_private=False):
# TODO: Pagination
- keyvalues = {
- "group_id": group_id,
- }
+ keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
return self._simple_select_list(
table="group_users",
keyvalues=keyvalues,
- retcols=("user_id", "is_public", "is_admin",),
+ retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)
@@ -82,9 +78,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_select_onecol(
table="group_invites",
- keyvalues={
- "group_id": group_id,
- },
+ keyvalues={"group_id": group_id},
retcol="user_id",
desc="get_invited_users_in_group",
)
@@ -92,16 +86,14 @@ class GroupServerStore(SQLBaseStore):
def get_rooms_in_group(self, group_id, include_private=False):
# TODO: Pagination
- keyvalues = {
- "group_id": group_id,
- }
+ keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
return self._simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
- retcols=("room_id", "is_public",),
+ retcols=("room_id", "is_public"),
desc="get_rooms_in_group",
)
@@ -110,10 +102,9 @@ class GroupServerStore(SQLBaseStore):
Returns ([rooms], [categories])
"""
+
def _get_rooms_for_summary_txn(txn):
- keyvalues = {
- "group_id": group_id,
- }
+ keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -162,18 +153,23 @@ class GroupServerStore(SQLBaseStore):
}
return rooms, categories
- return self.runInteraction(
- "get_rooms_for_summary", _get_rooms_for_summary_txn
- )
+
+ return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.runInteraction(
- "add_room_to_summary", self._add_room_to_summary_txn,
- group_id, room_id, category_id, order, is_public,
+ "add_room_to_summary",
+ self._add_room_to_summary_txn,
+ group_id,
+ room_id,
+ category_id,
+ order,
+ is_public,
)
- def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
- is_public):
+ def _add_room_to_summary_txn(
+ self, txn, group_id, room_id, category_id, order, is_public
+ ):
"""Add (or update) room's entry in summary.
Args:
@@ -188,10 +184,7 @@ class GroupServerStore(SQLBaseStore):
room_in_group = self._simple_select_one_onecol_txn(
txn,
table="group_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- },
+ keyvalues={"group_id": group_id, "room_id": room_id},
retcol="room_id",
allow_none=True,
)
@@ -204,10 +197,7 @@ class GroupServerStore(SQLBaseStore):
cat_exists = self._simple_select_one_onecol_txn(
txn,
table="group_room_categories",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- },
+ keyvalues={"group_id": group_id, "category_id": category_id},
retcol="group_id",
allow_none=True,
)
@@ -218,22 +208,22 @@ class GroupServerStore(SQLBaseStore):
cat_exists = self._simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- },
+ keyvalues={"group_id": group_id, "category_id": category_id},
retcol="group_id",
allow_none=True,
)
if not cat_exists:
# If not, add it with an order larger than all others
- txn.execute("""
+ txn.execute(
+ """
INSERT INTO group_summary_room_categories
(group_id, category_id, cat_order)
SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
FROM group_summary_room_categories
WHERE group_id = ? AND category_id = ?
- """, (group_id, category_id, group_id, category_id))
+ """,
+ (group_id, category_id, group_id, category_id),
+ )
existing = self._simple_select_one_txn(
txn,
@@ -243,7 +233,7 @@ class GroupServerStore(SQLBaseStore):
"room_id": room_id,
"category_id": category_id,
},
- retcols=("room_order", "is_public",),
+ retcols=("room_order", "is_public"),
allow_none=True,
)
@@ -253,13 +243,13 @@ class GroupServerStore(SQLBaseStore):
UPDATE group_summary_rooms SET room_order = room_order + 1
WHERE group_id = ? AND category_id = ? AND room_order >= ?
"""
- txn.execute(sql, (group_id, category_id, order,))
+ txn.execute(sql, (group_id, category_id, order))
elif not existing:
sql = """
SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
WHERE group_id = ? AND category_id = ?
"""
- txn.execute(sql, (group_id, category_id,))
+ txn.execute(sql, (group_id, category_id))
order, = txn.fetchone()
if existing:
@@ -312,29 +302,26 @@ class GroupServerStore(SQLBaseStore):
def get_group_categories(self, group_id):
rows = yield self._simple_select_list(
table="group_room_categories",
- keyvalues={
- "group_id": group_id,
- },
+ keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
desc="get_group_categories",
)
- defer.returnValue({
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ defer.returnValue(
+ {
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
}
- for row in rows
- })
+ )
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
category = yield self._simple_select_one(
table="group_room_categories",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- },
+ keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
desc="get_group_category",
)
@@ -361,10 +348,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_upsert(
table="group_room_categories",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- },
+ keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
insertion_values=insertion_values,
desc="upsert_group_category",
@@ -373,10 +357,7 @@ class GroupServerStore(SQLBaseStore):
def remove_group_category(self, group_id, category_id):
return self._simple_delete(
table="group_room_categories",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- },
+ keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
)
@@ -384,29 +365,26 @@ class GroupServerStore(SQLBaseStore):
def get_group_roles(self, group_id):
rows = yield self._simple_select_list(
table="group_roles",
- keyvalues={
- "group_id": group_id,
- },
+ keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
desc="get_group_roles",
)
- defer.returnValue({
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ defer.returnValue(
+ {
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
}
- for row in rows
- })
+ )
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
role = yield self._simple_select_one(
table="group_roles",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
desc="get_group_role",
)
@@ -433,10 +411,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_upsert(
table="group_roles",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
insertion_values=insertion_values,
desc="upsert_group_role",
@@ -445,21 +420,24 @@ class GroupServerStore(SQLBaseStore):
def remove_group_role(self, group_id, role_id):
return self._simple_delete(
table="group_roles",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.runInteraction(
- "add_user_to_summary", self._add_user_to_summary_txn,
- group_id, user_id, role_id, order, is_public,
+ "add_user_to_summary",
+ self._add_user_to_summary_txn,
+ group_id,
+ user_id,
+ role_id,
+ order,
+ is_public,
)
- def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
- is_public):
+ def _add_user_to_summary_txn(
+ self, txn, group_id, user_id, role_id, order, is_public
+ ):
"""Add (or update) user's entry in summary.
Args:
@@ -474,10 +452,7 @@ class GroupServerStore(SQLBaseStore):
user_in_group = self._simple_select_one_onecol_txn(
txn,
table="group_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
)
@@ -490,10 +465,7 @@ class GroupServerStore(SQLBaseStore):
role_exists = self._simple_select_one_onecol_txn(
txn,
table="group_roles",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id},
retcol="group_id",
allow_none=True,
)
@@ -504,32 +476,28 @@ class GroupServerStore(SQLBaseStore):
role_exists = self._simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id},
retcol="group_id",
allow_none=True,
)
if not role_exists:
# If not, add it with an order larger than all others
- txn.execute("""
+ txn.execute(
+ """
INSERT INTO group_summary_roles
(group_id, role_id, role_order)
SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
FROM group_summary_roles
WHERE group_id = ? AND role_id = ?
- """, (group_id, role_id, group_id, role_id))
+ """,
+ (group_id, role_id, group_id, role_id),
+ )
existing = self._simple_select_one_txn(
txn,
table="group_summary_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- "role_id": role_id,
- },
- retcols=("user_order", "is_public",),
+ keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
+ retcols=("user_order", "is_public"),
allow_none=True,
)
@@ -539,13 +507,13 @@ class GroupServerStore(SQLBaseStore):
UPDATE group_summary_users SET user_order = user_order + 1
WHERE group_id = ? AND role_id = ? AND user_order >= ?
"""
- txn.execute(sql, (group_id, role_id, order,))
+ txn.execute(sql, (group_id, role_id, order))
elif not existing:
sql = """
SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
WHERE group_id = ? AND role_id = ?
"""
- txn.execute(sql, (group_id, role_id,))
+ txn.execute(sql, (group_id, role_id))
order, = txn.fetchone()
if existing:
@@ -586,11 +554,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_delete(
table="group_summary_users",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
)
@@ -599,10 +563,9 @@ class GroupServerStore(SQLBaseStore):
Returns ([users], [roles])
"""
+
def _get_users_for_summary_txn(txn):
- keyvalues = {
- "group_id": group_id,
- }
+ keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -651,6 +614,7 @@ class GroupServerStore(SQLBaseStore):
}
return users, roles
+
return self.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
@@ -658,10 +622,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_in_group(self, user_id, group_id):
return self._simple_select_one_onecol(
table="group_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
@@ -670,10 +631,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_admin_in_group(self, group_id, user_id):
return self._simple_select_one_onecol(
table="group_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
allow_none=True,
desc="is_user_admin_in_group",
@@ -684,10 +642,7 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_insert(
table="group_invites",
- values={
- "group_id": group_id,
- "user_id": user_id,
- },
+ values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
)
@@ -696,10 +651,7 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_select_one_onecol(
table="group_invites",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
desc="is_user_invited_to_local_group",
allow_none=True,
@@ -718,14 +670,12 @@ class GroupServerStore(SQLBaseStore):
Returns an empty dict if the user is not join/invite/etc
"""
+
def _get_users_membership_in_group_txn(txn):
row = self._simple_select_one_txn(
txn,
table="group_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("is_admin", "is_public"),
allow_none=True,
)
@@ -740,27 +690,29 @@ class GroupServerStore(SQLBaseStore):
row = self._simple_select_one_onecol_txn(
txn,
table="group_invites",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
)
if row:
- return {
- "membership": "invite",
- }
+ return {"membership": "invite"}
return {}
return self.runInteraction(
- "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
- def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
- local_attestation=None, remote_attestation=None):
+ def add_user_to_group(
+ self,
+ group_id,
+ user_id,
+ is_admin=False,
+ is_public=True,
+ local_attestation=None,
+ remote_attestation=None,
+ ):
"""Add a user to the group server.
Args:
@@ -774,6 +726,7 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): The attestation given to GS by remote
server. Optional if the user and group are on the same server
"""
+
def _add_user_to_group_txn(txn):
self._simple_insert_txn(
txn,
@@ -789,10 +742,7 @@ class GroupServerStore(SQLBaseStore):
self._simple_delete_txn(
txn,
table="group_invites",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
@@ -817,75 +767,52 @@ class GroupServerStore(SQLBaseStore):
},
)
- return self.runInteraction(
- "add_user_to_group", _add_user_to_group_txn
- )
+ return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
self._simple_delete_txn(
txn,
table="group_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_delete_txn(
txn,
table="group_invites",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_delete_txn(
txn,
table="group_attestations_renewals",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_delete_txn(
txn,
table="group_attestations_remote",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_delete_txn(
txn,
table="group_summary_users",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+ return self.runInteraction(
+ "remove_user_from_group", _remove_user_from_group_txn
+ )
def add_room_to_group(self, group_id, room_id, is_public):
return self._simple_insert(
table="group_rooms",
- values={
- "group_id": group_id,
- "room_id": room_id,
- "is_public": is_public,
- },
+ values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self._simple_update(
table="group_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- },
- updatevalues={
- "is_public": is_public,
- },
+ keyvalues={"group_id": group_id, "room_id": room_id},
+ updatevalues={"is_public": is_public},
desc="update_room_in_group_visibility",
)
@@ -894,22 +821,17 @@ class GroupServerStore(SQLBaseStore):
self._simple_delete_txn(
txn,
table="group_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- },
+ keyvalues={"group_id": group_id, "room_id": room_id},
)
self._simple_delete_txn(
txn,
table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- },
+ keyvalues={"group_id": group_id, "room_id": room_id},
)
+
return self.runInteraction(
- "remove_room_from_group", _remove_room_from_group_txn,
+ "remove_room_from_group", _remove_room_from_group_txn
)
def get_publicised_groups_for_user(self, user_id):
@@ -917,11 +839,7 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_select_onecol(
table="local_group_membership",
- keyvalues={
- "user_id": user_id,
- "membership": "join",
- "is_publicised": True,
- },
+ keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
desc="get_publicised_groups_for_user",
)
@@ -931,23 +849,23 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_update_one(
table="local_group_membership",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
- updatevalues={
- "is_publicised": publicise,
- },
- desc="update_group_publicity"
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ updatevalues={"is_publicised": publicise},
+ desc="update_group_publicity",
)
@defer.inlineCallbacks
- def register_user_group_membership(self, group_id, user_id, membership,
- is_admin=False, content={},
- local_attestation=None,
- remote_attestation=None,
- is_publicised=False,
- ):
+ def register_user_group_membership(
+ self,
+ group_id,
+ user_id,
+ membership,
+ is_admin=False,
+ content={},
+ local_attestation=None,
+ remote_attestation=None,
+ is_publicised=False,
+ ):
"""Registers that a local user is a member of a (local or remote) group.
Args:
@@ -962,15 +880,13 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): If remote group then store the remote
attestation from the group, else None.
"""
+
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
self._simple_delete_txn(
txn,
table="local_group_membership",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_insert_txn(
txn,
@@ -993,8 +909,10 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id,
"user_id": user_id,
"type": "membership",
- "content": json.dumps({"membership": membership, "content": content}),
- }
+ "content": json.dumps(
+ {"membership": membership, "content": content}
+ ),
+ },
)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
@@ -1009,7 +927,7 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": local_attestation["valid_until_ms"],
- }
+ },
)
if remote_attestation:
self._simple_insert_txn(
@@ -1020,24 +938,18 @@ class GroupServerStore(SQLBaseStore):
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
"attestation_json": json.dumps(remote_attestation),
- }
+ },
)
else:
self._simple_delete_txn(
txn,
table="group_attestations_renewals",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
self._simple_delete_txn(
txn,
table="group_attestations_remote",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
)
return next_id
@@ -1045,13 +957,15 @@ class GroupServerStore(SQLBaseStore):
with self._group_updates_id_gen.get_next() as next_id:
res = yield self.runInteraction(
"register_user_group_membership",
- _register_user_group_membership_txn, next_id,
+ _register_user_group_membership_txn,
+ next_id,
)
defer.returnValue(res)
@defer.inlineCallbacks
- def create_group(self, group_id, user_id, name, avatar_url, short_description,
- long_description,):
+ def create_group(
+ self, group_id, user_id, name, avatar_url, short_description, long_description
+ ):
yield self._simple_insert(
table="groups",
values={
@@ -1066,12 +980,10 @@ class GroupServerStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def update_group_profile(self, group_id, profile,):
+ def update_group_profile(self, group_id, profile):
yield self._simple_update_one(
table="groups",
- keyvalues={
- "group_id": group_id,
- },
+ keyvalues={"group_id": group_id},
updatevalues=profile,
desc="update_group_profile",
)
@@ -1079,6 +991,7 @@ class GroupServerStore(SQLBaseStore):
def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
+
def _get_attestations_need_renewals_txn(txn):
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
@@ -1086,6 +999,7 @@ class GroupServerStore(SQLBaseStore):
"""
txn.execute(sql, (valid_until_ms,))
return self.cursor_to_dict(txn)
+
return self.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
@@ -1095,13 +1009,8 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_update_one(
table="group_attestations_renewals",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
- updatevalues={
- "valid_until_ms": attestation["valid_until_ms"],
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)
@@ -1110,13 +1019,10 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_update_one(
table="group_attestations_remote",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
"valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json.dumps(attestation)
+ "attestation_json": json.dumps(attestation),
},
desc="update_remote_attestion",
)
@@ -1132,10 +1038,7 @@ class GroupServerStore(SQLBaseStore):
"""
return self._simple_delete(
table="group_attestations_renewals",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
)
@@ -1146,10 +1049,7 @@ class GroupServerStore(SQLBaseStore):
"""
row = yield self._simple_select_one(
table="group_attestations_remote",
- keyvalues={
- "group_id": group_id,
- "user_id": user_id,
- },
+ keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
desc="get_remote_attestation",
allow_none=True,
@@ -1164,10 +1064,7 @@ class GroupServerStore(SQLBaseStore):
def get_joined_groups(self, user_id):
return self._simple_select_onecol(
table="local_group_membership",
- keyvalues={
- "user_id": user_id,
- "membership": "join",
- },
+ keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
desc="get_joined_groups",
)
@@ -1181,7 +1078,7 @@ class GroupServerStore(SQLBaseStore):
WHERE user_id = ? AND membership != 'leave'
AND stream_id <= ?
"""
- txn.execute(sql, (user_id, now_token,))
+ txn.execute(sql, (user_id, now_token))
return [
{
"group_id": row[0],
@@ -1191,14 +1088,15 @@ class GroupServerStore(SQLBaseStore):
}
for row in txn
]
+
return self.runInteraction(
- "get_all_groups_for_user", _get_all_groups_for_user_txn,
+ "get_all_groups_for_user", _get_all_groups_for_user_txn
)
def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
- user_id, from_token,
+ user_id, from_token
)
if not has_changed:
return []
@@ -1210,21 +1108,25 @@ class GroupServerStore(SQLBaseStore):
INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
"""
- txn.execute(sql, (user_id, from_token, to_token,))
- return [{
- "group_id": group_id,
- "membership": membership,
- "type": gtype,
- "content": json.loads(content_json),
- } for group_id, membership, gtype, content_json in txn]
+ txn.execute(sql, (user_id, from_token, to_token))
+ return [
+ {
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ }
+ for group_id, membership, gtype, content_json in txn
+ ]
+
return self.runInteraction(
- "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token,
+ from_token
)
if not has_changed:
return []
@@ -1236,17 +1138,52 @@ class GroupServerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit,))
- return [(
- stream_id,
- group_id,
- user_id,
- gtype,
- json.loads(content_json),
- ) for stream_id, group_id, user_id, gtype, content_json in txn]
+ txn.execute(sql, (from_token, to_token, limit))
+ return [
+ (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ]
+
return self.runInteraction(
- "get_all_groups_changes", _get_all_groups_changes_txn,
+ "get_all_groups_changes", _get_all_groups_changes_txn
)
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
+
+ def delete_group(self, group_id):
+ """Deletes a group fully from the database.
+
+ Args:
+ group_id (str)
+
+ Returns:
+ Deferred
+ """
+
+ def _delete_group_txn(txn):
+ tables = [
+ "groups",
+ "group_users",
+ "group_invites",
+ "group_rooms",
+ "group_summary_rooms",
+ "group_summary_room_categories",
+ "group_room_categories",
+ "group_summary_users",
+ "group_summary_roles",
+ "group_roles",
+ "group_attestations_renewals",
+ "group_attestations_remote",
+ ]
+
+ for table in tables:
+ self._simple_delete_txn(
+ txn,
+ table=table,
+ keyvalues={"group_id": group_id},
+ )
+
+ return self.runInteraction(
+ "delete_group", _delete_group_txn
+ )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 8af17921e3..7036541792 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
+import itertools
import logging
import six
from signedjson.key import decode_verify_key_bytes
-import OpenSSL
-from twisted.internet import defer
-
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util import batch_iter
+from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
@@ -38,93 +37,56 @@ else:
class KeyStore(SQLBaseStore):
- """Persistence for signature verification keys and tls X.509 certificates
+ """Persistence for signature verification keys
"""
- @defer.inlineCallbacks
- def get_server_certificate(self, server_name):
- """Retrieve the TLS X.509 certificate for the given server
+ @cached()
+ def _get_server_verify_key(self, server_name_and_key_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+ )
+ def get_server_verify_keys(self, server_name_and_key_ids):
+ """
Args:
- server_name (bytes): The name of the server.
+ server_name_and_key_ids (iterable[Tuple[str, str]]):
+ iterable of (server_name, key-id) tuples to fetch keys for
+
Returns:
- (OpenSSL.crypto.X509): The tls certificate.
+ Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
+ map from (server_name, key_id) -> VerifyKey, or None if the key is
+ unknown
"""
- tls_certificate_bytes, = yield self._simple_select_one(
- table="server_tls_certificates",
- keyvalues={"server_name": server_name},
- retcols=("tls_certificate",),
- desc="get_server_certificate",
- )
- tls_certificate = OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
- )
- defer.returnValue(tls_certificate)
+ keys = {}
- def store_server_certificate(self, server_name, from_server, time_now_ms,
- tls_certificate):
- """Stores the TLS X.509 certificate for the given server
- Args:
- server_name (str): The name of the server.
- from_server (str): Where the certificate was looked up
- time_now_ms (int): The time now in milliseconds
- tls_certificate (OpenSSL.crypto.X509): The X.509 certificate.
- """
- tls_certificate_bytes = OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
- )
- fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
- return self._simple_upsert(
- table="server_tls_certificates",
- keyvalues={
- "server_name": server_name,
- "fingerprint": fingerprint,
- },
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "tls_certificate": db_binary_type(tls_certificate_bytes),
- },
- desc="store_server_certificate",
- )
+ def _get_keys(txn, batch):
+ """Processes a batch of keys to fetch, and adds the result to `keys`."""
- @cachedInlineCallbacks()
- def _get_server_verify_key(self, server_name, key_id):
- verify_key_bytes = yield self._simple_select_one_onecol(
- table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": key_id,
- },
- retcol="verify_key",
- desc="_get_server_verify_key",
- allow_none=True,
- )
+ # batch_iter always returns tuples so it's safe to do len(batch)
+ sql = (
+ "SELECT server_name, key_id, verify_key FROM server_signature_keys "
+ "WHERE 1=0"
+ ) + " OR (server_name=? AND key_id=?)" * len(batch)
- if verify_key_bytes:
- defer.returnValue(decode_verify_key_bytes(
- key_id, bytes(verify_key_bytes)
- ))
+ txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
- @defer.inlineCallbacks
- def get_server_verify_keys(self, server_name, key_ids):
- """Retrieve the NACL verification key for a given server for the given
- key_ids
- Args:
- server_name (str): The name of the server.
- key_ids (iterable[str]): key_ids to try and look up.
- Returns:
- Deferred: resolves to dict[str, VerifyKey]: map from
- key_id to verification key.
- """
- keys = {}
- for key_id in key_ids:
- key = yield self._get_server_verify_key(server_name, key_id)
- if key:
- keys[key_id] = key
- defer.returnValue(keys)
-
- def store_server_verify_key(self, server_name, from_server, time_now_ms,
- verify_key):
+ for row in txn:
+ server_name, key_id, key_bytes = row
+ keys[(server_name, key_id)] = decode_verify_key_bytes(
+ key_id, bytes(key_bytes)
+ )
+
+ def _txn(txn):
+ for batch in batch_iter(server_name_and_key_ids, 50):
+ _get_keys(txn, batch)
+ return keys
+
+ return self.runInteraction("get_server_verify_keys", _txn)
+
+ def store_server_verify_key(
+ self, server_name, from_server, time_now_ms, verify_key
+ ):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
@@ -139,25 +101,25 @@ class KeyStore(SQLBaseStore):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": key_id,
- },
+ keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": db_binary_type(verify_key.encode()),
},
)
+ # invalidate takes a tuple corresponding to the params of
+ # _get_server_verify_key. _get_server_verify_key only takes one
+ # param, which is itself the 2-tuple (server_name, key_id).
txn.call_after(
- self._get_server_verify_key.invalidate,
- (server_name, key_id)
+ self._get_server_verify_key.invalidate, ((server_name, key_id),)
)
return self.runInteraction("store_server_verify_key", _txn)
- def store_server_keys_json(self, server_name, key_id, from_server,
- ts_now_ms, ts_expires_ms, key_json_bytes):
+ def store_server_keys_json(
+ self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
+ ):
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
@@ -197,9 +159,10 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
- Dict mapping (server_name, key_id, source) triplets to dicts with
- "ts_valid_until_ms" and "key_json" keys.
+ Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
+ Dict mapping (server_name, key_id, source) triplets to lists of dicts
"""
+
def _get_server_keys_json_txn(txn):
results = {}
for server_name, key_id, from_server in server_keys:
@@ -222,6 +185,5 @@ class KeyStore(SQLBaseStore):
)
results[(server_name, key_id, from_server)] = rows
return results
- return self.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
- )
+
+ return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index e6cdbb0545..3ecf47e7a7 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -38,15 +38,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository",
{"media_id": media_id},
(
- "media_type", "media_length", "upload_name", "created_ts",
- "quarantined_by", "url_cache",
+ "media_type",
+ "media_length",
+ "upload_name",
+ "created_ts",
+ "quarantined_by",
+ "url_cache",
),
allow_none=True,
desc="get_local_media",
)
- def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
- media_length, user_id, url_cache=None):
+ def store_local_media(
+ self,
+ media_id,
+ media_type,
+ time_now_ms,
+ upload_name,
+ media_length,
+ user_id,
+ url_cache=None,
+ ):
return self._simple_insert(
"local_media_repository",
{
@@ -66,6 +78,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
Returns:
None if the URL isn't cached.
"""
+
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
@@ -92,16 +105,25 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if not row:
return None
- return dict(zip((
- 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
- ), row))
+ return dict(
+ zip(
+ (
+ 'response_code',
+ 'etag',
+ 'expires_ts',
+ 'og',
+ 'media_id',
+ 'download_ts',
+ ),
+ row,
+ )
+ )
- return self.runInteraction(
- "get_url_cache", get_url_cache_txn
- )
+ return self.runInteraction("get_url_cache", get_url_cache_txn)
- def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
- download_ts):
+ def store_url_cache(
+ self, url, response_code, etag, expires_ts, og, media_id, download_ts
+ ):
return self._simple_insert(
"local_media_repository_url_cache",
{
@@ -121,15 +143,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository_thumbnails",
{"media_id": media_id},
(
- "thumbnail_width", "thumbnail_height", "thumbnail_method",
- "thumbnail_type", "thumbnail_length",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
),
desc="get_local_media_thumbnails",
)
- def store_local_thumbnail(self, media_id, thumbnail_width,
- thumbnail_height, thumbnail_type,
- thumbnail_method, thumbnail_length):
+ def store_local_thumbnail(
+ self,
+ media_id,
+ thumbnail_width,
+ thumbnail_height,
+ thumbnail_type,
+ thumbnail_method,
+ thumbnail_length,
+ ):
return self._simple_insert(
"local_media_repository_thumbnails",
{
@@ -148,16 +179,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
- "media_type", "media_length", "upload_name", "created_ts",
- "filesystem_id", "quarantined_by",
+ "media_type",
+ "media_length",
+ "upload_name",
+ "created_ts",
+ "filesystem_id",
+ "quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
- def store_cached_remote_media(self, origin, media_id, media_type,
- media_length, time_now_ms, upload_name,
- filesystem_id):
+ def store_cached_remote_media(
+ self,
+ origin,
+ media_id,
+ media_type,
+ media_length,
+ time_now_ms,
+ upload_name,
+ filesystem_id,
+ ):
return self._simple_insert(
"remote_media_cache",
{
@@ -181,26 +223,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
+
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(sql, (
- (time_ms, media_origin, media_id)
- for media_origin, media_id in remote_media
- ))
+ txn.executemany(
+ sql,
+ (
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ),
+ )
sql = (
"UPDATE local_media_repository SET last_access_ts = ?"
" WHERE media_id = ?"
)
- txn.executemany(sql, (
- (time_ms, media_id)
- for media_id in local_media
- ))
+ txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -209,16 +252,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
- "thumbnail_width", "thumbnail_height", "thumbnail_method",
- "thumbnail_type", "thumbnail_length", "filesystem_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ "filesystem_id",
),
desc="get_remote_media_thumbnails",
)
- def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
- thumbnail_width, thumbnail_height,
- thumbnail_type, thumbnail_method,
- thumbnail_length):
+ def store_remote_media_thumbnail(
+ self,
+ origin,
+ media_id,
+ filesystem_id,
+ thumbnail_width,
+ thumbnail_height,
+ thumbnail_type,
+ thumbnail_method,
+ thumbnail_length,
+ ):
return self._simple_insert(
"remote_media_cache_thumbnails",
{
@@ -250,17 +304,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
self._simple_delete_txn(
txn,
"remote_media_cache",
- keyvalues={
- "media_origin": media_origin, "media_id": media_id
- },
+ keyvalues={"media_origin": media_origin, "media_id": media_id},
)
self._simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
- keyvalues={
- "media_origin": media_origin, "media_id": media_id
- },
+ keyvalues={"media_origin": media_origin, "media_id": media_id},
)
+
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
@@ -281,10 +332,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if len(media_ids) == 0:
return
- sql = (
- "DELETE FROM local_media_repository_url_cache"
- " WHERE media_id = ?"
- )
+ sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
@@ -304,7 +352,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return [row[0] for row in txn]
return self.runInteraction(
- "get_url_cache_media_before", _get_url_cache_media_before_txn,
+ "get_url_cache_media_before", _get_url_cache_media_before_txn
)
def delete_url_cache_media(self, media_ids):
@@ -312,20 +360,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
- sql = (
- "DELETE FROM local_media_repository"
- " WHERE media_id = ?"
- )
+ sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- sql = (
- "DELETE FROM local_media_repository_thumbnails"
- " WHERE media_id = ?"
- )
+ sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
- "delete_url_cache_media", _delete_url_cache_media_txn,
+ "delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 9e7e09b8c1..8aa8abc470 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -35,9 +35,12 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._new_transaction(
- dbconn, "initialise_mau_threepids", [], [],
+ dbconn,
+ "initialise_mau_threepids",
+ [],
+ [],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
+ hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -51,10 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
reserved_user_list = []
for tp in threepids:
- user_id = self.get_user_id_by_threepid_txn(
- txn,
- tp["medium"], tp["address"]
- )
+ user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
@@ -62,9 +62,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
- logger.warning(
- "mau limit reserved threepid %s not found in db" % tp
- )
+ logger.warning("mau limit reserved threepid %s not found in db" % tp)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
@@ -75,12 +73,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Returns:
Deferred[]
"""
+
def _reap_users(txn):
# Purge stale users
- thirty_days_ago = (
- int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- )
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
@@ -158,6 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn.execute(sql)
count, = txn.fetchone()
return count
+
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
@@ -198,14 +196,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return
yield self.runInteraction(
- "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
- user_id
+ "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
user_in_mau = self.user_last_seen_monthly_active.cache.get(
- (user_id,),
- None,
- update_metrics=False
+ (user_id,), None, update_metrics=False
)
if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
@@ -247,12 +242,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_insert = self._simple_upsert_txn(
txn,
table="monthly_active_users",
- keyvalues={
- "user_id": user_id,
- },
- values={
- "timestamp": int(self._clock.time_msec()),
- },
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
)
return is_insert
@@ -268,15 +259,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"""
- return(self._simple_select_one_onecol(
+ return self._simple_select_one_onecol(
table="monthly_active_users",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
retcol="timestamp",
allow_none=True,
desc="user_last_seen_monthly_active",
- ))
+ )
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
diff --git a/synapse/storage/openid.py b/synapse/storage/openid.py
index 5dabb607bd..b3318045ee 100644
--- a/synapse/storage/openid.py
+++ b/synapse/storage/openid.py
@@ -10,7 +10,7 @@ class OpenIdStore(SQLBaseStore):
"ts_valid_until_ms": ts_valid_until_ms,
"user_id": user_id,
},
- desc="insert_open_id_token"
+ desc="insert_open_id_token",
)
def get_user_id_for_open_id_token(self, token, ts_now_ms):
@@ -27,6 +27,5 @@ class OpenIdStore(SQLBaseStore):
return None
else:
return rows[0][0]
- return self.runInteraction(
- "get_user_id_for_token", get_user_id_for_token_txn
- )
+
+ return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e042221774..c1711bc8bd 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -143,10 +143,9 @@ def _setup_new_database(cur, database_engine):
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)"
- " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
),
- (max_current_ver, False,)
+ (max_current_ver, False),
)
_upgrade_existing_database(
@@ -160,8 +159,15 @@ def _setup_new_database(cur, database_engine):
)
-def _upgrade_existing_database(cur, current_version, applied_delta_files,
- upgraded, database_engine, config, is_empty=False):
+def _upgrade_existing_database(
+ cur,
+ current_version,
+ applied_delta_files,
+ upgraded,
+ database_engine,
+ config,
+ is_empty=False,
+):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -209,8 +215,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if current_version > SCHEMA_VERSION:
raise ValueError(
- "Cannot use this database as it is too " +
- "new for the server to understand"
+ "Cannot use this database as it is too "
+ + "new for the server to understand"
)
start_ver = current_version
@@ -239,20 +245,14 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if relative_path in applied_delta_files:
continue
- absolute_path = os.path.join(
- dir_path, "schema", "delta", relative_path,
- )
+ absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
- module_name = "synapse.storage.v%d_%s" % (
- v, root_name
- )
+ module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(
- module_name, absolute_path, python_file
- )
+ module = imp.load_source(module_name, absolute_path, python_file)
logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine)
if not is_empty:
@@ -269,8 +269,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
else:
# Not a valid delta file.
logger.warn(
- "Found directory entry that did not end in .py or"
- " .sql: %s",
+ "Found directory entry that did not end in .py or" " .sql: %s",
relative_path,
)
continue
@@ -278,19 +277,17 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file)"
- " VALUES (?,?)",
+ "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
),
- (v, relative_path)
+ (v, relative_path),
)
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
+ "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
),
- (v, True)
+ (v, True),
)
@@ -308,7 +305,7 @@ def _apply_module_schemas(txn, database_engine, config):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
- txn, database_engine, modname, mod.get_db_schema_files(),
+ txn, database_engine, modname, mod.get_db_schema_files()
)
@@ -326,7 +323,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
database_engine.convert_param_style(
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
),
- (modname,)
+ (modname,),
)
applied_deltas = set(d for d, in cur)
for (name, stream) in names_and_streams:
@@ -336,7 +333,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
root_name, ext = os.path.splitext(name)
if ext != '.sql':
raise PrepareDatabaseException(
- "only .sql files are currently supported for module schemas",
+ "only .sql files are currently supported for module schemas"
)
logger.info("applying schema %s for %s", name, modname)
@@ -346,10 +343,9 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)"
- " VALUES (?,?)",
+ "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
),
- (modname, name)
+ (modname, name),
)
@@ -386,10 +382,7 @@ def get_statements(f):
statements = line.split(";")
# We must prepend statement_buffer to the first statement
- first_statement = "%s %s" % (
- statement_buffer.strip(),
- statements[0].strip()
- )
+ first_statement = "%s %s" % (statement_buffer.strip(), statements[0].strip())
statements[0] = first_statement
# Every entry, except the last, is a full statement
@@ -409,9 +402,7 @@ def executescript(txn, schema_path):
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
- schema_path = os.path.join(
- dir_path, "schema", "schema_version.sql",
- )
+ schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
@@ -424,7 +415,7 @@ def _get_or_create_schema_state(txn, database_engine):
database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
- (current_version,)
+ (current_version,),
)
applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index a0c7a0dc87..42ec8c6bb8 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -19,15 +19,25 @@ from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
-class UserPresenceState(namedtuple("UserPresenceState",
- ("user_id", "state", "last_active_ts",
- "last_federation_update_ts", "last_user_sync_ts",
- "status_msg", "currently_active"))):
+class UserPresenceState(
+ namedtuple(
+ "UserPresenceState",
+ (
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ )
+):
"""Represents the current presence state of the user.
user_id (str)
@@ -75,22 +85,21 @@ class PresenceStore(SQLBaseStore):
with stream_ordering_manager as stream_orderings:
yield self.runInteraction(
"update_presence",
- self._update_presence_txn, stream_orderings, presence_states,
+ self._update_presence_txn,
+ stream_orderings,
+ presence_states,
)
- defer.returnValue((
- stream_orderings[-1], self._presence_id_gen.get_current_token()
- ))
+ defer.returnValue(
+ (stream_orderings[-1], self._presence_id_gen.get_current_token())
+ )
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
- self.presence_stream_cache.entity_has_changed,
- state.user_id, stream_id,
- )
- txn.call_after(
- self._get_presence_for_user.invalidate, (state.user_id,)
+ self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
)
+ txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
self._simple_insert_many_txn(
@@ -113,18 +122,13 @@ class PresenceStore(SQLBaseStore):
# Delete old rows to stop database from getting really big
sql = (
- "DELETE FROM presence_stream WHERE"
- " stream_id < ?"
- " AND user_id IN (%s)"
+ "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
)
for states in batch_iter(presence_states, 50):
args = [stream_id]
args.extend(s.user_id for s in states)
- txn.execute(
- sql % (",".join("?" for _ in states),),
- args
- )
+ txn.execute(sql % (",".join("?" for _ in states),), args)
def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id:
@@ -149,8 +153,12 @@ class PresenceStore(SQLBaseStore):
def _get_presence_for_user(self, user_id):
raise NotImplementedError()
- @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
- num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="_get_presence_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
table="presence_stream",
@@ -180,8 +188,10 @@ class PresenceStore(SQLBaseStore):
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
table="presence_allow_inbound",
- values={"observed_user_id": observed_localpart,
- "observer_user_id": observer_userid},
+ values={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
desc="allow_presence_visible",
or_ignore=True,
)
@@ -189,89 +199,9 @@ class PresenceStore(SQLBaseStore):
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_delete_one(
table="presence_allow_inbound",
- keyvalues={"observed_user_id": observed_localpart,
- "observer_user_id": observer_userid},
+ keyvalues={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
desc="disallow_presence_visible",
)
-
- def add_presence_list_pending(self, observer_localpart, observed_userid):
- return self._simple_insert(
- table="presence_list",
- values={"user_id": observer_localpart,
- "observed_user_id": observed_userid,
- "accepted": False},
- desc="add_presence_list_pending",
- )
-
- def set_presence_list_accepted(self, observer_localpart, observed_userid):
- def update_presence_list_txn(txn):
- result = self._simple_update_one_txn(
- txn,
- table="presence_list",
- keyvalues={
- "user_id": observer_localpart,
- "observed_user_id": observed_userid
- },
- updatevalues={"accepted": True},
- )
-
- self._invalidate_cache_and_stream(
- txn, self.get_presence_list_accepted, (observer_localpart,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_presence_list_observers_accepted, (observed_userid,)
- )
-
- return result
-
- return self.runInteraction(
- "set_presence_list_accepted", update_presence_list_txn,
- )
-
- def get_presence_list(self, observer_localpart, accepted=None):
- if accepted:
- return self.get_presence_list_accepted(observer_localpart)
- else:
- keyvalues = {"user_id": observer_localpart}
- if accepted is not None:
- keyvalues["accepted"] = accepted
-
- return self._simple_select_list(
- table="presence_list",
- keyvalues=keyvalues,
- retcols=["observed_user_id", "accepted"],
- desc="get_presence_list",
- )
-
- @cached()
- def get_presence_list_accepted(self, observer_localpart):
- return self._simple_select_list(
- table="presence_list",
- keyvalues={"user_id": observer_localpart, "accepted": True},
- retcols=["observed_user_id", "accepted"],
- desc="get_presence_list_accepted",
- )
-
- @cachedInlineCallbacks()
- def get_presence_list_observers_accepted(self, observed_userid):
- user_localparts = yield self._simple_select_onecol(
- table="presence_list",
- keyvalues={"observed_user_id": observed_userid, "accepted": True},
- retcol="user_id",
- desc="get_presence_list_accepted",
- )
-
- defer.returnValue([
- "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
- ])
-
- @defer.inlineCallbacks
- def del_presence_list(self, observer_localpart, observed_userid):
- yield self._simple_delete_one(
- table="presence_list",
- keyvalues={"user_id": observer_localpart,
- "observed_user_id": observed_userid},
- desc="del_presence_list",
- )
- self.get_presence_list_accepted.invalidate((observer_localpart,))
- self.get_presence_list_observers_accepted.invalidate((observed_userid,))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 88b50f33b5..aeec2f57c4 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -41,8 +41,7 @@ class ProfileWorkerStore(SQLBaseStore):
defer.returnValue(
ProfileInfo(
- avatar_url=profile['avatar_url'],
- display_name=profile['displayname'],
+ avatar_url=profile['avatar_url'], display_name=profile['displayname']
)
)
@@ -66,16 +65,14 @@ class ProfileWorkerStore(SQLBaseStore):
return self._simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- retcols=("displayname", "avatar_url",),
+ retcols=("displayname", "avatar_url"),
allow_none=True,
desc="get_from_remote_profile_cache",
)
def create_profile(self, user_localpart):
return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
+ table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
@@ -141,6 +138,7 @@ class ProfileStore(ProfileWorkerStore):
def get_remote_profile_cache_entries_that_expire(self, last_checked):
"""Get all users who haven't been checked since `last_checked`
"""
+
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 4b8438c3e9..9e406baafa 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -57,11 +57,13 @@ def _load_rules(rawrules, enabled_map):
return rules
-class PushRulesWorkerStore(ApplicationServiceWorkerStore,
- ReceiptsWorkerStore,
- PusherWorkerStore,
- RoomMemberWorkerStore,
- SQLBaseStore):
+class PushRulesWorkerStore(
+ ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore,
+):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
@@ -74,14 +76,16 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn, "push_rules_stream",
+ db_conn,
+ "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache", push_rules_id,
+ "PushRulesStreamChangeCache",
+ push_rules_id,
prefilled_cache=push_rules_prefill,
)
@@ -98,19 +102,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
- keyvalues={
- "user_name": user_id,
- },
+ keyvalues={"user_name": user_id},
retcols=(
- "user_name", "rule_id", "priority_class", "priority",
- "conditions", "actions",
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
),
desc="get_push_rules_enabled_for_user",
)
- rows.sort(
- key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
- )
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
@@ -122,22 +126,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
- keyvalues={
- 'user_name': user_id
- },
- retcols=(
- "user_name", "rule_id", "enabled",
- ),
+ keyvalues={'user_name': user_id},
+ retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
- defer.returnValue({
- r['rule_id']: False if r['enabled'] == 0 else True for r in results
- })
+ defer.returnValue(
+ {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
+ )
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
+
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
@@ -146,20 +147,22 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
+
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
- @cachedList(cached_method_name="get_push_rules_for_user",
- list_name="user_ids", num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="get_push_rules_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def bulk_get_push_rules(self, user_ids):
if not user_ids:
defer.returnValue({})
- results = {
- user_id: []
- for user_id in user_ids
- }
+ results = {user_id: [] for user_id in user_ids}
rows = yield self._simple_select_many_batch(
table="push_rules",
@@ -169,9 +172,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
desc="bulk_get_push_rules",
)
- rows.sort(
- key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
- )
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
results.setdefault(row['user_name'], []).append(row)
@@ -179,16 +180,12 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {})
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
defer.returnValue(results)
@defer.inlineCallbacks
- def move_push_rule_from_room_to_room(
- self, new_room_id, user_id, rule,
- ):
+ def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Move a single push rule from one room to another for a specific user.
Args:
@@ -219,7 +216,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
@defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id,
+ self, old_room_id, new_room_id, user_id
):
"""Move all of the push rules from one room to another for a specific
user.
@@ -236,11 +233,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# delete them from the old room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
- if any((c.get("key") == "room_id" and
- c.get("pattern") == old_room_id) for c in conditions):
- self.move_push_rule_from_room_to_room(
- new_room_id, user_id, rule,
- )
+ if any(
+ (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+ for c in conditions
+ ):
+ self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
@@ -259,8 +256,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
- cache_context, event=None):
+ def _bulk_get_push_rules_for_room(
+ self, room_id, state_group, current_state_ids, cache_context, event=None
+ ):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
@@ -273,7 +271,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context(
- room_id, state_group, current_state_ids,
+ room_id,
+ state_group,
+ current_state_ids,
on_invalidate=cache_context.invalidate,
event=event,
)
@@ -282,7 +282,8 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = set(
- u for u in users_in_room
+ u
+ for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
)
@@ -290,15 +291,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room,
- on_invalidate=cache_context.invalidate,
+ local_users_in_room, on_invalidate=cache_context.invalidate
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
# any users with pushers must be ours: they have pushers
@@ -307,29 +307,30 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
- user_ids, on_invalidate=cache_context.invalidate,
+ user_ids, on_invalidate=cache_context.invalidate
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
- @cachedList(cached_method_name="get_push_rules_enabled_for_user",
- list_name="user_ids", num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="get_push_rules_enabled_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
defer.returnValue({})
- results = {
- user_id: {}
- for user_id in user_ids
- }
+ results = {user_id: {} for user_id in user_ids}
rows = yield self._simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled",),
+ retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
)
for row in rows:
@@ -341,8 +342,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
- self, user_id, rule_id, priority_class, conditions, actions,
- before=None, after=None
+ self,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions,
+ actions,
+ before=None,
+ after=None,
):
conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions)
@@ -352,20 +359,41 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
- stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- conditions_json, actions_json, before, after,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
)
else:
yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
- stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- conditions_json, actions_json,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
)
def _add_push_rule_relative_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- conditions_json, actions_json, before, after
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
@@ -376,10 +404,7 @@ class PushRuleStore(PushRulesWorkerStore):
res = self._simple_select_one_txn(
txn,
table="push_rules",
- keyvalues={
- "user_name": user_id,
- "rule_id": relative_to_rule,
- },
+ keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
retcols=["priority_class", "priority"],
allow_none=True,
)
@@ -416,13 +441,27 @@ class PushRuleStore(PushRulesWorkerStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- new_rule_priority, conditions_json, actions_json,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_rule_priority,
+ conditions_json,
+ actions_json,
)
def _add_push_rule_highest_priority_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- conditions_json, actions_json
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
@@ -443,13 +482,28 @@ class PushRuleStore(PushRulesWorkerStore):
self._upsert_push_rule_txn(
txn,
- stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
- conditions_json, actions_json,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_prio,
+ conditions_json,
+ actions_json,
)
def _upsert_push_rule_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
- priority, conditions_json, actions_json, update_stream=True
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ conditions_json,
+ actions_json,
+ update_stream=True,
):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
@@ -461,10 +515,10 @@ class PushRuleStore(PushRulesWorkerStore):
" WHERE user_name = ? AND rule_id = ?"
)
- txn.execute(sql, (
- priority_class, priority, conditions_json, actions_json,
- user_id, rule_id,
- ))
+ txn.execute(
+ sql,
+ (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
+ )
if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one
@@ -486,14 +540,18 @@ class PushRuleStore(PushRulesWorkerStore):
if update_stream:
self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
op="ADD",
data={
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
- }
+ },
)
@defer.inlineCallbacks
@@ -507,22 +565,23 @@ class PushRuleStore(PushRulesWorkerStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
+
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self._simple_delete_one_txn(
- txn,
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
+ txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
)
self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id,
- op="DELETE"
+ txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
- "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
+ "delete_push_rule",
+ delete_push_rule_txn,
+ stream_id,
+ event_stream_ordering,
)
@defer.inlineCallbacks
@@ -532,7 +591,11 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
- stream_id, event_stream_ordering, user_id, rule_id, enabled
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
)
def _set_push_rule_enabled_txn(
@@ -548,8 +611,12 @@ class PushRuleStore(PushRulesWorkerStore):
)
self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id,
- op="ENABLE" if enabled else "DISABLE"
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ENABLE" if enabled else "DISABLE",
)
@defer.inlineCallbacks
@@ -563,9 +630,16 @@ class PushRuleStore(PushRulesWorkerStore):
priority_class = -1
priority = 1
self._upsert_push_rule_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id,
- priority_class, priority, "[]", actions_json,
- update_stream=False
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ "[]",
+ actions_json,
+ update_stream=False,
)
else:
self._simple_update_one_txn(
@@ -576,15 +650,22 @@ class PushRuleStore(PushRulesWorkerStore):
)
self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id,
- op="ACTIONS", data={"actions": actions_json}
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ACTIONS",
+ data={"actions": actions_json},
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
- "set_push_rule_actions", set_push_rule_actions_txn,
- stream_id, event_stream_ordering
+ "set_push_rule_actions",
+ set_push_rule_actions_txn,
+ stream_id,
+ event_stream_ordering,
)
def _insert_push_rules_update_txn(
@@ -602,12 +683,8 @@ class PushRuleStore(PushRulesWorkerStore):
self._simple_insert_txn(txn, "push_rules_stream", values=values)
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
- )
+ txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -627,6 +704,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
+
return self.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 134297e284..1567e1df48 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -47,7 +47,9 @@ class PusherWorkerStore(SQLBaseStore):
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.args[0],
+ r['id'],
+ dataJson,
+ e.args[0],
)
pass
@@ -64,20 +66,16 @@ class PusherWorkerStore(SQLBaseStore):
defer.returnValue(ret is not None)
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({
- "app_id": app_id,
- "pushkey": pushkey,
- })
+ return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({
- "user_name": user_id,
- })
+ return self.get_pushers_by({"user_name": user_id})
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self._simple_select_list(
- "pushers", keyvalues,
+ "pushers",
+ keyvalues,
[
"id",
"user_name",
@@ -94,7 +92,8 @@ class PusherWorkerStore(SQLBaseStore):
"last_stream_ordering",
"last_success",
"failing_since",
- ], desc="get_pushers_by"
+ ],
+ desc="get_pushers_by",
)
defer.returnValue(self._decode_pushers_rows(ret))
@@ -135,6 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
deleted = txn.fetchall()
return (updated, deleted)
+
return self.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@@ -177,6 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
results.sort() # Sort so that they're ordered by stream id
return results
+
return self.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@@ -186,15 +187,19 @@ class PusherWorkerStore(SQLBaseStore):
# This only exists for the cachedList decorator
raise NotImplementedError()
- @cachedList(cached_method_name="get_if_user_has_pusher",
- list_name="user_ids", num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="get_if_user_has_pusher",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=user_ids,
retcols=['user_name'],
- desc='get_if_users_have_pushers'
+ desc='get_if_users_have_pushers',
)
result = {user_id: False for user_id in user_ids}
@@ -208,20 +213,27 @@ class PusherStore(PusherWorkerStore):
return self._pushers_id_gen.get_current_token()
@defer.inlineCallbacks
- def add_pusher(self, user_id, access_token, kind, app_id,
- app_display_name, device_display_name,
- pushkey, pushkey_ts, lang, data, last_stream_ordering,
- profile_tag=""):
+ def add_pusher(
+ self,
+ user_id,
+ access_token,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ pushkey_ts,
+ lang,
+ data,
+ last_stream_ordering,
+ profile_tag="",
+ ):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
yield self._simple_upsert(
table="pushers",
- keyvalues={
- "app_id": app_id,
- "pushkey": pushkey,
- "user_name": user_id,
- },
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
"access_token": access_token,
"kind": kind,
@@ -247,7 +259,8 @@ class PusherStore(PusherWorkerStore):
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
- self.get_if_user_has_pusher, (user_id,)
+ self.get_if_user_has_pusher,
+ (user_id,),
)
@defer.inlineCallbacks
@@ -260,7 +273,7 @@ class PusherStore(PusherWorkerStore):
self._simple_delete_one_txn(
txn,
"pushers",
- {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
)
# it's possible for us to end up with duplicate rows for
@@ -278,13 +291,12 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.runInteraction(
- "delete_pusher", delete_pusher_txn, stream_id
- )
+ yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
@defer.inlineCallbacks
- def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id,
- last_stream_ordering):
+ def update_pusher_last_stream_ordering(
+ self, app_id, pushkey, user_id, last_stream_ordering
+ ):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@@ -293,23 +305,21 @@ class PusherStore(PusherWorkerStore):
)
@defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey,
- user_id,
- last_stream_ordering,
- last_success):
+ def update_pusher_last_stream_ordering_and_success(
+ self, app_id, pushkey, user_id, last_stream_ordering, last_success
+ ):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{
'last_stream_ordering': last_stream_ordering,
- 'last_success': last_success
+ 'last_success': last_success,
},
desc="update_pusher_last_stream_ordering_and_success",
)
@defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id,
- failing_since):
+ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@@ -323,14 +333,14 @@ class PusherStore(PusherWorkerStore):
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room"
+ desc="get_throttle_params_by_room",
)
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = {
"last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"]
+ "throttle_ms": row["throttle_ms"],
}
defer.returnValue(params_by_room)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 89a1f7e3d7..a1647e50a1 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
table="receipts_linearized",
- keyvalues={
- "room_id": room_id,
- "receipt_type": receipt_type,
- },
+ keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
desc="get_receipts_for_room",
)
@@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
- "user_id": user_id
+ "user_id": user_id,
},
retcol="event_id",
desc="get_own_receipt_for_user",
@@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self._simple_select_list(
table="receipts_linearized",
- keyvalues={
- "user_id": user_id,
- "receipt_type": receipt_type,
- },
+ keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
desc="get_receipts_for_user",
)
@@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.runInteraction(
- "get_receipts_for_user_with_orderings", f
+
+ rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+ defer.returnValue(
+ {
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
+ }
+ for row in rows
+ }
)
- defer.returnValue({
- row[0]: {
- "event_id": row[1],
- "topological_ordering": row[2],
- "stream_ordering": row[3],
- } for row in rows
- })
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room
"""
+
def f(txn):
if from_key:
sql = (
@@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore):
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
- txn.execute(
- sql,
- (room_id, from_key, to_key)
- )
+ txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
- txn.execute(
- sql,
- (room_id, to_key)
- )
+ txn.execute(sql, (room_id, to_key))
rows = self.cursor_to_dict(txn)
return rows
- rows = yield self.runInteraction(
- "get_linearized_receipts_for_room", f
- )
+ rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
defer.returnValue([])
content = {}
for row in rows:
- content.setdefault(
- row["event_id"], {}
- ).setdefault(
- row["receipt_type"], {}
- )[row["user_id"]] = json.loads(row["data"])
-
- defer.returnValue([{
- "type": "m.receipt",
- "room_id": room_id,
- "content": content,
- }])
-
- @cachedList(cached_method_name="_get_linearized_receipts_for_room",
- list_name="room_ids", num_args=3, inlineCallbacks=True)
+ content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
+ row["user_id"]
+ ] = json.loads(row["data"])
+
+ defer.returnValue(
+ [{"type": "m.receipt", "room_id": room_id, "content": content}]
+ )
+
+ @cachedList(
+ cached_method_name="_get_linearized_receipts_for_room",
+ list_name="room_ids",
+ num_args=3,
+ inlineCallbacks=True,
+ )
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
@@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
- ) % (
- ",".join(["?"] * len(room_ids))
- )
+ ) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.extend([from_key, to_key])
@@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
- ) % (
- ",".join(["?"] * len(room_ids))
- )
+ ) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.append(to_key)
@@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.cursor_to_dict(txn)
- txn_results = yield self.runInteraction(
- "_get_linearized_receipts_for_rooms", f
- )
+ txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
- room_event = results.setdefault(row["room_id"], {
- "type": "m.receipt",
- "room_id": row["room_id"],
- "content": {},
- })
+ room_event = results.setdefault(
+ row["room_id"],
+ {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ )
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
@@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return (
- r[0:5] + (json.loads(r[5]), ) for r in txn
- )
+ return (r[0:5] + (json.loads(r[5]),) for r in txn)
+
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
+ def _invalidate_get_users_with_receipts_in_room(
+ self, room_id, receipt_type, user_id
+ ):
if receipt_type != "m.read":
return
# Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get(
- room_id, None, update_metrics=False,
+ room_id, None, update_metrics=False
)
# first handle the Deferred case
@@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
- def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
- user_id, event_id, data, stream_id):
+ def insert_linearized_receipt_txn(
+ self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
+ ):
"""Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None
@@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
table="events",
retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
- allow_none=True
+ allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
@@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore):
logger.debug(
"Ignoring new receipt for %s in favour of existing "
"one for later event %s",
- event_id, eid,
+ event_id,
+ eid,
)
return None
- txn.call_after(
- self.get_receipts_for_room.invalidate, (room_id, receipt_type)
- )
+ txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
- room_id, receipt_type, user_id,
+ room_id,
+ receipt_type,
+ user_id,
)
+ txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+ # FIXME: This shouldn't invalidate the whole cache
txn.call_after(
- self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- # FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
- self._receipts_stream_cache.entity_has_changed,
- room_id, stream_id
+ self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
- (user_id, room_id, receipt_type)
+ (user_id, room_id, receipt_type),
)
self._simple_delete_txn(
@@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
- }
+ },
)
self._simple_insert_txn(
@@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
- }
+ },
)
if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
- txn,
- room_id=room_id,
- user_id=user_id,
- stream_ordering=stream_ordering,
+ txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
return rx_ts
@@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
event_ts = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
- room_id, receipt_type, user_id, linearized_event_id,
+ room_id,
+ receipt_type,
+ user_id,
+ linearized_event_id,
data,
stream_id=stream_id,
)
@@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore):
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
- linearized_event_id, room_id, now - event_ts,
+ linearized_event_id,
+ room_id,
+ now - event_ts,
)
- yield self.insert_graph_receipt(
- room_id, receipt_type, user_id, event_ids, data
- )
+ yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))
- def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
- data):
+ def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
- room_id, receipt_type, user_id, event_ids, data
+ room_id,
+ receipt_type,
+ user_id,
+ event_ids,
+ data,
)
- def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
- user_id, event_ids, data):
- txn.call_after(
- self.get_receipts_for_room.invalidate, (room_id, receipt_type)
- )
+ def insert_graph_receipt_txn(
+ self, txn, room_id, receipt_type, user_id, event_ids, data
+ ):
+ txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
- room_id, receipt_type, user_id,
+ room_id,
+ receipt_type,
+ user_id,
)
+ txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+ # FIXME: This shouldn't invalidate the whole cache
txn.call_after(
- self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+ self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- # FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn(
txn,
@@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
- }
+ },
)
self._simple_insert_txn(
txn,
@@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
- }
+ },
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index eede8ae4d2..643f7a3808 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -37,13 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore):
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
- keyvalues={
- "name": user_id,
- },
+ keyvalues={"name": user_id},
retcols=[
- "name", "password_hash", "is_guest",
- "consent_version", "consent_server_notice_sent",
- "appservice_id", "creation_ts",
+ "name",
+ "password_hash",
+ "is_guest",
+ "consent_version",
+ "consent_server_notice_sent",
+ "appservice_id",
+ "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
@@ -81,9 +83,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
+ "get_user_by_access_token", self._query_for_auth, token
)
@cachedInlineCallbacks()
@@ -163,10 +163,10 @@ class RegistrationWorkerStore(SQLBaseStore):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
+
def f(txn):
sql = (
- "SELECT name, password_hash FROM users"
- " WHERE lower(name) = lower(?)"
+ "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
@@ -176,6 +176,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
+
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
@@ -193,6 +194,7 @@ class RegistrationWorkerStore(SQLBaseStore):
3) bridged users
who registered on the homeserver in the past 24 hours
"""
+
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
@@ -213,15 +215,18 @@ class RegistrationWorkerStore(SQLBaseStore):
for row in txn:
results[row[0]] = row[1]
return results
+
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
- txn.execute("""
+ txn.execute(
+ """
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
- """)
+ """
+ )
count, = txn.fetchone()
return count
@@ -240,6 +245,7 @@ class RegistrationWorkerStore(SQLBaseStore):
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
+
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
@@ -247,7 +253,7 @@ class RegistrationWorkerStore(SQLBaseStore):
found = set()
- for user_id, in txn:
+ for (user_id,) in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
@@ -255,20 +261,22 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
- defer.returnValue((yield self.runInteraction(
- "find_next_generated_user_id",
- _find_next_generated_user_id
- )))
+ defer.returnValue(
+ (
+ yield self.runInteraction(
+ "find_next_generated_user_id", _find_next_generated_user_id
+ )
+ )
+ )
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
- {
- "medium": medium,
- "address": address
- },
- ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ {"medium": medium, "address": address},
+ ["guest_access_token"],
+ True,
+ 'get_3pid_guest_access_token',
)
if ret:
defer.returnValue(ret["guest_access_token"])
@@ -286,8 +294,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
- "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
- medium, address
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
defer.returnValue(user_id)
@@ -305,11 +312,9 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = self._simple_select_one_txn(
txn,
"user_threepids",
- {
- "medium": medium,
- "address": address
- },
- ['user_id'], True
+ {"medium": medium, "address": address},
+ ['user_id'],
+ True,
)
if ret:
return ret['user_id']
@@ -317,41 +322,110 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert("user_threepids", {
- "medium": medium,
- "address": address,
- }, {
- "user_id": user_id,
- "validated_at": validated_at,
- "added_at": added_at,
- })
+ yield self._simple_upsert(
+ "user_threepids",
+ {"medium": medium, "address": address},
+ {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
+ )
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
- "user_threepids", {
- "user_id": user_id
- },
+ "user_threepids",
+ {"user_id": user_id},
['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids'
+ 'user_get_threepids',
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepids",
+ )
+
+ def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied a bind request to the given identity server on
+ behalf of the given user. We need to remember this in case the user
+ asks us to unbind the threepid.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ # We need to use an upsert, in case they user had already bound the
+ # threepid
+ return self._simple_upsert(
+ table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
+ "id_server": id_server,
},
- desc="user_delete_threepids",
+ values={},
+ insertion_values={},
+ desc="add_user_bound_threepid",
)
+ def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied an unbind request to the given identity server on
+ behalf of the given user, so we remove the mapping of threepid to
+ identity server.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ return self._simple_delete(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ "id_server": id_server,
+ },
+ desc="remove_user_bound_threepid",
+ )
+
+ def get_id_servers_user_bound(self, user_id, medium, address):
+ """Get the list of identity servers that the server proxied bind
+ requests to for given user and threepid
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+
+ Returns:
+ Deferred[list[str]]: Resolves to a list of identity servers
+ """
+ return self._simple_select_onecol(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ },
+ retcol="id_server",
+ desc="get_id_servers_user_bound",
+ )
-class RegistrationStore(RegistrationWorkerStore,
- background_updates.BackgroundUpdateStore):
+class RegistrationStore(
+ RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
@@ -378,6 +452,10 @@ class RegistrationStore(RegistrationWorkerStore,
# clear the background update.
self.register_noop_background_update("refresh_tokens_device_index")
+ self.register_background_update_handler(
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ )
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -394,18 +472,22 @@ class RegistrationStore(RegistrationWorkerStore,
yield self._simple_insert(
"access_tokens",
- {
- "id": next_id,
- "user_id": user_id,
- "token": token,
- "device_id": device_id,
- },
+ {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
desc="add_access_token_to_user",
)
- def register(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False, user_type=None):
+ def register(
+ self,
+ user_id,
+ token=None,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ ):
"""Attempts to register an account.
Args:
@@ -439,7 +521,7 @@ class RegistrationStore(RegistrationWorkerStore,
appservice_id,
create_profile_with_displayname,
admin,
- user_type
+ user_type,
)
def _register(
@@ -469,10 +551,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_select_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
retcols=("name",),
allow_none=False,
)
@@ -480,10 +559,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_update_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
@@ -491,7 +567,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
else:
self._simple_insert_txn(
@@ -505,7 +581,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
if self._account_validity.enabled:
@@ -520,17 +596,14 @@ class RegistrationStore(RegistrationWorkerStore,
}
)
except self.database_engine.module.IntegrityError:
- raise StoreError(
- 400, "User ID already taken.", errcode=Codes.USER_IN_USE
- )
+ raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)"
- " VALUES (?,?,?)",
- (next_id, user_id, token,)
+ "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
+ (next_id, user_id, token),
)
if create_profile_with_displayname:
@@ -541,12 +614,10 @@ class RegistrationStore(RegistrationWorkerStore,
# while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
- (user_id_obj.localpart, create_profile_with_displayname)
+ (user_id_obj.localpart, create_profile_with_displayname),
)
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def user_set_password_hash(self, user_id, password_hash):
@@ -555,22 +626,14 @@ class RegistrationStore(RegistrationWorkerStore,
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
+
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
- txn,
- 'users', {
- 'name': user_id
- },
- {
- 'password_hash': password_hash
- }
+ txn, 'users', {'name': user_id}, {'password_hash': password_hash}
)
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
- )
- return self.runInteraction(
- "user_set_password_hash", user_set_password_hash_txn
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -583,16 +646,16 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_version': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_version': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
@@ -607,20 +670,19 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_server_notice_sent': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_server_notice_sent': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None,
- device_id=None):
+ def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
Invalidate access tokens belonging to a user
@@ -635,10 +697,9 @@ class RegistrationStore(RegistrationWorkerStore,
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
"""
+
def f(txn):
- keyvalues = {
- "user_id": user_id,
- }
+ keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -650,8 +711,9 @@ class RegistrationStore(RegistrationWorkerStore,
values.append(except_token_id)
txn.execute(
- "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
- values
+ "SELECT token, id, device_id FROM access_tokens WHERE %s"
+ % where_clause,
+ values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
@@ -660,25 +722,16 @@ class RegistrationStore(RegistrationWorkerStore,
txn, self.get_user_by_access_token, (token,)
)
- txn.execute(
- "DELETE FROM access_tokens WHERE %s" % where_clause,
- values
- )
+ txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
return tokens_and_devices
- return self.runInteraction(
- "user_delete_access_tokens", f,
- )
+ return self.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self._simple_delete_one_txn(
- txn,
- table="access_tokens",
- keyvalues={
- "token": access_token
- },
+ txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
@@ -701,7 +754,7 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
+ self, medium, address, access_token, inviter_user_id
):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
@@ -717,12 +770,13 @@ class RegistrationStore(RegistrationWorkerStore,
deferred str: Whichever access token is persisted at the end
of this function call.
"""
+
def insert(txn):
txn.execute(
"INSERT INTO threepid_guest_access_tokens "
"(medium, address, guest_access_token, first_inviter) "
"VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id)
+ (medium, address, access_token, inviter_user_id),
)
try:
@@ -739,9 +793,7 @@ class RegistrationStore(RegistrationWorkerStore,
"""
return self._simple_insert(
"users_pending_deactivation",
- values={
- "user_id": user_id,
- },
+ values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
@@ -754,9 +806,7 @@ class RegistrationStore(RegistrationWorkerStore,
# the table, so somehow duplicate entries have ended up in it.
return self._simple_delete(
"users_pending_deactivation",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
@@ -772,3 +822,34 @@ class RegistrationStore(RegistrationWorkerStore,
allow_none=True,
desc="get_users_pending_deactivation",
)
+
+ @defer.inlineCallbacks
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+ )
+
+ yield self._end_background_update("user_threepids_grandfather")
+
+ defer.returnValue(1)
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 880f047adb..f4c1c2a457 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -36,9 +36,7 @@ class RejectionsStore(SQLBaseStore):
return self._simple_select_one_onecol(
table="rejections",
retcol="reason",
- keyvalues={
- "event_id": event_id,
- },
+ keyvalues={"event_id": event_id},
allow_none=True,
desc="get_rejection_reason",
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index a979d4860a..fe9d79d792 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
- "OpsLevel",
- ("ban_level", "kick_level", "redact_level",)
+ "OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
- "RatelimitOverride",
- ("messages_per_second", "burst_count",)
+ "RatelimitOverride", ("messages_per_second", "burst_count")
)
@@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
- keyvalues={
- "is_public": True,
- },
+ keyvalues={"is_public": True},
retcol="room_id",
desc="get_public_room_ids",
)
@@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
- stream_id, network_tuple=network_tuple
+ stream_id,
+ network_tuple=network_tuple,
)
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
- network_tuple):
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
@@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
if network_tuple:
# We want to get from a particular list. No aggregation required.
- sql = ("""
+ sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
@@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
- """)
+ """
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id),
)
else:
- txn.execute(
- sql % ("AND appservice_id IS NULL",),
- (stream_id,)
- )
+ txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
- sql = ("""
+ sql = """
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
@@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
- """)
+ """
- txn.execute(
- sql,
- (stream_id,)
- )
+ txn.execute(sql, (stream_id,))
results = {}
# A room is visible if its visible on any list.
@@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
return results
- def get_public_room_changes(self, prev_stream_id, new_stream_id,
- network_tuple):
+ def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
@@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, new_stream_id, network_tuple
)
- now_rooms_visible = set(
- rm for rm, vis in now_rooms_dict.items() if vis
- )
+ now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
@@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
def is_room_blocked(self, room_id):
return self._simple_select_one_onecol(
table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
+ keyvalues={"room_id": room_id},
retcol="1",
allow_none=True,
desc="is_room_blocked",
@@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
)
if row:
- defer.returnValue(RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- ))
+ defer.returnValue(
+ RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ )
+ )
else:
defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore):
-
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
StoreError if the room could not be stored.
"""
try:
+
def store_room_txn(txn, next_id):
self._simple_insert_txn(
txn,
@@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
- }
+ },
)
+
with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "store_room_txn",
- store_room_txn, next_id,
- )
+ yield self.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": None,
"network_id": None,
- }
+ },
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
- "set_room_is_public",
- set_room_is_public_txn, next_id,
+ "set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
- def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
- is_public):
+ def set_room_is_public_appservice(
+ self, room_id, appservice_id, network_id, is_public
+ ):
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
is_public (bool): Whether to publish or unpublish the room from the
list.
"""
+
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
@@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
values={
"appservice_id": appservice_id,
"network_id": network_id,
- "room_id": room_id
+ "room_id": room_id,
},
)
except self.database_engine.module.IntegrityError:
@@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
keyvalues={
"appservice_id": appservice_id,
"network_id": network_id,
- "room_id": room_id
+ "room_id": room_id,
},
)
@@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": appservice_id,
"network_id": network_id,
- }
+ },
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public_appservice",
- set_room_is_public_appservice_txn, next_id,
+ set_room_is_public_appservice_txn,
+ next_id,
)
self.hs.get_notifier().on_new_replication_data()
@@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.runInteraction(
- "get_rooms", f
- )
+ return self.runInteraction("get_rooms", f)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
@@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
)
self.store_event_search_txn(
- txn, event, "content.topic", event.content["topic"],
+ txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event):
@@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
"event_id": event.event_id,
"room_id": event.room_id,
"name": event.content["name"],
- }
+ },
)
self.store_event_search_txn(
- txn, event, "content.name", event.content["name"],
+ txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self.store_event_search_txn(
- txn, event, "content.body", event.content["body"],
+ txn, event, "content.body", event.content["body"]
)
def _store_history_visibility_txn(self, txn, event):
@@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
" (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key}
)
- txn.execute(sql, (
- event.event_id,
- event.room_id,
- event.content[key]
- ))
-
- def add_event_report(self, room_id, event_id, user_id, reason, content,
- received_ts):
+ txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
+
+ def add_event_report(
+ self, room_id, event_id, user_id, reason, content, received_ts
+ ):
next_id = self._event_reports_id_gen.get_next()
return self._simple_insert(
table="event_reports",
@@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
"reason": reason,
"content": json.dumps(content),
},
- desc="add_event_report"
+ desc="add_event_report",
)
def get_current_public_room_stream_id(self):
@@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
- sql = ("""
+ sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
- """)
+ """
- txn.execute(sql, (prev_id, current_id, limit,))
+ txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
- return self.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
+ return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
@@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
"""
yield self._simple_upsert(
table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
+ keyvalues={"room_id": room_id},
values={},
- insertion_values={
- "user_id": user_id,
- },
+ insertion_values={"user_id": user_id},
desc="block_room",
)
yield self.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
- self.is_room_blocked, (room_id,),
+ self.is_room_blocked,
+ (room_id,),
)
def get_media_mxcs_in_room(self, room_id):
@@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
+
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
@@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
+
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
+
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
- txn.executemany("""
+ txn.executemany(
+ """
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_mxcs))
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
txn.executemany(
"""
@@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
- )
+ ),
)
total_media_quarantined += len(local_mxcs)
@@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
return total_media_quarantined
return self.runInteraction(
- "quarantine_media_in_room",
- _quarantine_media_in_room_txn,
+ "quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 592c1bcd33..57df17bcc2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,28 +35,22 @@ logger = logging.getLogger(__name__)
RoomsForUser = namedtuple(
- "RoomsForUser",
- ("room_id", "sender", "membership", "event_id", "stream_ordering")
+ "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering",
- ("room_id", "stream_ordering",)
+ "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
)
# We store this using a namedtuple so that we save about 3x space over using a
# dict.
-ProfileInfo = namedtuple(
- "ProfileInfo", ("avatar_url", "display_name")
-)
+ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
# "members" points to a truncated list of (user_id, event_id) tuples for users of
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple(
- "MemberSummary", ("members", "count")
-)
+MemberSummary = namedtuple("MemberSummary", ("members", "count"))
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
- txn.execute(sql, (room_id, Membership.JOIN,))
+ txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn]
+
return self.runInteraction("get_users_in_room", f)
@cached(max_entries=100000)
@@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A deferred list of RoomsForUser.
"""
- return self.get_rooms_for_user_where_membership_is(
- user_id, [Membership.INVITE]
- )
+ return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
@@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
- user_id, membership_list
+ user_id,
+ membership_list,
)
- def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
- membership_list):
+ def _get_rooms_for_user_where_membership_is_txn(
+ self, txn, user_id, membership_list
+ ):
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
@@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) % (where_clause,)
txn.execute(sql, args)
- results = [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite:
sql = (
@@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id,))
- results.extend(RoomsForUser(
- room_id=r["room_id"],
- sender=r["inviter"],
- event_id=r["event_id"],
- stream_ordering=r["stream_ordering"],
- membership=Membership.INVITE,
- ) for r in self.cursor_to_dict(txn))
+ results.extend(
+ RoomsForUser(
+ room_id=r["room_id"],
+ sender=r["inviter"],
+ event_id=r["event_id"],
+ stream_ordering=r["stream_ordering"],
+ membership=Membership.INVITE,
+ )
+ for r in self.cursor_to_dict(txn)
+ )
return results
@@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
- user_id, membership_list=[Membership.JOIN],
+ user_id, membership_list=[Membership.JOIN]
+ )
+ defer.returnValue(
+ frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
+ )
)
- defer.returnValue(frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
- ))
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
- user_id, on_invalidate=on_invalidate,
+ user_id, on_invalidate=on_invalidate
)
defer.returnValue(frozenset(r.room_id for r in rooms))
@@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`
"""
room_ids = yield self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate,
+ user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
@@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
current_state_ids = yield context.get_current_state_ids(self)
result = yield self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids,
- event=event,
- context=context,
+ event.room_id, state_group, current_state_ids, event=event, context=context
)
defer.returnValue(result)
@@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
return self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry,
+ room_id, state_group, state_entry.state, context=state_entry
)
- @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
- max_entries=100000)
- def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
- cache_context, event=None, context=None):
+ @cachedInlineCallbacks(
+ num_args=2, cache_context=True, iterable=True, max_entries=100000
+ )
+ def _get_joined_users_from_context(
+ self,
+ room_id,
+ state_group,
+ current_state_ids,
+ cache_context,
+ event=None,
+ context=None,
+ ):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
@@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
event_map = self._get_events_from_cache(
- member_event_ids,
- allow_rejected=False,
- update_metrics=False,
+ member_event_ids, allow_rejected=False, update_metrics=False
)
missing_member_event_ids = []
@@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
- retcols=('user_id', 'display_name', 'avatar_url',),
- keyvalues={
- "membership": Membership.JOIN,
- },
+ retcols=('user_id', 'display_name', 'avatar_url'),
+ keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
)
- users_in_room.update({
- to_ascii(row["user_id"]): ProfileInfo(
- avatar_url=to_ascii(row["avatar_url"]),
- display_name=to_ascii(row["display_name"]),
- )
- for row in rows
- })
+ users_in_room.update(
+ {
+ to_ascii(row["user_id"]): ProfileInfo(
+ avatar_url=to_ascii(row["avatar_url"]),
+ display_name=to_ascii(row["display_name"]),
+ )
+ for row in rows
+ }
+ )
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
return self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry,
+ room_id, state_group, state_entry.state, state_entry=state_entry
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
@@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
+
def f(txn):
sql = (
"SELECT"
@@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id, room_id))
rows = txn.fetchall()
return rows[0][0]
+
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
- ]
+ ],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
- event.state_key, event.internal_metadata.stream_ordering
+ event.state_key,
+ event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
@@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
- }
+ },
)
else:
sql = (
@@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore):
" AND replaced_by is NULL"
)
- txn.execute(sql, (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ))
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
@@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
def f(txn, stream_ordering):
- txn.execute(sql, (
- stream_ordering,
- True,
- room_id,
- user_id,
- ))
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
+
def f(txn):
sql = (
"UPDATE"
@@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
txn.execute(sql, (user_id, room_id))
- self._invalidate_cache_and_stream(
- txn, self.did_forget, (user_id, room_id,),
- )
+ self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks
@@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn):
- sql = ("""
+ sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
INNER JOIN event_json USING (event_id)
@@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
AND type = 'm.room.member'
ORDER BY stream_ordering DESC
LIMIT ?
- """)
+ """
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
@@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
avatar_url = content.get("avatar_url", None)
if display_name or avatar_url:
- to_update.append((
- display_name, avatar_url, event_id, room_id
- ))
+ to_update.append((display_name, avatar_url, event_id, room_id))
- to_update_sql = ("""
+ to_update_sql = """
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
- """)
+ """
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index:index + INSERT_CLUMP_SIZE]
+ clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = {
@@ -789,7 +790,7 @@ class _JoinedHostsCache(object):
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
- self.room_id, state_entry,
+ self.room_id, state_entry
)
self.hosts_to_joined_users = {}
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/schema/delta/13/v13.sql
index 5eb93b38b2..f8649e5d99 100644
--- a/synapse/storage/schema/delta/13/v13.sql
+++ b/synapse/storage/schema/delta/13/v13.sql
@@ -13,19 +13,7 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS application_services(
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- url TEXT,
- token TEXT,
- hs_token TEXT,
- sender TEXT,
- UNIQUE(token)
-);
-
-CREATE TABLE IF NOT EXISTS application_services_regex(
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- as_id BIGINT UNSIGNED NOT NULL,
- namespace INTEGER, /* enum[room_id|room_alias|user_id] */
- regex TEXT,
- FOREIGN KEY(as_id) REFERENCES application_services(id)
-);
+/* We used to create a tables called application_services and
+ * application_services_regex, but these are no longer used and are removed in
+ * delta 54.
+ */
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
deleted file mode 100644
index 4d725b92fe..0000000000
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-import simplejson as json
-
-logger = logging.getLogger(__name__)
-
-
-def run_create(cur, *args, **kwargs):
- cur.execute("SELECT id, regex FROM application_services_regex")
- for row in cur.fetchall():
- try:
- logger.debug("Checking %s..." % row[0])
- json.loads(row[1])
- except ValueError:
- # row isn't in json, make it so.
- string_regex = row[1]
- new_regex = json.dumps({
- "regex": string_regex,
- "exclusive": True
- })
- cur.execute(
- "UPDATE application_services_regex SET regex=? WHERE id=?",
- (new_regex, row[0])
- )
-
-
-def run_upgrade(*args, **kwargs):
- pass
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/schema/delta/16/unique_constraints.sql
index fecf11118c..5b8de52c33 100644
--- a/synapse/storage/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/schema/delta/16/unique_constraints.sql
@@ -18,14 +18,6 @@ DROP INDEX IF EXISTS room_memberships_event_id;
CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id);
--
-DELETE FROM feedback WHERE rowid not in (
- SELECT MIN(rowid) FROM feedback GROUP BY event_id
-);
-
-DROP INDEX IF EXISTS feedback_event_id;
-CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id);
-
---
DELETE FROM topics WHERE rowid not in (
SELECT MIN(rowid) FROM topics GROUP BY event_id
);
diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql
index 5f508af7a9..acea7483bd 100644
--- a/synapse/storage/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/schema/delta/24/stats_reporting.sql
@@ -1,4 +1,4 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
+/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,10 +13,6 @@
* limitations under the License.
*/
--- Should only ever contain one row
-CREATE TABLE IF NOT EXISTS stats_reporting(
- -- The stream ordering token which was most recently reported as stats
- reported_stream_token INTEGER,
- -- The time (seconds since epoch) stats were most recently reported
- reported_time BIGINT
-);
+ /* We used to create a table called stats_reporting, but this is no longer
+ * used and is removed in delta 54.
+ */
\ No newline at end of file
diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql
index 706fe1dcf4..e85699e82e 100644
--- a/synapse/storage/schema/delta/30/state_stream.sql
+++ b/synapse/storage/schema/delta/30/state_stream.sql
@@ -14,15 +14,10 @@
*/
-/**
- * The positions in the event stream_ordering when the current_state was
- * replaced by the state at the event.
+/* We used to create a table called current_state_resets, but this is no
+ * longer used and is removed in delta 54.
*/
-CREATE TABLE IF NOT EXISTS current_state_resets(
- event_stream_ordering BIGINT PRIMARY KEY NOT NULL
-);
-
/* The outlier events that have aquired a state group typically through
* backfill. This is tracked separately to the events table, as assigning a
* state group change the position of the existing event in the stream
diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/schema/delta/32/remove_indices.sql
index f859be46a6..4219cdd06a 100644
--- a/synapse/storage/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/schema/delta/32/remove_indices.sql
@@ -24,13 +24,9 @@ DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
-DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT
-DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT
-DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT
-DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT
-- The following indices were unused
DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id;
diff --git a/synapse/storage/schema/full_schemas/11/room_aliases.sql b/synapse/storage/schema/delta/53/user_threepid_id.sql
index 71a91f8ec9..80c2c573b6 100644
--- a/synapse/storage/schema/full_schemas/11/room_aliases.sql
+++ b/synapse/storage/schema/delta/53/user_threepid_id.sql
@@ -1,4 +1,4 @@
-/* Copyright 2014-2016 OpenMarket Ltd
+/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,12 +13,17 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS room_aliases(
- room_alias TEXT NOT NULL,
- room_id TEXT NOT NULL
+-- Tracks which identity server a user bound their threepid via.
+CREATE TABLE user_threepid_id_server (
+ user_id TEXT NOT NULL,
+ medium TEXT NOT NULL,
+ address TEXT NOT NULL,
+ id_server TEXT NOT NULL
);
-CREATE TABLE IF NOT EXISTS room_alias_servers(
- room_alias TEXT NOT NULL,
- server TEXT NOT NULL
+CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server(
+ user_id, medium, address, id_server
);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('user_threepids_grandfather', '{}');
diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/schema/delta/54/drop_legacy_tables.sql
new file mode 100644
index 0000000000..ecca005d9b
--- /dev/null
+++ b/synapse/storage/schema/delta/54/drop_legacy_tables.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP TABLE IF EXISTS application_services;
+DROP TABLE IF EXISTS application_services_regex;
+DROP TABLE IF EXISTS transaction_id_to_pdu;
+DROP TABLE IF EXISTS stats_reporting;
+DROP TABLE IF EXISTS current_state_resets;
+DROP TABLE IF EXISTS event_content_hashes;
+DROP TABLE IF EXISTS event_destinations;
+DROP TABLE IF EXISTS event_edge_hashes;
+DROP TABLE IF EXISTS event_signatures;
+DROP TABLE IF EXISTS feedback;
+DROP TABLE IF EXISTS room_hosts;
+DROP TABLE IF EXISTS server_tls_certificates;
+DROP TABLE IF EXISTS state_forward_extremities;
diff --git a/synapse/storage/schema/full_schemas/11/profiles.sql b/synapse/storage/schema/delta/54/drop_presence_list.sql
index b314e6df75..e6ee70c623 100644
--- a/synapse/storage/schema/full_schemas/11/profiles.sql
+++ b/synapse/storage/schema/delta/54/drop_presence_list.sql
@@ -1,4 +1,4 @@
-/* Copyright 2014-2016 OpenMarket Ltd
+/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -12,8 +12,5 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS profiles(
- user_id TEXT NOT NULL,
- displayname TEXT,
- avatar_url TEXT
-);
+
+DROP TABLE IF EXISTS presence_list;
diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql
deleted file mode 100644
index bccd1c6f74..0000000000
--- a/synapse/storage/schema/full_schemas/11/event_edges.sql
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS event_forward_extremities(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- UNIQUE (event_id, room_id)
-);
-
-CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
-CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_backward_extremities(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- UNIQUE (event_id, room_id)
-);
-
-CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
-CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_edges(
- event_id TEXT NOT NULL,
- prev_event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- -- We no longer insert prev_state into this table, so all new rows will have
- -- is_state as false.
- is_state BOOL NOT NULL,
- UNIQUE (event_id, prev_event_id, room_id, is_state)
-);
-
-CREATE INDEX ev_edges_id ON event_edges(event_id);
-CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
-
-
-CREATE TABLE IF NOT EXISTS room_depth(
- room_id TEXT NOT NULL,
- min_depth INTEGER NOT NULL,
- UNIQUE (room_id)
-);
-
-CREATE INDEX room_depth_room ON room_depth(room_id);
-
-
-create TABLE IF NOT EXISTS event_destinations(
- event_id TEXT NOT NULL,
- destination TEXT NOT NULL,
- delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
- UNIQUE (event_id, destination)
-);
-
-CREATE INDEX event_destinations_id ON event_destinations(event_id);
-
-
-CREATE TABLE IF NOT EXISTS state_forward_extremities(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- UNIQUE (event_id, room_id)
-);
-
-CREATE INDEX st_extrem_keys ON state_forward_extremities(
- room_id, type, state_key
-);
-CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_auth(
- event_id TEXT NOT NULL,
- auth_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- UNIQUE (event_id, auth_id, room_id)
-);
-
-CREATE INDEX evauth_edges_id ON event_auth(event_id);
-CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);
diff --git a/synapse/storage/schema/full_schemas/11/event_signatures.sql b/synapse/storage/schema/full_schemas/11/event_signatures.sql
deleted file mode 100644
index 00ce85980e..0000000000
--- a/synapse/storage/schema/full_schemas/11/event_signatures.sql
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS event_content_hashes (
- event_id TEXT,
- algorithm TEXT,
- hash bytea,
- UNIQUE (event_id, algorithm)
-);
-
-CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_reference_hashes (
- event_id TEXT,
- algorithm TEXT,
- hash bytea,
- UNIQUE (event_id, algorithm)
-);
-
-CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_signatures (
- event_id TEXT,
- signature_name TEXT,
- key_id TEXT,
- signature bytea,
- UNIQUE (event_id, signature_name, key_id)
-);
-
-CREATE INDEX event_signatures_id ON event_signatures(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_edge_hashes(
- event_id TEXT,
- prev_event_id TEXT,
- algorithm TEXT,
- hash bytea,
- UNIQUE (event_id, prev_event_id, algorithm)
-);
-
-CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
diff --git a/synapse/storage/schema/full_schemas/11/im.sql b/synapse/storage/schema/full_schemas/11/im.sql
deleted file mode 100644
index dfbbf9fd54..0000000000
--- a/synapse/storage/schema/full_schemas/11/im.sql
+++ /dev/null
@@ -1,123 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS events(
- stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT,
- topological_ordering BIGINT NOT NULL,
- event_id TEXT NOT NULL,
- type TEXT NOT NULL,
- room_id TEXT NOT NULL,
- content TEXT NOT NULL,
- unrecognized_keys TEXT,
- processed BOOL NOT NULL,
- outlier BOOL NOT NULL,
- depth BIGINT DEFAULT 0 NOT NULL,
- UNIQUE (event_id)
-);
-
-CREATE INDEX events_stream_ordering ON events (stream_ordering);
-CREATE INDEX events_topological_ordering ON events (topological_ordering);
-CREATE INDEX events_room_id ON events (room_id);
-
-
-CREATE TABLE IF NOT EXISTS event_json(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- internal_metadata TEXT NOT NULL,
- json TEXT NOT NULL,
- UNIQUE (event_id)
-);
-
-CREATE INDEX event_json_room_id ON event_json(room_id);
-
-
-CREATE TABLE IF NOT EXISTS state_events(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- prev_state TEXT,
- UNIQUE (event_id)
-);
-
-CREATE INDEX state_events_room_id ON state_events (room_id);
-CREATE INDEX state_events_type ON state_events (type);
-CREATE INDEX state_events_state_key ON state_events (state_key);
-
-
-CREATE TABLE IF NOT EXISTS current_state_events(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- UNIQUE (room_id, type, state_key)
-);
-
-CREATE INDEX curr_events_event_id ON current_state_events (event_id);
-CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
-CREATE INDEX current_state_events_type ON current_state_events (type);
-CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
-
-CREATE TABLE IF NOT EXISTS room_memberships(
- event_id TEXT NOT NULL,
- user_id TEXT NOT NULL,
- sender TEXT NOT NULL,
- room_id TEXT NOT NULL,
- membership TEXT NOT NULL
-);
-
-CREATE INDEX room_memberships_event_id ON room_memberships (event_id);
-CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
-CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
-
-CREATE TABLE IF NOT EXISTS feedback(
- event_id TEXT NOT NULL,
- feedback_type TEXT,
- target_event_id TEXT,
- sender TEXT,
- room_id TEXT
-);
-
-CREATE TABLE IF NOT EXISTS topics(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- topic TEXT NOT NULL
-);
-
-CREATE INDEX topics_event_id ON topics(event_id);
-CREATE INDEX topics_room_id ON topics(room_id);
-
-CREATE TABLE IF NOT EXISTS room_names(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- name TEXT NOT NULL
-);
-
-CREATE INDEX room_names_event_id ON room_names(event_id);
-CREATE INDEX room_names_room_id ON room_names(room_id);
-
-CREATE TABLE IF NOT EXISTS rooms(
- room_id TEXT PRIMARY KEY NOT NULL,
- is_public BOOL,
- creator TEXT
-);
-
-CREATE TABLE IF NOT EXISTS room_hosts(
- room_id TEXT NOT NULL,
- host TEXT NOT NULL,
- UNIQUE (room_id, host)
-);
-
-CREATE INDEX room_hosts_room_id ON room_hosts (room_id);
diff --git a/synapse/storage/schema/full_schemas/11/keys.sql b/synapse/storage/schema/full_schemas/11/keys.sql
deleted file mode 100644
index ca0ca1b694..0000000000
--- a/synapse/storage/schema/full_schemas/11/keys.sql
+++ /dev/null
@@ -1,31 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS server_tls_certificates(
- server_name TEXT, -- Server name.
- fingerprint TEXT, -- Certificate fingerprint.
- from_server TEXT, -- Which key server the certificate was fetched from.
- ts_added_ms BIGINT, -- When the certifcate was added.
- tls_certificate bytea, -- DER encoded x509 certificate.
- UNIQUE (server_name, fingerprint)
-);
-
-CREATE TABLE IF NOT EXISTS server_signature_keys(
- server_name TEXT, -- Server name.
- key_id TEXT, -- Key version.
- from_server TEXT, -- Which key server the key was fetched form.
- ts_added_ms BIGINT, -- When the key was added.
- verify_key bytea, -- NACL verification key.
- UNIQUE (server_name, key_id)
-);
diff --git a/synapse/storage/schema/full_schemas/11/media_repository.sql b/synapse/storage/schema/full_schemas/11/media_repository.sql
deleted file mode 100644
index 9c264d6ece..0000000000
--- a/synapse/storage/schema/full_schemas/11/media_repository.sql
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS local_media_repository (
- media_id TEXT, -- The id used to refer to the media.
- media_type TEXT, -- The MIME-type of the media.
- media_length INTEGER, -- Length of the media in bytes.
- created_ts BIGINT, -- When the content was uploaded in ms.
- upload_name TEXT, -- The name the media was uploaded with.
- user_id TEXT, -- The user who uploaded the file.
- UNIQUE (media_id)
-);
-
-CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
- media_id TEXT, -- The id used to refer to the media.
- thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
- thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
- thumbnail_type TEXT, -- The MIME-type of the thumbnail.
- thumbnail_method TEXT, -- The method used to make the thumbnail.
- thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
- UNIQUE (
- media_id, thumbnail_width, thumbnail_height, thumbnail_type
- )
-);
-
-CREATE INDEX local_media_repository_thumbnails_media_id
- ON local_media_repository_thumbnails (media_id);
-
-CREATE TABLE IF NOT EXISTS remote_media_cache (
- media_origin TEXT, -- The remote HS the media came from.
- media_id TEXT, -- The id used to refer to the media on that server.
- media_type TEXT, -- The MIME-type of the media.
- created_ts BIGINT, -- When the content was uploaded in ms.
- upload_name TEXT, -- The name the media was uploaded with.
- media_length INTEGER, -- Length of the media in bytes.
- filesystem_id TEXT, -- The name used to store the media on disk.
- UNIQUE (media_origin, media_id)
-);
-
-CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
- media_origin TEXT, -- The remote HS the media came from.
- media_id TEXT, -- The id used to refer to the media.
- thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
- thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
- thumbnail_method TEXT, -- The method used to make the thumbnail
- thumbnail_type TEXT, -- The MIME-type of the thumbnail.
- thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
- filesystem_id TEXT, -- The name used to store the media on disk.
- UNIQUE (
- media_origin, media_id, thumbnail_width, thumbnail_height,
- thumbnail_type
- )
-);
diff --git a/synapse/storage/schema/full_schemas/11/presence.sql b/synapse/storage/schema/full_schemas/11/presence.sql
deleted file mode 100644
index 492725994c..0000000000
--- a/synapse/storage/schema/full_schemas/11/presence.sql
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS presence(
- user_id TEXT NOT NULL,
- state VARCHAR(20),
- status_msg TEXT,
- mtime BIGINT -- miliseconds since last state change
-);
-
--- For each of /my/ users which possibly-remote users are allowed to see their
--- presence state
-CREATE TABLE IF NOT EXISTS presence_allow_inbound(
- observed_user_id TEXT NOT NULL,
- observer_user_id TEXT NOT NULL -- a UserID,
-);
-
--- For each of /my/ users (watcher), which possibly-remote users are they
--- watching?
-CREATE TABLE IF NOT EXISTS presence_list(
- user_id TEXT NOT NULL,
- observed_user_id TEXT NOT NULL, -- a UserID,
- accepted BOOLEAN NOT NULL
-);
diff --git a/synapse/storage/schema/full_schemas/11/redactions.sql b/synapse/storage/schema/full_schemas/11/redactions.sql
deleted file mode 100644
index 318f0d9aa5..0000000000
--- a/synapse/storage/schema/full_schemas/11/redactions.sql
+++ /dev/null
@@ -1,22 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS redactions (
- event_id TEXT NOT NULL,
- redacts TEXT NOT NULL,
- UNIQUE (event_id)
-);
-
-CREATE INDEX redactions_event_id ON redactions (event_id);
-CREATE INDEX redactions_redacts ON redactions (redacts);
diff --git a/synapse/storage/schema/full_schemas/11/state.sql b/synapse/storage/schema/full_schemas/11/state.sql
deleted file mode 100644
index b901e0f017..0000000000
--- a/synapse/storage/schema/full_schemas/11/state.sql
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS state_groups(
- id INTEGER PRIMARY KEY,
- room_id TEXT NOT NULL,
- event_id TEXT NOT NULL
-);
-
-CREATE TABLE IF NOT EXISTS state_groups_state(
- state_group INTEGER NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- event_id TEXT NOT NULL
-);
-
-CREATE TABLE IF NOT EXISTS event_to_state_groups(
- event_id TEXT NOT NULL,
- state_group INTEGER NOT NULL,
- UNIQUE (event_id)
-);
-
-CREATE INDEX state_groups_id ON state_groups(id);
-
-CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
-CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
-CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);
diff --git a/synapse/storage/schema/full_schemas/11/transactions.sql b/synapse/storage/schema/full_schemas/11/transactions.sql
deleted file mode 100644
index f6a058832e..0000000000
--- a/synapse/storage/schema/full_schemas/11/transactions.sql
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
--- Stores what transaction ids we have received and what our response was
-CREATE TABLE IF NOT EXISTS received_transactions(
- transaction_id TEXT,
- origin TEXT,
- ts BIGINT,
- response_code INTEGER,
- response_json bytea,
- has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx
- UNIQUE (transaction_id, origin)
-);
-
-CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-
--- For sent transactions only.
-CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
- transaction_id INTEGER,
- destination TEXT,
- pdu_id TEXT,
- pdu_origin TEXT
-);
-
-CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
-CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
-
--- To track destination health
-CREATE TABLE IF NOT EXISTS destinations(
- destination TEXT PRIMARY KEY,
- retry_last_ts BIGINT,
- retry_interval INTEGER
-);
diff --git a/synapse/storage/schema/full_schemas/11/users.sql b/synapse/storage/schema/full_schemas/11/users.sql
deleted file mode 100644
index 6c1d4c34a1..0000000000
--- a/synapse/storage/schema/full_schemas/11/users.sql
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2014-2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS users(
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT,
- password_hash TEXT,
- creation_ts BIGINT,
- admin SMALLINT DEFAULT 0 NOT NULL,
- UNIQUE(name)
-);
-
-CREATE TABLE IF NOT EXISTS access_tokens(
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id TEXT NOT NULL,
- device_id TEXT,
- token TEXT NOT NULL,
- last_used BIGINT,
- UNIQUE(token)
-);
-
-CREATE TABLE IF NOT EXISTS user_ips (
- user TEXT NOT NULL,
- access_token TEXT NOT NULL,
- device_id TEXT,
- ip TEXT NOT NULL,
- user_agent TEXT NOT NULL,
- last_seen BIGINT NOT NULL,
- UNIQUE (user, access_token, ip, user_agent)
-);
-
-CREATE INDEX user_ips_user ON user_ips(user);
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/schema/full_schemas/16/application_services.sql
index aee0e68473..883fcd10b2 100644
--- a/synapse/storage/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/schema/full_schemas/16/application_services.sql
@@ -13,22 +13,11 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS application_services(
- id BIGINT PRIMARY KEY,
- url TEXT,
- token TEXT,
- hs_token TEXT,
- sender TEXT,
- UNIQUE(token)
-);
+/* We used to create tables called application_services and
+ * application_services_regex, but these are no longer used and are removed in
+ * delta 54.
+ */
-CREATE TABLE IF NOT EXISTS application_services_regex(
- id BIGINT PRIMARY KEY,
- as_id BIGINT NOT NULL,
- namespace INTEGER, /* enum[room_id|room_alias|user_id] */
- regex TEXT,
- FOREIGN KEY(as_id) REFERENCES application_services(id)
-);
CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY,
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql
index 6b5a5a88fa..10ce2aa7a0 100644
--- a/synapse/storage/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/schema/full_schemas/16/event_edges.sql
@@ -13,6 +13,11 @@
* limitations under the License.
*/
+/* We used to create tables called event_destinations and
+ * state_forward_extremities, but these are no longer used and are removed in
+ * delta 54.
+ */
+
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
@@ -54,31 +59,6 @@ CREATE TABLE IF NOT EXISTS room_depth(
CREATE INDEX room_depth_room ON room_depth(room_id);
-
-create TABLE IF NOT EXISTS event_destinations(
- event_id TEXT NOT NULL,
- destination TEXT NOT NULL,
- delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
- UNIQUE (event_id, destination)
-);
-
-CREATE INDEX event_destinations_id ON event_destinations(event_id);
-
-
-CREATE TABLE IF NOT EXISTS state_forward_extremities(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- UNIQUE (event_id, room_id)
-);
-
-CREATE INDEX st_extrem_keys ON state_forward_extremities(
- room_id, type, state_key
-);
-CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
-
-
CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/schema/full_schemas/16/event_signatures.sql
index 00ce85980e..95826da431 100644
--- a/synapse/storage/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/schema/full_schemas/16/event_signatures.sql
@@ -13,15 +13,9 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS event_content_hashes (
- event_id TEXT,
- algorithm TEXT,
- hash bytea,
- UNIQUE (event_id, algorithm)
-);
-
-CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
-
+ /* We used to create tables called event_content_hashes and event_edge_hashes,
+ * but these are no longer used and are removed in delta 54.
+ */
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
@@ -42,14 +36,3 @@ CREATE TABLE IF NOT EXISTS event_signatures (
);
CREATE INDEX event_signatures_id ON event_signatures(event_id);
-
-
-CREATE TABLE IF NOT EXISTS event_edge_hashes(
- event_id TEXT,
- prev_event_id TEXT,
- algorithm TEXT,
- hash bytea,
- UNIQUE (event_id, prev_event_id, algorithm)
-);
-
-CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql
index 5f5cb8d01d..a1a2aa8e5b 100644
--- a/synapse/storage/schema/full_schemas/16/im.sql
+++ b/synapse/storage/schema/full_schemas/16/im.sql
@@ -13,6 +13,10 @@
* limitations under the License.
*/
+/* We used to create tables called room_hosts and feedback,
+ * but these are no longer used and are removed in delta 54.
+ */
+
CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY,
topological_ordering BIGINT NOT NULL,
@@ -91,15 +95,6 @@ CREATE TABLE IF NOT EXISTS room_memberships(
CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
-CREATE TABLE IF NOT EXISTS feedback(
- event_id TEXT NOT NULL,
- feedback_type TEXT,
- target_event_id TEXT,
- sender TEXT,
- room_id TEXT,
- UNIQUE (event_id)
-);
-
CREATE TABLE IF NOT EXISTS topics(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
@@ -123,11 +118,3 @@ CREATE TABLE IF NOT EXISTS rooms(
is_public BOOL,
creator TEXT
);
-
-CREATE TABLE IF NOT EXISTS room_hosts(
- room_id TEXT NOT NULL,
- host TEXT NOT NULL,
- UNIQUE (room_id, host)
-);
-
-CREATE INDEX room_hosts_room_id ON room_hosts (room_id);
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/schema/full_schemas/16/keys.sql
index ca0ca1b694..11cdffdbb3 100644
--- a/synapse/storage/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/schema/full_schemas/16/keys.sql
@@ -12,14 +12,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS server_tls_certificates(
- server_name TEXT, -- Server name.
- fingerprint TEXT, -- Certificate fingerprint.
- from_server TEXT, -- Which key server the certificate was fetched from.
- ts_added_ms BIGINT, -- When the certifcate was added.
- tls_certificate bytea, -- DER encoded x509 certificate.
- UNIQUE (server_name, fingerprint)
-);
+
+-- we used to create a table called server_tls_certificates, but this is no
+-- longer used, and is removed in delta 54.
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql
index 283136df20..01d2d8f833 100644
--- a/synapse/storage/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/schema/full_schemas/16/presence.sql
@@ -28,13 +28,5 @@ CREATE TABLE IF NOT EXISTS presence_allow_inbound(
UNIQUE (observed_user_id, observer_user_id)
);
--- For each of /my/ users (watcher), which possibly-remote users are they
--- watching?
-CREATE TABLE IF NOT EXISTS presence_list(
- user_id TEXT NOT NULL,
- observed_user_id TEXT NOT NULL, -- a UserID,
- accepted BOOLEAN NOT NULL,
- UNIQUE (user_id, observed_user_id)
-);
-
-CREATE INDEX presence_list_user_id ON presence_list (user_id);
+-- We used to create a table called presence_list, but this is no longer used
+-- and is removed in delta 54.
\ No newline at end of file
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index c6420b2374..226f8f1b7e 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -30,10 +30,10 @@ from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
-SearchEntry = namedtuple('SearchEntry', [
- 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
- 'origin_server_ts',
-])
+SearchEntry = namedtuple(
+ 'SearchEntry',
+ ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
+)
class SearchStore(BackgroundUpdateStore):
@@ -53,8 +53,7 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
self.register_background_update_handler(
- self.EVENT_SEARCH_ORDER_UPDATE_NAME,
- self._background_reindex_search_order
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
# we used to have a background update to turn the GIN index into a
@@ -62,13 +61,10 @@ class SearchStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
- self.register_noop_background_update(
- self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
- )
+ self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
self.register_background_update_handler(
- self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
- self._background_reindex_gin_search
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
@defer.inlineCallbacks
@@ -138,21 +134,23 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
- event_search_rows.append(SearchEntry(
- key=key,
- value=value,
- event_id=event_id,
- room_id=room_id,
- stream_ordering=stream_ordering,
- origin_server_ts=origin_server_ts,
- ))
+ event_search_rows.append(
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ origin_server_ts=origin_server_ts,
+ )
+ )
self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(event_search_rows)
+ "rows_inserted": rows_inserted + len(event_search_rows),
}
self._background_update_progress_txn(
@@ -191,6 +189,7 @@ class SearchStore(BackgroundUpdateStore):
# doesn't support CREATE INDEX IF EXISTS so we just catch the
# exception and ignore it.
import psycopg2
+
try:
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
@@ -198,14 +197,11 @@ class SearchStore(BackgroundUpdateStore):
)
except psycopg2.ProgrammingError as e:
logger.warn(
- "Ignoring error %r when trying to switch from GIST to GIN",
- e
+ "Ignoring error %r when trying to switch from GIST to GIN", e
)
# we should now be able to delete the GIST index.
- c.execute(
- "DROP INDEX IF EXISTS event_search_fts_idx_gist"
- )
+ c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
finally:
conn.set_session(autocommit=False)
@@ -223,6 +219,7 @@ class SearchStore(BackgroundUpdateStore):
have_added_index = progress['have_added_indexes']
if not have_added_index:
+
def create_index(conn):
conn.rollback()
conn.set_session(autocommit=True)
@@ -248,7 +245,8 @@ class SearchStore(BackgroundUpdateStore):
yield self.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_update_progress_txn,
- self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME,
+ pg,
)
def reindex_search_txn(txn):
@@ -302,14 +300,16 @@ class SearchStore(BackgroundUpdateStore):
"""
self.store_search_entries_txn(
txn,
- (SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),),
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
)
def store_search_entries_txn(self, txn, entries):
@@ -329,10 +329,17 @@ class SearchStore(BackgroundUpdateStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
- args = ((
- entry.event_id, entry.room_id, entry.key, entry.value,
- entry.stream_ordering, entry.origin_server_ts,
- ) for entry in entries)
+ args = (
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ entry.value,
+ entry.stream_ordering,
+ entry.origin_server_ts,
+ )
+ for entry in entries
+ )
# inserts to a GIN index are normally batched up into a pending
# list, and then all committed together once the list gets to a
@@ -363,9 +370,10 @@ class SearchStore(BackgroundUpdateStore):
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
- args = ((
- entry.event_id, entry.room_id, entry.key, entry.value,
- ) for entry in entries)
+ args = (
+ (entry.event_id, entry.room_id, entry.key, entry.value)
+ for entry in entries
+ )
txn.executemany(sql, args)
else:
@@ -394,9 +402,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append(
- "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
- )
+ clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
local_clauses = []
@@ -404,9 +410,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?")
args.append(key)
- clauses.append(
- "(%s)" % (" OR ".join(local_clauses),)
- )
+ clauses.append("(%s)" % (" OR ".join(local_clauses),))
count_args = args
count_clauses = clauses
@@ -452,18 +456,13 @@ class SearchStore(BackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self._execute(
- "search_msgs", self.cursor_to_dict, sql, *args
- )
+ results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
- event_map = {
- ev.event_id: ev
- for ev in events
- }
+ event_map = {ev.event_id: ev for ev in events}
highlights = None
if isinstance(self.database_engine, PostgresEngine):
@@ -477,18 +476,17 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue({
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- })
+ defer.returnValue(
+ {
+ "results": [
+ {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
+ )
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -513,9 +511,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append(
- "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
- )
+ clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
local_clauses = []
@@ -523,9 +519,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?")
args.append(key)
- clauses.append(
- "(%s)" % (" OR ".join(local_clauses),)
- )
+ clauses.append("(%s)" % (" OR ".join(local_clauses),))
# take copies of the current args and clauses lists, before adding
# pagination clauses to main query.
@@ -607,18 +601,13 @@ class SearchStore(BackgroundUpdateStore):
args.append(limit)
- results = yield self._execute(
- "search_rooms", self.cursor_to_dict, sql, *args
- )
+ results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
- event_map = {
- ev.event_id: ev
- for ev in events
- }
+ event_map = {ev.event_id: ev for ev in events}
highlights = None
if isinstance(self.database_engine, PostgresEngine):
@@ -632,21 +621,22 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue({
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s" % (
- r["origin_server_ts"], r["stream_ordering"]
- ),
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- })
+ defer.returnValue(
+ {
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s"
+ % (r["origin_server_ts"], r["stream_ordering"]),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
+ )
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
@@ -662,6 +652,7 @@ class SearchStore(BackgroundUpdateStore):
Returns:
deferred : A set of strings.
"""
+
def f(txn):
highlight_words = set()
for event in events:
@@ -689,13 +680,15 @@ class SearchStore(BackgroundUpdateStore):
stop_sel += ">"
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
- _to_postgres_options({
- "StartSel": start_sel,
- "StopSel": stop_sel,
- "MaxFragments": "50",
- })
+ _to_postgres_options(
+ {
+ "StartSel": start_sel,
+ "StopSel": stop_sel,
+ "MaxFragments": "50",
+ }
+ )
)
- txn.execute(query, (value, search_query,))
+ txn.execute(query, (value, search_query))
headline, = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
@@ -714,9 +707,7 @@ class SearchStore(BackgroundUpdateStore):
def _to_postgres_options(options_dict):
- return "'%s'" % (
- ",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
- )
+ return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine, search_term):
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 158e9dbe7b..6bd81e84ad 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -39,8 +39,9 @@ class SignatureWorkerStore(SQLBaseStore):
# to use its cache
raise NotImplementedError()
- @cachedList(cached_method_name="get_event_reference_hash",
- list_name="event_ids", num_args=1)
+ @cachedList(
+ cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
+ )
def get_event_reference_hashes(self, event_ids):
def f(txn):
return {
@@ -48,21 +49,13 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
- return self.runInteraction(
- "get_event_reference_hashes",
- f
- )
+ return self.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
- hashes = yield self.get_event_reference_hashes(
- event_ids
- )
+ hashes = yield self.get_event_reference_hashes(event_ids)
hashes = {
- e_id: {
- k: encode_base64(v) for k, v in h.items()
- if k == "sha256"
- }
+ e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
}
@@ -81,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
" FROM event_reference_hashes"
" WHERE event_id = ?"
)
- txn.execute(query, (event_id, ))
+ txn.execute(query, (event_id,))
return {k: v for k, v in txn}
@@ -98,14 +91,12 @@ class SignatureStore(SignatureWorkerStore):
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append({
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": db_binary_type(ref_hash_bytes),
- })
-
- self._simple_insert_many_txn(
- txn,
- table="event_reference_hashes",
- values=vals,
- )
+ vals.append(
+ {
+ "event_id": event.event_id,
+ "algorithm": ref_alg,
+ "hash": db_binary_type(ref_hash_bytes),
+ }
+ )
+
+ self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6ddc4055d2..0bfe1b4550 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
+
__slots__ = []
def __len__(self):
@@ -70,10 +73,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {
- k: v for k, v in iteritems(self.types)
- if v is not None
- }
+ self.types = {k: v for k, v in iteritems(self.types) if v is not None}
@staticmethod
def all():
@@ -130,10 +130,7 @@ class StateFilter(object):
Returns:
StateFilter
"""
- return StateFilter(
- types={EventTypes.Member: set(members)},
- include_others=True,
- )
+ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
"""Creates a new StateFilter where type wild cards have been removed
@@ -243,9 +240,7 @@ class StateFilter(object):
if where_clause:
where_clause += " OR "
- where_clause += "type NOT IN (%s)" % (
- ",".join(["?"] * len(self.types)),
- )
+ where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
where_args.extend(self.types)
return where_clause, where_args
@@ -305,12 +300,8 @@ class StateFilter(object):
bool
"""
- return (
- self.include_others
- or any(
- state_keys is None
- for state_keys in itervalues(self.types)
- )
+ return self.include_others or any(
+ state_keys is None for state_keys in itervalues(self.types)
)
def concrete_types(self):
@@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache")
+ 50000 * get_cache_factor_for("stateGroupCache"),
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache")
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
)
@defer.inlineCallbacks
@@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: dict of (type, state_key) -> event_id
"""
+
def _get_current_state_ids_txn(txn):
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
- (room_id,)
+ (room_id,),
)
return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
- return self.runInteraction(
- "get_current_state_ids",
- _get_current_state_ids_txn,
- )
+ return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
return self.runInteraction(
- "get_filtered_current_state_ids",
- _get_filtered_current_state_ids_txn,
+ "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@defer.inlineCallbacks
@@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[str|None]: The canonical alias, if any
"""
- state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
- [(EventTypes.CanonicalAlias, "")]
- ))
+ state = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
@@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
(prev_group, delta_ids), where both may be None.
"""
+
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
- keyvalues={
- "state_group": state_group,
- },
+ keyvalues={"state_group": state_group},
retcol="prev_state_group",
allow_none=True,
)
@@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
- keyvalues={
- "state_group": state_group,
- },
- retcols=("type", "state_key", "event_id",)
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
)
- return _GetStateGroupDelta(prev_group, {
- (row["type"], row["state_key"]): row["event_id"]
- for row in delta_ids
- })
- return self.runInteraction(
- "get_state_group_delta",
- _get_state_group_delta_txn,
- )
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_ids:
defer.returnValue({})
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
@@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in itervalues(group_to_ids)
+ ev_id
+ for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
- get_prev_content=False
+ get_prev_content=False,
)
- defer.returnValue({
- group: [
- state_event_map[v] for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- })
+ defer.returnValue(
+ {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
+ )
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = {}
- chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, state_filter,
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all(),
+ self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
@@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
- args
+ args,
)
results[group].update(
((typ, state_key), event_id)
@@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
- max_entries_returned is not None and
- len(results[group]) == max_entries_returned
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
):
break
@@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
- get_prev_content=False
+ get_prev_content=False,
)
event_to_state = {
@@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self._get_state_group_for_events(
- event_ids,
- )
+ event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
@@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
- keyvalues={
- "event_id": event_id,
- },
+ keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
- @cachedList(cached_method_name="_get_state_group_for_event",
- list_name="event_ids", num_args=1, inlineCallbacks=True)
+ @cachedList(
+ cached_method_name="_get_state_group_for_event",
+ list_name="event_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
column="event_id",
iterable=event_ids,
keyvalues={},
- retcols=("event_id", "state_group",),
+ retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
)
@@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Now we look them up in the member and non-member caches
non_member_state, incomplete_groups_nm, = (
yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache,
- state_filter=non_member_filter,
+ groups, self._state_group_cache, state_filter=non_member_filter
)
)
member_state, incomplete_groups_m, = (
yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache,
- state_filter=member_filter,
+ groups, self._state_group_members_cache, state_filter=member_filter
)
)
@@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups),
- state_filter=db_state_filter,
+ list(incomplete_groups), state_filter=db_state_filter
)
# Now lets update the caches
@@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(state)
- def _get_state_for_groups_using_cache(
- self, groups, cache, state_filter,
- ):
+ def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results, incomplete_groups
- def _insert_into_cache(self, group_to_state_dict, state_filter,
- cache_seq_num_members, cache_seq_num_non_members):
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
"""Inserts results from querying the database into the relevant cache.
Args:
@@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
fetched_keys=non_member_types,
)
- def store_state_group(self, event_id, room_id, prev_group, delta_ids,
- current_state_ids):
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
+
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
@@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._simple_insert_txn(
txn,
table="state_groups",
- values={
- "id": state_group,
- "room_id": room_id,
- "event_id": event_id,
- },
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
# We persist as a delta if we can, while also ensuring the chain
@@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
% (prev_group,)
)
- potential_hops = self._count_state_group_hops_txn(
- txn, prev_group
- )
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
+ values={"state_group": state_group, "prev_state_group": prev_group},
)
self._simple_insert_many_txn(
@@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
- sql = ("""
+ sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
@@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
- """)
+ """
txn.execute(sql, (state_group,))
row = txn.fetchone()
@@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self._background_deduplicate_state,
)
self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn,
table="event_to_state_groups",
values=[
- {
- "state_group": state_group_id,
- "event_id": event_id,
- }
+ {"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
- self._get_state_group_for_event.prefill,
- (event_id,), state_group_id
+ self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
@@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
if max_group is None:
rows = yield self._execute(
- "_background_deduplicate_state", None,
+ "_background_deduplicate_state",
+ None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
@@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC"
" LIMIT 1",
- (new_last_state_group, max_group,)
+ (new_last_state_group, max_group),
)
row = txn.fetchone()
if row:
@@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT state_group FROM state_group_edges"
" WHERE state_group = ?",
- (state_group,)
+ (state_group,),
)
# If we reach a point where we've already started inserting
@@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?",
- (state_group, room_id,)
+ (state_group, room_id),
)
prev_group, = txn.fetchone()
new_last_state_group = state_group
if prev_group:
- potential_hops = self._count_state_group_hops_txn(
- txn, prev_group
- )
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,#
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group],
+ txn, [prev_group]
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group],
+ txn, [state_group]
)
curr_state = curr_state[state_group]
@@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys
delta_state = {
- key: value for key, value in iteritems(curr_state)
+ key: value
+ for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
self._simple_delete_txn(
txn,
table="state_group_edges",
- keyvalues={
- "state_group": state_group,
- }
+ keyvalues={"state_group": state_group},
)
self._simple_insert_txn(
@@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
values={
"state_group": state_group,
"prev_state_group": prev_group,
- }
+ },
)
self._simple_delete_txn(
txn,
table="state_groups_state",
- keyvalues={
- "state_group": state_group,
- }
+ keyvalues={"state_group": state_group},
)
self._simple_insert_many_txn(
@@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
)
if finished:
- yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
+ yield self._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
- txn.execute(
- "DROP INDEX IF EXISTS state_groups_state_id"
- )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
finally:
conn.set_session(autocommit=False)
else:
@@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
- txn.execute(
- "DROP INDEX IF EXISTS state_groups_state_id"
- )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.runWithConnection(reindex_txn)
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 57bc45cdb9..56e42f583d 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -21,10 +21,11 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
-
def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id)
- if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(
+ prev_stream_id
+ ):
return []
def get_current_state_deltas_txn(txn):
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn)
return self.runInteraction(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 580fafeb3a..9cd1e0f9fe 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
-_EventDictReturn = namedtuple("_EventDictReturn", (
- "event_id", "topological_ordering", "stream_ordering",
-))
+_EventDictReturn = namedtuple(
+ "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
+)
def lower_bound(token, engine, inclusive=False):
@@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
- token.topological, token.stream, inclusive,
- "topological_ordering", "stream_ordering",
+ token.topological,
+ token.stream,
+ inclusive,
+ "topological_ordering",
+ "stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
- token.topological, "topological_ordering",
- token.topological, "topological_ordering",
- token.stream, inclusive, "stream_ordering",
+ token.topological,
+ "topological_ordering",
+ token.topological,
+ "topological_ordering",
+ token.stream,
+ inclusive,
+ "stream_ordering",
)
@@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
- token.topological, token.stream, inclusive,
- "topological_ordering", "stream_ordering",
+ token.topological,
+ token.stream,
+ inclusive,
+ "topological_ordering",
+ "stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
- token.topological, "topological_ordering",
- token.topological, "topological_ordering",
- token.stream, inclusive, "stream_ordering",
+ token.topological,
+ "topological_ordering",
+ token.topological,
+ "topological_ordering",
+ token.stream,
+ inclusive,
+ "stream_ordering",
)
@@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
args = []
if event_filter.types:
- clauses.append(
- "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
- )
+ clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
args.extend(event_filter.types)
for typ in event_filter.not_types:
@@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
args.append(typ)
if event_filter.senders:
- clauses.append(
- "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
- )
+ clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
@@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
args.append(sender)
if event_filter.rooms:
- clauses.append(
- "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
- )
+ clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
@@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
+ db_conn,
+ "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
+ "EventsRoomStreamChangeCache",
+ min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
+ "MembershipStreamChangeCache", events_max
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
@@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotImplementedError()
@defer.inlineCallbacks
- def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
- order='DESC'):
+ def get_room_events_stream_for_rooms(
+ self, room_ids, from_key, to_key, limit=0, order='DESC'
+ ):
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(
- self.get_room_events_stream_for_room,
- room_id, from_key, to_key, limit, order=order,
+ for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
+ res = yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.get_room_events_stream_for_room,
+ room_id,
+ from_key,
+ to_key,
+ limit,
+ order=order,
+ )
+ for room_id in rm_ids
+ ],
+ consumeErrors=True,
)
- for room_id in rm_ids
- ], consumeErrors=True))
+ )
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
- room_id for room_id in room_ids
+ room_id
+ for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks
- def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
- order='DESC'):
+ def get_room_events_stream_for_room(
+ self, room_id, from_key, to_key, limit=0, order='DESC'
+ ):
"""Get new room events in stream ordering since `from_key`.
@@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self._get_events(
- [r.event_id for r in rows],
- get_prev_content=True
- )
+ ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
- txn.execute(sql, (user_id, from_id, to_id,))
+ txn.execute(sql, (user_id, from_id, to_id))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
- ret = yield self._get_events(
- [r.event_id for r in rows],
- get_prev_content=True
- )
+ ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=False)
@@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
rows, token = yield self.get_recent_event_ids_for_room(
- room_id, limit, end_token,
+ room_id, limit, end_token
)
logger.debug("stream before")
events = yield self._get_events(
- [r.event_id for r in rows],
- get_prev_content=True
+ [r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction(
- "get_recent_event_ids_for_room", self._paginate_room_events_txn,
- room_id, from_token=end_token, limit=limit,
+ "get_recent_event_ids_for_room",
+ self._paginate_room_events_txn,
+ room_id,
+ from_token=end_token,
+ limit=limit,
)
# We want to return the results in ascending order.
@@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
"""
+
def _f(txn):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
@@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering"
" LIMIT 1"
)
- txn.execute(sql, (room_id, stream_ordering, ))
+ txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.runInteraction(
- "get_room_event_after_stream_ordering", _f,
- )
+ return self.runInteraction("get_room_event_after_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
- "_get_max_topological_txn", self._get_max_topological_txn,
- room_id,
+ "_get_max_topological_txn", self._get_max_topological_txn, room_id
)
defer.returnValue("t%d-%d" % (topo, token))
@@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": event_id},
- retcol="stream_ordering",
+ table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
@@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(lambda row: "t%d-%d" % (
- row["topological_ordering"], row["stream_ordering"],)
+ ).addCallback(
+ lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
def get_max_topological_token(self, room_id, stream_key):
@@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
- "get_max_topological_token", None,
- sql, room_id, stream_key,
- ).addCallback(
- lambda r: r[0][0] if r else 0
- )
+ "get_max_topological_token", None, sql, room_id, stream_key
+ ).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
- "SELECT MAX(topological_ordering) FROM events"
- " WHERE room_id = ?",
- (room_id,)
+ "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+ (room_id,),
)
rows = txn.fetchall()
@@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
- internal.order = (
- int(topo) if topo else 0,
- int(stream),
- )
+ internal.order = (int(topo) if topo else 0, int(stream))
@defer.inlineCallbacks
def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None,
+ self, room_id, event_id, before_limit, after_limit, event_filter=None
):
"""Retrieve events and pagination tokens around a given event in a
room.
@@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = yield self.runInteraction(
- "get_events_around", self._get_events_around_txn,
- room_id, event_id, before_limit, after_limit, event_filter,
+ "get_events_around",
+ self._get_events_around_txn,
+ room_id,
+ event_id,
+ before_limit,
+ after_limit,
+ event_filter,
)
events_before = yield self._get_events(
- [e for e in results["before"]["event_ids"]],
- get_prev_content=True
+ [e for e in results["before"]["event_ids"]], get_prev_content=True
)
events_after = yield self._get_events(
- [e for e in results["after"]["event_ids"]],
- get_prev_content=True
+ [e for e in results["after"]["event_ids"]], get_prev_content=True
)
- defer.returnValue({
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- })
+ defer.returnValue(
+ {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": results["before"]["token"],
+ "end": results["after"]["token"],
+ }
+ )
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter,
+ self, txn, room_id, event_id, before_limit, after_limit, event_filter
):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
@@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = self._simple_select_one_txn(
txn,
"events",
- keyvalues={
- "event_id": event_id,
- "room_id": room_id,
- },
+ keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
)
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- results["topological_ordering"] - 1,
- results["stream_ordering"],
+ results["topological_ordering"] - 1, results["stream_ordering"]
)
after_token = RoomStreamToken(
- results["topological_ordering"],
- results["stream_ordering"],
+ results["topological_ordering"], results["stream_ordering"]
)
rows, start_token = self._paginate_room_events_txn(
- txn, room_id, before_token, direction='b', limit=before_limit,
+ txn,
+ room_id,
+ before_token,
+ direction='b',
+ limit=before_limit,
event_filter=event_filter,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
- txn, room_id, after_token, direction='f', limit=after_limit,
+ txn,
+ room_id,
+ after_token,
+ direction='f',
+ limit=after_limit,
event_filter=event_filter,
)
events_after = [r.event_id for r in rows]
return {
- "before": {
- "event_ids": events_before,
- "token": start_token,
- },
- "after": {
- "event_ids": events_after,
- "token": end_token,
- },
+ "before": {"event_ids": events_before, "token": start_token},
+ "after": {"event_ids": events_after, "token": end_token},
}
@defer.inlineCallbacks
@@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
- "get_all_new_events_stream", get_all_new_events_stream_txn,
+ "get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self._get_events(event_ids)
@@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
- desc="get_federation_out_pos"
+ desc="get_federation_out_pos",
)
def update_federation_out_pos(self, typ, stream_id):
@@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
- def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
- direction='b', limit=-1, event_filter=None):
+ def _paginate_room_events_txn(
+ self,
+ txn,
+ room_id,
+ from_token,
+ to_token=None,
+ direction='b',
+ limit=-1,
+ event_filter=None,
+ ):
"""Returns list of events before or after a given token.
Args:
@@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(
- from_token, self.database_engine
- )
+ bounds = upper_bound(from_token, self.database_engine)
if to_token:
- bounds = "%s AND %s" % (bounds, lower_bound(
- to_token, self.database_engine
- ))
+ bounds = "%s AND %s" % (
+ bounds,
+ lower_bound(to_token, self.database_engine),
+ )
else:
order = "ASC"
- bounds = lower_bound(
- from_token, self.database_engine
- )
+ bounds = lower_bound(from_token, self.database_engine)
if to_token:
- bounds = "%s AND %s" % (bounds, upper_bound(
- to_token, self.database_engine
- ))
+ bounds = "%s AND %s" % (
+ bounds,
+ upper_bound(to_token, self.database_engine),
+ )
filter_clause, filter_args = filter_to_clause(event_filter)
@@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
- ) % {
- "bounds": bounds,
- "order": order,
- }
+ ) % {"bounds": bounds, "order": order}
txn.execute(sql, args)
@@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, str(next_token),
+ return rows, str(next_token)
@defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
+ def paginate_room_events(
+ self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
+ ):
"""Returns list of events before or after a given token.
Args:
@@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
- "paginate_room_events", self._paginate_room_events_txn,
- room_id, from_key, to_key, direction, limit, event_filter,
+ "paginate_room_events",
+ self._paginate_room_events_txn,
+ room_id,
+ from_key,
+ to_key,
+ direction,
+ limit,
+ event_filter,
)
events = yield self._get_events(
- [r.event_id for r in rows],
- get_prev_content=True
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(events, rows)
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 0f657b2bd3..e88f8ea35f 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -84,9 +84,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
def get_tag_content(txn, tag_ids):
sql = (
- "SELECT tag, content"
- " FROM room_tags"
- " WHERE user_id=? AND room_id=?"
+ "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
)
results = []
for stream_id, user_id, room_id in tag_ids:
@@ -105,7 +103,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
- tag_ids[i:i + batch_size],
+ tag_ids[i : i + batch_size],
)
results.extend(tags)
@@ -123,6 +121,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
+
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
@@ -138,9 +137,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
defer.returnValue({})
- room_ids = yield self.runInteraction(
- "get_updated_tags", get_updated_tags_txn
- )
+ room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
results = {}
if room_ids:
@@ -163,9 +160,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
- ).addCallback(lambda rows: {
- row["tag"]: json.loads(row["content"]) for row in rows
- })
+ ).addCallback(
+ lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
+ )
class TagsStore(TagsWorkerStore):
@@ -186,14 +183,8 @@ class TagsStore(TagsWorkerStore):
self._simple_upsert_txn(
txn,
table="room_tags",
- keyvalues={
- "user_id": user_id,
- "room_id": room_id,
- "tag": tag,
- },
- values={
- "content": content_json,
- }
+ keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
+ values={"content": content_json},
)
self._update_revision_txn(txn, user_id, room_id, next_id)
@@ -211,6 +202,7 @@ class TagsStore(TagsWorkerStore):
Returns:
A deferred that completes once the tag has been removed
"""
+
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
@@ -238,8 +230,7 @@ class TagsStore(TagsWorkerStore):
"""
txn.call_after(
- self._account_data_stream_cache.entity_has_changed,
- user_id, next_id
+ self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
update_max_id_sql = (
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index d8bf953ec0..b1188f6bcb 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -38,16 +38,12 @@ logger = logging.getLogger(__name__)
_TransactionRow = namedtuple(
- "_TransactionRow", (
- "id", "transaction_id", "destination", "ts", "response_code",
- "response_json",
- )
+ "_TransactionRow",
+ ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
)
_UpdateTransactionRow = namedtuple(
- "_TransactionRow", (
- "response_code", "response_json",
- )
+ "_TransactionRow", ("response_code", "response_json")
)
SENTINEL = object()
@@ -84,19 +80,22 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction(
"get_received_txn_response",
- self._get_received_txn_response, transaction_id, origin
+ self._get_received_txn_response,
+ transaction_id,
+ origin,
)
def _get_received_txn_response(self, txn, transaction_id, origin):
result = self._simple_select_one_txn(
txn,
table="received_transactions",
- keyvalues={
- "transaction_id": transaction_id,
- "origin": origin,
- },
+ keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=(
- "transaction_id", "origin", "ts", "response_code", "response_json",
+ "transaction_id",
+ "origin",
+ "ts",
+ "response_code",
+ "response_json",
"has_been_referenced",
),
allow_none=True,
@@ -108,8 +107,7 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_received_txn_response(self, transaction_id, origin, code,
- response_dict):
+ def set_received_txn_response(self, transaction_id, origin, code, response_dict):
"""Persist the response we returened for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
@@ -135,8 +133,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- def prep_send_transaction(self, transaction_id, destination,
- origin_server_ts):
+ def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@@ -182,7 +179,9 @@ class TransactionStore(SQLBaseStore):
result = yield self.runInteraction(
"get_destination_retry_timings",
- self._get_destination_retry_timings, destination)
+ self._get_destination_retry_timings,
+ destination,
+ )
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
@@ -193,9 +192,7 @@ class TransactionStore(SQLBaseStore):
result = self._simple_select_one_txn(
txn,
table="destinations",
- keyvalues={
- "destination": destination,
- },
+ keyvalues={"destination": destination},
retcols=("destination", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -205,8 +202,7 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(self, destination,
- retry_last_ts, retry_interval):
+ def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
@@ -225,8 +221,9 @@ class TransactionStore(SQLBaseStore):
retry_interval,
)
- def _set_destination_retry_timings(self, txn, destination,
- retry_last_ts, retry_interval):
+ def _set_destination_retry_timings(
+ self, txn, destination, retry_last_ts, retry_interval
+ ):
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@@ -235,9 +232,7 @@ class TransactionStore(SQLBaseStore):
prev_row = self._simple_select_one_txn(
txn,
table="destinations",
- keyvalues={
- "destination": destination,
- },
+ keyvalues={"destination": destination},
retcols=("retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -250,15 +245,13 @@ class TransactionStore(SQLBaseStore):
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
- }
+ },
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
self._simple_update_one_txn(
txn,
"destinations",
- keyvalues={
- "destination": destination,
- },
+ keyvalues={"destination": destination},
updatevalues={
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
@@ -273,8 +266,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
- "get_destinations_needing_retry",
- self._get_destinations_needing_retry
+ "get_destinations_needing_retry", self._get_destinations_needing_retry
)
def _get_destinations_needing_retry(self, txn):
@@ -288,7 +280,7 @@ class TransactionStore(SQLBaseStore):
def _start_cleanup_transactions(self):
return run_as_background_process(
- "cleanup_transactions", self._cleanup_transactions,
+ "cleanup_transactions", self._cleanup_transactions
)
def _cleanup_transactions(self):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 4d60a5726f..83466e25d9 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -194,7 +194,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
room_id
)
- users_with_profile = yield state.get_current_user_in_room(room_id)
+ users_with_profile = yield state.get_current_users_in_room(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
index be013f4427..1815fdc0dd 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/user_erasure_store.py
@@ -40,9 +40,7 @@ class UserErasureWorkerStore(SQLBaseStore):
).addCallback(operator.truth)
@cachedList(
- cached_method_name="is_user_erased",
- list_name="user_ids",
- inlineCallbacks=True,
+ cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
"""
@@ -61,16 +59,13 @@ class UserErasureWorkerStore(SQLBaseStore):
def _get_erased_users(txn):
txn.execute(
- "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
- ",".join("?" * len(user_ids))
- ),
+ "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
+ % (",".join("?" * len(user_ids))),
user_ids,
)
return set(r[0] for r in txn)
- erased_users = yield self.runInteraction(
- "are_users_erased", _get_erased_users,
- )
+ erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids)
defer.returnValue(res)
@@ -82,22 +77,16 @@ class UserErasureStore(UserErasureWorkerStore):
Args:
user_id (str): full user_id to be erased
"""
+
def f(txn):
# first check if they are already in the list
- txn.execute(
- "SELECT 1 FROM erased_users WHERE user_id = ?",
- (user_id, )
- )
+ txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if txn.fetchone():
return
# they are not already there: do the insert.
- txn.execute(
- "INSERT INTO erased_users (user_id) VALUES (?)",
- (user_id, )
- )
+ txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,))
+
+ self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.is_user_erased, (user_id,)
- )
return self.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d6160d5e4d..f1c8d99419 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -43,9 +43,9 @@ def _load_current_id(db_conn, table, column, step=1):
"""
cur = db_conn.cursor()
if step == 1:
- cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
- cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
+ cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
val, = cur.fetchone()
cur.close()
current_id = int(val) if val else step
@@ -77,6 +77,7 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
+
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock()
@@ -84,8 +85,7 @@ class StreamIdGenerator(object):
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
- self._current,
- _load_current_id(db_conn, table, column, step)
+ self._current, _load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
@@ -121,7 +121,7 @@ class StreamIdGenerator(object):
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
- self._step
+ self._step,
)
self._current += n * self._step
|