diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6492fa011f..77be1b2952 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -1,3 +1,16 @@
+Upgrading to v1.20.0
+====================
+
+Shared rooms endpoint (MSC2666)
+-------------------------------
+
+This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*`
+for fetching rooms one user has in common with another. This feature requires the
+`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir`
+worker, requests to this endpoint must be handled by that worker.
+See `docs/workers.md <docs/workers.md>`_ for more details.
+
+
Upgrading Synapse
=================
diff --git a/changelog.d/7785.feature b/changelog.d/7785.feature
new file mode 100644
index 0000000000..c7e51c9320
--- /dev/null
+++ b/changelog.d/7785.feature
@@ -0,0 +1 @@
+Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666).
diff --git a/changelog.d/8059.feature b/changelog.d/8059.feature
new file mode 100644
index 0000000000..feb02be234
--- /dev/null
+++ b/changelog.d/8059.feature
@@ -0,0 +1 @@
+Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
diff --git a/changelog.d/8156.misc b/changelog.d/8156.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8156.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8170.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8189.doc b/changelog.d/8189.doc
new file mode 100644
index 0000000000..800ff89dc5
--- /dev/null
+++ b/changelog.d/8189.doc
@@ -0,0 +1 @@
+Explain better what GDPR-erased means when deactivating a user.
diff --git a/changelog.d/8196.misc b/changelog.d/8196.misc
new file mode 100644
index 0000000000..c42baf0e56
--- /dev/null
+++ b/changelog.d/8196.misc
@@ -0,0 +1 @@
+Fix `wait_for_stream_position` to allow multiple waiters on same stream ID.
diff --git a/changelog.d/8197.misc b/changelog.d/8197.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8197.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8199.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8200.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8201.misc b/changelog.d/8201.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8201.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8202.misc b/changelog.d/8202.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8202.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8203.misc b/changelog.d/8203.misc
new file mode 100644
index 0000000000..9fe2224aaa
--- /dev/null
+++ b/changelog.d/8203.misc
@@ -0,0 +1 @@
+Make `MultiWriterIDGenerator` work for streams that use negative values.
diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8204.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc
new file mode 100644
index 0000000000..fb8fd83278
--- /dev/null
+++ b/changelog.d/8205.misc
@@ -0,0 +1 @@
+ Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8207.misc b/changelog.d/8207.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8207.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8213.misc b/changelog.d/8213.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8213.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8214.misc b/changelog.d/8214.misc
new file mode 100644
index 0000000000..e26764dea1
--- /dev/null
+++ b/changelog.d/8214.misc
@@ -0,0 +1 @@
+ Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8222.misc b/changelog.d/8222.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8222.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8223.bugfix b/changelog.d/8223.bugfix
new file mode 100644
index 0000000000..60655ce3e1
--- /dev/null
+++ b/changelog.d/8223.bugfix
@@ -0,0 +1 @@
+Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events.
diff --git a/changelog.d/8224.misc b/changelog.d/8224.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8224.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8225.misc b/changelog.d/8225.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8225.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8226.bugfix b/changelog.d/8226.bugfix
new file mode 100644
index 0000000000..60bdff576d
--- /dev/null
+++ b/changelog.d/8226.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where stats updates could break when unexpected profile data was included in events.
diff --git a/changelog.d/8231.misc b/changelog.d/8231.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8231.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8232.misc b/changelog.d/8232.misc
new file mode 100644
index 0000000000..3a7a352c4f
--- /dev/null
+++ b/changelog.d/8232.misc
@@ -0,0 +1 @@
+Add type hints to `StreamStore`.
diff --git a/changelog.d/8235.misc b/changelog.d/8235.misc
new file mode 100644
index 0000000000..3a7a352c4f
--- /dev/null
+++ b/changelog.d/8235.misc
@@ -0,0 +1 @@
+Add type hints to `StreamStore`.
diff --git a/changelog.d/8237.misc b/changelog.d/8237.misc
new file mode 100644
index 0000000000..29d946cde6
--- /dev/null
+++ b/changelog.d/8237.misc
@@ -0,0 +1 @@
+Fix type hints in `SyncHandler`.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index d6e3194cda..e21c78a9c6 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -214,9 +214,11 @@ Deactivate Account
This API deactivates an account. It removes active access tokens, resets the
password, and deletes third-party IDs (to prevent the user requesting a
-password reset). It can also mark the user as GDPR-erased (stopping their data
-from distributed further, and deleting it entirely if there are no other
-references to it).
+password reset).
+
+It can also mark the user as GDPR-erased. This means messages sent by the
+user will still be visible by anyone that was in the room when these messages
+were sent, but hidden from users joining the room afterwards.
The api is::
diff --git a/docs/workers.md b/docs/workers.md
index bfec745897..7a8f5c89fc 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
+ ^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$
When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background
diff --git a/mypy.ini b/mypy.ini
index 4213e31b03..ae3290d5bb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,6 +28,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
+ synapse/http/federation/well_known_resolver.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging/,
@@ -42,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
+ synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2b2cd795e0..a43dc5b2c9 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
can run out of file descriptors and infinite loop if we attempt to do too
many DNS queries at once
+
+ XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
+ you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
+ backed by the reactor's default threadpool (which is limited to 10 threads). So
+ (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
+ understand why we would run out of FDs if we did too many lookups at once.
+ -- richvdh 2020/08/29
"""
new_resolver = _LimitedHostnameResolver(
reactor.nameResolver, max_dns_requests_in_flight
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..b6c9085670 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
pass
-@defer.inlineCallbacks
-def export_data_command(hs, args):
+async def export_data_command(hs, args):
"""Export data for a user.
Args:
@@ -91,10 +90,8 @@ def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = yield defer.ensureDeferred(
- hs.get_handlers().admin_handler.export_user_data(
- user_id, FileExfiltrationWriter(user_id, directory=directory)
- )
+ res = await hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@@ -232,14 +229,15 @@ def start(config_options):
# We also make sure that `_base.start` gets run before we actually run the
# command.
- @defer.inlineCallbacks
- def run(_reactor):
+ async def run():
with LoggingContext("command"):
- yield _base.start(ss, [])
- yield args.func(ss, args)
+ _base.start(ss, [])
+ await args.func(ss, args)
_base.start_worker_reactor(
- "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+ "synapse-admin-cmd",
+ config,
+ run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 98d0d14a12..6014adc850 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -411,26 +411,24 @@ def setup(config_options):
return provision
- @defer.inlineCallbacks
- def reprovision_acme():
+ async def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
- reprovisioned = yield defer.ensureDeferred(do_acme())
+ reprovisioned = await do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
- @defer.inlineCallbacks
- def start():
+ async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
- yield defer.ensureDeferred(acme.start_listening())
- yield defer.ensureDeferred(do_acme())
+ await acme.start_listening()
+ await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -439,8 +437,8 @@ def setup(config_options):
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
- yield defer.ensureDeferred(oidc.load_metadata())
- yield defer.ensureDeferred(oidc.load_jwks())
+ await oidc.load_metadata()
+ await oidc.load_jwks()
_base.start(hs, config.listeners)
@@ -456,7 +454,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
return hs
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e72a0b9ac0..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,18 +14,20 @@
# limitations under the License.
import logging
import urllib
+from typing import TYPE_CHECKING, Optional
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.appservice import ApplicationService
+
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
@@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex)
return []
- def get_3pe_protocol(self, service, protocol):
+ async def get_3pe_protocol(
+ self, service: "ApplicationService", protocol: str
+ ) -> Optional[JsonDict]:
if service.url is None:
return {}
- @defer.inlineCallbacks
- def _get():
+ async def _get() -> Optional[JsonDict]:
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
- info = yield defer.ensureDeferred(self.get_json(uri, {}))
+ info = await self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning(
@@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
return None
key = (service.id, protocol)
- return self.protocol_meta_cache.wrap(key, _get)
+ return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, txn_id=None):
if service.url is None:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1417487427..73f0717b0d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
-
- # If multiple instances are not defined we always return true.
+ # If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True
+ return self.get_instance(key) == instance_name
+
+ def get_instance(self, key: str) -> str:
+ """Get the instance responsible for handling the given key.
+
+ Note: For things like federation sending the config for which instance
+ is sending is known only to the sender instance if there is only one.
+ Therefore `should_handle` should be used where possible.
+ """
+
+ if not self.instances:
+ return "master"
+
+ if len(self.instances) == 1:
+ return self.instances[0]
+
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
- return self.instances[remainder] == instance_name
+ return self.instances[remainder]
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
+ def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Union
+
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
+def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
+ """Helper for allowing parsing a string or list of strings to a config
+ option expecting a list of strings.
+ """
+
+ if isinstance(obj, str):
+ return [obj]
+ return obj
+
+
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
- events: The instance that writes to the event and backfill streams.
- events: The instance that writes to the typing stream.
+ events: The instances that write to the event and backfill streams.
+ typing: The instance that writes to the typing stream.
"""
- events = attr.ib(default="master", type=str)
+ events = attr.ib(
+ default=["master"], type=List[str], converter=_instance_to_list_converter
+ )
typing = attr.ib(default="master", type=str)
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
- # Check that the configured writer for events and typing also appears in
+ # Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
- instance = getattr(self.writers, stream)
- if instance != "master" and instance not in self.instance_map:
- raise ConfigError(
- "Instance %r is configured to write %s but does not appear in `instance_map` config."
- % (instance, stream)
- )
+ instances = _instance_to_list_converter(getattr(self.writers, stream))
+ for instance in instances:
+ if instance != "master" and instance not in self.instance_map:
+ raise ConfigError(
+ "Instance %r is configured to write %s but does not appear in `instance_map` config."
+ % (instance, stream)
+ )
+
+ self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 67db763dbf..62ea44fa49 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
-from typing import Dict, Optional, Type
+from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
- order = DictProperty("order") # type: int
+ order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict:
return dict(self._dict)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 9ed24380dd..7878cd7044 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
from nacl.signing import SigningKey
@@ -97,14 +97,14 @@ class EventBuilder(object):
def is_state(self):
return self._state_key is not None
- async def build(self, prev_event_ids):
+ async def build(self, prev_event_ids: List[str]) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
- prev_event_ids (list[str]): The event IDs to use as the prev events
+ prev_event_ids: The event IDs to use as the prev events
Returns:
- FrozenEvent
+ The signed and hashed event.
"""
state_ids = await self._state.get_current_state_ids(
@@ -114,8 +114,13 @@ class EventBuilder(object):
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
- auth_events = await self._store.add_event_hashes(auth_ids)
- prev_events = await self._store.add_event_hashes(prev_event_ids)
+ # The types of auth/prev events changes between event versions.
+ auth_events = await self._store.add_event_hashes(
+ auth_ids
+ ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+ prev_events = await self._store.add_event_hashes(
+ prev_event_ids
+ ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
auth_events = auth_ids
prev_events = prev_event_ids
@@ -138,7 +143,7 @@ class EventBuilder(object):
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
- }
+ } # type: Dict[str, Any]
if self.is_state():
event_dict["state_key"] = self._state_key
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60de..ee4666337a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
return result
async def on_federation_query_user_devices(self, user_id):
- stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+ stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
+ user_id
+ )
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d8def45e38..dfd1c78549 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
- devices = await self.store.get_e2e_device_keys([(user_id, None)])
+ devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 16389a0dca..bd8efbb768 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -923,7 +923,8 @@ class FederationHandler(BaseHandler):
)
)
- await self._handle_new_events(dest, ev_infos, backfilled=True)
+ if ev_infos:
+ await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
- destination, event_infos,
+ destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
- origin, auth_chain, state, event, room_version_obj
+ origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
- #
- # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
- self.config.worker.writers.events, "events", max_stream_id
+ self.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify([(event, context)])
+ await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify([(event, context)])
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event, stream_id
@@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
- [(event, context)], backfilled=backfilled
+ event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
+ room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
+ room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
+ room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
+ room_id,
auth_events
state
event
@@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
+ room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ]
+ ],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
- return await self.persist_events_and_notify([(event, new_event_context)])
+ return await self.persist_events_and_notify(
+ room_id, [(event, new_event_context)]
+ )
async def _prep_event(
self,
@@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
+ room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler):
necessary.
Args:
- event_and_contexts:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
- if self.config.worker.writers.events != self._instance_name:
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
result = await self._send_events(
- instance_name=self.config.worker.writers.events,
+ instance_name=instance,
store=self.store,
+ room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a48c69163..0016af44be 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import (
- Collection,
- Requester,
- RoomAlias,
- StreamToken,
- UserID,
- create_requester,
-)
+from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
@@ -383,9 +376,8 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
- self._is_event_writer = (
- self.config.worker.writers.events == hs.get_instance_name()
- )
+ self._events_shard_config = self.config.worker.events_shard_config
+ self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
@@ -448,7 +440,7 @@ class EventCreationHandler(object):
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@@ -788,7 +780,7 @@ class EventCreationHandler(object):
self,
builder: EventBuilder,
requester: Optional[Requester] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -913,9 +905,10 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
- if not self._is_event_writer:
+ writer_instance = self._events_shard_config.get_instance(event.room_id)
+ if writer_instance != self._instance_name:
result = await self.send_event(
- instance_name=self.config.worker.writers.events,
+ instance_name=writer_instance,
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -983,7 +976,9 @@ class EventCreationHandler(object):
This should only be run on the instance in charge of persisting events.
"""
- assert self._is_event_writer
+ assert self._events_shard_config.should_handle(
+ self._instance_name, event.room_id
+ )
if ratelimit:
# We check if this is a room admin redacting an event so that we
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ac3418d69d..5a1aa7d830 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any, Dict, Optional
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -247,15 +250,16 @@ class PaginationHandler(object):
)
return purge_id
- async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+ async def _purge_history(
+ self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
"""Carry out a history purge on a room.
Args:
- purge_id (str): The id for this purge
- room_id (str): The room to purge from
- token (str): topological token to delete events before
- delete_local_events (bool): True to delete local events as well as
- remote ones
+ purge_id: The id for this purge
+ room_id: The room to purge from
+ token: topological token to delete events before
+ delete_local_events: True to delete local events as well as remote ones
"""
self._purges_in_progress_by_room.add(room_id)
try:
@@ -291,9 +295,9 @@ class PaginationHandler(object):
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
- with (await self.pagination_lock.write(room_id)):
+ with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
@@ -307,23 +311,22 @@ class PaginationHandler(object):
async def get_messages(
self,
- requester,
- room_id=None,
- pagin_config=None,
- as_client_event=True,
- event_filter=None,
- ):
+ requester: Requester,
+ room_id: Optional[str] = None,
+ pagin_config: Optional[PaginationConfig] = None,
+ as_client_event: bool = True,
+ event_filter: Optional[Filter] = None,
+ ) -> Dict[str, Any]:
"""Get messages in a room.
Args:
- requester (Requester): The user requesting messages.
- room_id (str): The room they want messages from.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config rules to apply, if any.
- as_client_event (bool): True to get events in client-server format.
- event_filter (Filter): Filter to apply to results or None
+ requester: The user requesting messages.
+ room_id: The room they want messages from.
+ pagin_config: The pagination config rules to apply, if any.
+ as_client_event: True to get events in client-server format.
+ event_filter: Filter to apply to results or None
Returns:
- dict: Pagination API results
+ Pagination API results
"""
user_id = requester.user.to_string()
@@ -343,7 +346,7 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (await self.pagination_lock.read(room_id)):
+ with await self.pagination_lock.read(room_id):
(
membership,
member_event_id,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 96c9d6bab4..0cb8fad89a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -161,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
Codes.FORBIDDEN,
)
+ if not isinstance(new_displayname, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -235,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
+ if not isinstance(new_avatar_url, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9d5b1828df..55794c3057 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", last_stream_id
+ self.hs.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ last_stream_id,
)
return result, last_stream_id
@@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object):
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
- #
- # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
+ self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+ "events",
+ stream_id,
)
else:
new_room_id = None
@@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object):
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
+ self.hs.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1017ae6b19..ed1d1bd83d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
-from synapse.types import (
- Collection,
- JsonDict,
- Requester,
- RoomAlias,
- RoomID,
- StateMap,
- UserID,
-)
+from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -91,13 +83,6 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
- self._event_stream_writer_instance = hs.config.worker.writers.events
- self._is_on_event_persistence_instance = (
- self._event_stream_writer_instance == hs.get_instance_name()
- )
- if self._is_on_event_persistence_instance:
- self.persist_event_storage = hs.get_storage().persistence
-
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -185,7 +170,7 @@ class RoomMemberHandler(object):
target: UserID,
room_id: str,
membership: str,
- prev_event_ids: Collection[str],
+ prev_event_ids: List[str],
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 8118206f8e..c281ff163a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,7 +16,7 @@
import itertools
import logging
-from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -96,7 +99,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
@@ -105,6 +113,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
+ unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -239,7 +248,7 @@ class SyncResult:
class SyncHandler(object):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -714,9 +723,8 @@ class SyncHandler(object):
]
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
- missing_hero_state = missing_hero_state.values()
- for s in missing_hero_state:
+ for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
@@ -934,7 +942,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> Optional[Dict[str, str]]:
+ ) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -942,15 +950,10 @@ class SyncHandler(object):
receipt_type="m.read",
)
- if last_unread_event_id:
- notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
- room_id, sync_config.user.to_string(), last_unread_event_id
- )
- return notifs
-
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- return None
+ notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ return notifs
async def generate_sync_result(
self,
@@ -1773,7 +1776,7 @@ class SyncHandler(object):
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
- tags: Optional[List[JsonDict]],
+ tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
@@ -1889,7 +1892,7 @@ class SyncHandler(object):
)
if room_builder.rtype == "joined":
- unread_notifications = {} # type: Dict[str, str]
+ unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1898,14 +1901,16 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
+ unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- if notifs is not None:
- unread_notifications["notification_count"] = notifs["notify_count"]
- unread_notifications["highlight_count"] = notifs["highlight_count"]
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+ room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 521b6d620d..e21f8dbc58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ
):
- """Handle a room having potentially changed from/to world_readable/publically
+ """Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
@@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
+ # If the new name is an unexpected form, do not update the directory.
+ if not isinstance(new_name, str):
+ new_name = prev_name
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
+ # If the new avatar is an unexpected form, do not update the directory.
+ if not isinstance(new_avatar, str):
+ new_avatar = prev_avatar
if prev_name != new_name or prev_avatar != new_avatar:
await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 369bf9c2fc..782d39d4ca 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
and not _is_ip_literal(parsed_uri.hostname)
and not parsed_uri.port
):
- well_known_result = yield self._well_known_resolver.get_well_known(
- parsed_uri.hostname
+ well_known_result = yield defer.ensureDeferred(
+ self._well_known_resolver.get_well_known(parsed_uri.hostname)
)
delegated_server = well_known_result.delegated_server
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index e701dcc961..37c29c008a 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -16,6 +16,7 @@
import logging
import random
import time
+from typing import Callable, Dict, Optional, Tuple
import attr
@@ -23,6 +24,7 @@ from twisted.internet import defer
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -99,15 +101,14 @@ class WellKnownResolver(object):
self._well_known_agent = RedirectAgent(agent)
self.user_agent = user_agent
- @defer.inlineCallbacks
- def get_well_known(self, server_name):
+ async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
"""Attempt to fetch and parse a .well-known file for the given server
Args:
- server_name (bytes): name of the server, from the requested url
+ server_name: name of the server, from the requested url
Returns:
- Deferred[WellKnownLookupResult]: The result of the lookup
+ The result of the lookup
"""
if server_name == b"kde.org":
@@ -128,7 +129,9 @@ class WellKnownResolver(object):
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
- result, cache_period = yield self._fetch_well_known(server_name)
+ result, cache_period = await self._fetch_well_known(
+ server_name
+ ) # type: Tuple[Optional[bytes], float]
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@@ -157,18 +160,17 @@ class WellKnownResolver(object):
return WellKnownLookupResult(delegated_server=result)
- @defer.inlineCallbacks
- def _fetch_well_known(self, server_name):
+ async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
"""Actually fetch and parse a .well-known, without checking the cache
Args:
- server_name (bytes): name of the server, from the requested url
+ server_name: name of the server, from the requested url
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns:
- Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+ The lookup result and cache period.
"""
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -176,7 +178,7 @@ class WellKnownResolver(object):
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
- response, body = yield self._make_well_known_request(
+ response, body = await self._make_well_known_request(
server_name, retry=had_valid_well_known
)
@@ -219,20 +221,20 @@ class WellKnownResolver(object):
return result, cache_period
- @defer.inlineCallbacks
- def _make_well_known_request(self, server_name, retry):
+ async def _make_well_known_request(
+ self, server_name: bytes, retry: bool
+ ) -> Tuple[IResponse, bytes]:
"""Make the well known request.
This will retry the request if requested and it fails (with unable
to connect or receives a 5xx error).
Args:
- server_name (bytes)
- retry (bool): Whether to retry the request if it fails.
+ server_name: name of the server, from the requested url
+ retry: Whether to retry the request if it fails.
Returns:
- Deferred[tuple[IResponse, bytes]] Returns the response object and
- body. Response may be a non-200 response.
+ Returns the response object and body. Response may be a non-200 response.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
@@ -247,12 +249,12 @@ class WellKnownResolver(object):
logger.info("Fetching %s", uri_str)
try:
- response = yield make_deferred_yieldable(
+ response = await make_deferred_yieldable(
self._well_known_agent.request(
b"GET", uri, headers=Headers(headers)
)
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -269,21 +271,24 @@ class WellKnownResolver(object):
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
# Sleep briefly in the hopes that they come back up
- yield self._clock.sleep(0.5)
+ await self._clock.sleep(0.5)
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+ headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
cache_controls = _parse_cache_control(headers)
if b"no-store" in cache_controls:
return 0
if b"max-age" in cache_controls:
- try:
- max_age = int(cache_controls[b"max-age"])
- return max_age
- except ValueError:
- pass
+ max_age = cache_controls[b"max-age"]
+ if max_age:
+ try:
+ return int(max_age)
+ except ValueError:
+ pass
expires = headers.getRawHeaders(b"expires")
if expires is not None:
@@ -299,7 +304,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
return None
-def _parse_cache_control(headers):
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []):
for directive in hdr.split(b","):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..e7fa02b78b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
)
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+ EventTypes.Topic,
+ EventTypes.Name,
+ EventTypes.RoomAvatar,
+ EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+ # Exclude rejected and soft-failed events.
+ if context.rejected or event.internal_metadata.is_soft_failed():
+ return False
+
+ # Exclude notices.
+ if (
+ not event.is_state()
+ and event.type == EventTypes.Message
+ and event.content.get("msgtype") == "m.notice"
+ ):
+ return False
+
+ # Exclude edits.
+ relates_to = event.content.get("m.relates_to", {})
+ if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ return False
+
+ # Mark events that have a non-empty string body as unread.
+ body = event.content.get("body")
+ if isinstance(body, str) and body:
+ return True
+
+ # Mark some state events as unread.
+ if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+ return True
+
+ # Mark encrypted events as unread.
+ if not event.is_state() and event.type == EventTypes.Encrypted:
+ return True
+
+ return False
+
+
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None:
- """Given an event and context, evaluate the push rules and insert the
- results into the event_push_actions_staging table.
+ """Given an event and context, evaluate the push rules, check if the message
+ should increment the unread count, and insert the results into the
+ event_push_actions_staging table.
"""
+ count_as_unread = _should_count_as_unread(event, context)
+
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
+ actions_by_user[uid] = []
+
for rule in rules:
if "enabled" in rule and not rule["enabled"]:
continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
- await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+ await self.store.add_push_actions_to_staging(
+ event.event_id, actions_by_user, count_as_unread,
+ )
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (list): List of event ids for membership events that
- have happened since the last time we filled rules_by_user
+ member_event_ids (dict): Dict of user id to event id for membership events
+ that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
"""
@@ -390,34 +441,19 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- interested_in_user_ids = {
+ user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", interested_in_user_ids)
-
- if_users_with_pushers = await self.store.get_if_users_have_pushers(
- interested_in_user_ids, on_invalidate=self.invalidate_all_cb
- )
-
- user_ids = {
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- }
-
- logger.debug("With pushers: %r", user_ids)
-
- users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
- self.room_id, on_invalidate=self.invalidate_all_cb
- )
-
- logger.debug("With receipts: %r", users_with_receipts)
+ logger.debug("Joined: %r", user_ids)
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in interested_in_user_ids:
- user_ids.add(uid)
+ # Previously we only considered users with pushers or read receipts in that
+ # room. We can't do this anymore because we use push actions to calculate unread
+ # counts, which don't rely on the user having pushers or sent a read receipt into
+ # the room. Therefore we just need to filter for local users here.
+ user_ids = list(filter(self.is_mine_id, user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..f7a25571f3 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ badge += 1 if notifs["unread_count"] else 0
return badge
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index dd77a44b8d..2d995ec456 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,7 +66,9 @@ REQUIREMENTS = [
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"prometheus_client>=0.0.18,<0.9.0",
- # we use attr.validators.deep_iterable, which arrived in 19.1.0
+ # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
+ # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
+ # is out in November.)
"attrs>=19.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5c8be747e1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
- async def _serialize_payload(store, event_and_contexts, backfilled):
+ async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
"""
Args:
store
+ room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
- payload = {"events": event_payloads, "backfilled": backfilled}
+ payload = {
+ "events": event_payloads,
+ "backfilled": backfilled,
+ "room_id": room_id,
+ }
return payload
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
+ room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
- event_and_contexts, backfilled
+ room_id, event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 596c72eb92..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index fcf8ebf1e7..d6ecf5b327 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-import heapq
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
@@ -219,9 +218,8 @@ class ReplicationDataHandler:
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
- # We insert into the list using heapq as it is more efficient than
- # pushing then resorting each time.
- heapq.heappush(waiting_list, (position, deferred))
+ waiting_list.append((position, deferred))
+ waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
self._streams_to_replicate.append(stream)
continue
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..3705618b4f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self._store.get_current_events_token),
+ self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..87f927890c 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
+ shared_rooms,
sync,
tags,
thirdparty,
@@ -125,3 +126,6 @@ class ClientRestResource(JsonResource):
synapse.rest.admin.register_servlets_for_client_rest_resource(
hs, client_resource
)
+
+ # unstable
+ shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..2492634dac
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+ """
+ GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ """
+
+ PATTERNS = client_patterns(
+ "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ releases=(), # This is an unstable feature
+ )
+
+ def __init__(self, hs):
+ super(UserSharedRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.user_directory_active = hs.config.update_user_directory
+
+ async def on_GET(self, request, user_id):
+
+ if not self.user_directory_active:
+ raise SynapseError(
+ code=400,
+ msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ UserID.from_string(user_id)
+
+ requester = await self.auth.get_user_by_req(request)
+ if user_id == requester.user.to_string():
+ raise SynapseError(
+ code=400,
+ msg="You cannot request a list of shared rooms with yourself",
+ errcode=Codes.FORBIDDEN,
+ )
+ rooms = await self.store.get_shared_rooms_for_users(
+ requester.user.to_string(), user_id
+ )
+
+ return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+ UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 96488b131a..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..24ac57f35d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Implements additional endpoints as described in MSC2666
+ "uk.half-shot.msc2666": True,
},
},
)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 0db900fa0e..67a89cd51a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -433,7 +433,7 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name}
)
- def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
@@ -441,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7ab370efef..78ca6d8346 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ cast,
overload,
)
@@ -35,7 +36,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -507,8 +507,9 @@ class DatabasePool(object):
self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
- @defer.inlineCallbacks
- def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ async def runInteraction(
+ self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Starts a transaction on the database and runs a given function
Arguments:
@@ -521,7 +522,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
@@ -530,16 +531,14 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield defer.ensureDeferred(
- self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -549,7 +548,7 @@ class DatabasePool(object):
after_callback(*after_args, **after_kwargs)
raise
- return result
+ return cast(R, result)
async def runWithConnection(
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
@@ -604,6 +603,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
+ @overload
+ async def execute(
+ self, desc: str, decoder: Literal[None], query: str, *args: Any
+ ) -> List[Tuple[Any, ...]]:
+ ...
+
+ @overload
+ async def execute(
+ self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+ ) -> R:
+ ...
+
async def execute(
self,
desc: str,
@@ -1088,6 +1099,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Any:
+ ...
+
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
+ ...
+
async def simple_select_one_onecol(
self,
table: str,
@@ -1116,6 +1149,30 @@ class DatabasePool(object):
allow_none=allow_none,
)
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[False] = False,
+ ) -> Any:
+ ...
+
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: Iterable[str],
+ allow_none: Literal[True] = True,
+ ) -> Optional[Any]:
+ ...
+
@classmethod
def simple_select_one_onecol_txn(
cls,
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..c73d54fb67 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..99890ffbf3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
import calendar
import logging
import time
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@@ -291,16 +294,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- def count_daily_users(self):
+ async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
- def count_monthly_users(self):
+ async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
@@ -308,7 +311,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -327,15 +330,15 @@ class DataStore(
(count,) = txn.fetchone()
return count
- def count_r30_users(self):
+ async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
- Returns counts globaly for a given user as well as breaking
- by platform
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
@@ -408,7 +411,7 @@ class DataStore(
return results
- return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -418,7 +421,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
- def generate_user_daily_visits(self):
+ async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
@@ -473,7 +476,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@@ -497,22 +500,28 @@ class DataStore(
desc="get_users",
)
- def get_users_paginate(
- self, start, limit, user_id=None, name=None, guests=True, deactivated=False
- ):
+ async def get_users_paginate(
+ self,
+ start: int,
+ limit: int,
+ user_id: Optional[str] = None,
+ name: Optional[str] = None,
+ guests: bool = True,
+ deactivated: bool = False,
+ ) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
- start (int): start number to begin the query from
- limit (int): number of rows to retrieve
- user_id (string): search for user_id. ignored if name is not None
- name (string): search for local part of user_id or display name
- guests (bool): whether to in include guest users
- deactivated (bool): whether to include deactivated users
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ user_id: search for user_id. ignored if name is not None
+ name: search for local part of user_id or display name
+ guests: whether to in include guest users
+ deactivated: whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]], int
+ A tuple of a list of mappings from user to information and a count of total users.
"""
def get_users_paginate_txn(txn):
@@ -555,7 +564,7 @@ class DataStore(
users = self.db_pool.cursor_to_dict(txn)
return users, count
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 04042a2c98..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
import abc
import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
- def get_account_data_for_user(self, user_id):
+ async def get_account_data_for_user(
+ self, user_id: str
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
- user_id(str): The user to get the account_data for.
+ user_id: The user to get the account_data for.
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A 2-tuple of a dict of global account_data and a dict mapping from
+ room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
- def get_account_data_for_room(self, user_id, room_id):
+ async def get_account_data_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
Returns:
- A deferred dict of the room account_data
+ A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
- def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ async def get_account_data_for_room_and_type(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
- account_data_type (str): The account data type to get.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
+ account_data_type: The account data type to get.
Returns:
- A deferred of the room account_data for that type, or None if
- there isn't any set.
+ The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ async def get_updated_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
- user_id(str): The user to get the account_data for.
- stream_id(int): The point in the stream since which to get updates
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return defer.succeed(({}, {}))
+ return ({}, {})
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id: int):
+ async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+ await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 4e2b2a85ee..d568789124 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
- def _update_client_ips_batch(self):
+ async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..f8fe948122 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
- await self.db_pool.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- def _get_last_device_update_for_remote_user(
+ async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
- ):
+ ) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+ return await self.db_pool.runInteraction(
+ "get_last_device_update_for_remote_user", f
+ )
- def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+ async def mark_as_sent_devices_by_remote(
+ self, destination: str, stream_id: int
+ ) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
@@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+ async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
@@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
@@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- def update_remote_device_list_cache_entry(
+ async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
- ):
+ ) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(
+ async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
@@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 405b5eafa5..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(
+ async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
- ):
+ ) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..cc0b15ae07 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,8 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@@ -23,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -31,19 +34,67 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by get_e2e_device_keys_and_signatures"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a signature of the structure by the ed25519 key)
+ key_json = attr.ib(type=Optional[str])
+
+ # cross-signing sigs
+ signatures = attr.ib(type=Optional[Dict], default=None)
+
+
class EndToEndKeyWorkerStore(SQLBaseStore):
+ async def get_e2e_device_keys_for_federation_query(
+ self, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ now_stream_id = self.get_device_stream_token()
+
+ devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ key_json = device.key_json
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.display_name
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
- async def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -53,13 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
+ results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@@ -67,13 +112,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
- r = db_to_json(device_info.pop("key_json"))
+ r = db_to_json(device_info.key_json)
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
+ if device_info.signatures:
+ for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@@ -82,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv
@trace
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: List[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Fetch a list of device keys, together with their cross-signatures.
+
+ Args:
+ query_list: List of pairs of user_ids and device_ids. Device id can be None
+ to indicate "all devices for this user"
+
+ include_all_devices: whether to return devices without device keys
+
+ include_deleted_devices: whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data.
+ """
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
+ result = await self.db_pool.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_and_signatures_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ log_kv(result)
+ return result
+
+ def _get_e2e_device_keys_and_signatures_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = []
query_params = []
signature_query_clauses = []
@@ -119,7 +197,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -130,13 +208,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, key_json
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@@ -167,13 +246,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
- target_device_signatures = target_device_result.setdefault("signatures", {})
+ target_device_signatures = target_device_result.signatures
+ if target_device_signatures is None:
+ target_device_signatures = target_device_result.signatures = {}
+
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
- log_kv(result)
return result
async def get_e2e_one_time_keys(
@@ -252,10 +333,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
+ async def count_e2e_one_time_keys(
+ self, user_id: str, device_id: str
+ ) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
+ A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@@ -270,7 +353,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@@ -308,7 +391,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
- def _get_bare_e2e_cross_signing_keys_bulk(
+ async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@@ -316,16 +399,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
- user_ids (list[str]): the users whose keys are being requested
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A mapping from user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict, or
+ their user ID will map to None.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -541,9 +623,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@@ -579,12 +668,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
+ async def claim_e2e_one_time_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ """Take a list of one time keys out of the database.
+
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
@trace
def _claim_e2e_one_time_keys(txn):
@@ -620,11 +718,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- def delete_e2e_keys_by_device(self, user_id, device_id):
+ async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@@ -647,7 +745,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..4c3c162acf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
- raise StoreError(400, "stream_ordering too old")
+ raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e8834b2162..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import List
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
+ self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ ) -> Dict[str, int]:
+ """Get the notification count, the highlight count and the unread message count
+ for a given user in a given room after the given read receipt.
+
+ Note that this function assumes the user to be a current member of the room,
+ since it's either called by the sync handler to handle joined room entries, or by
+ the HTTP pusher to calculate the badge of unread joined rooms.
+
+ Args:
+ room_id: The room to retrieve the counts in.
+ user_id: The user to retrieve the counts for.
+ last_read_event_id: The event associated with the latest read receipt for
+ this user in this room. None if no receipt for this user in this room.
+
+ Returns
+ A dict containing the counts mentioned earlier in this docstring,
+ respectively under the keys "notify_count", "highlight_count" and
+ "unread_count".
+ """
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id
+ self, txn, room_id, user_id, last_read_event_id,
):
- sql = (
- "SELECT stream_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(sql, (room_id, last_read_event_id))
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
+ stream_ordering = None
+
+ if last_read_event_id is not None:
+ stream_ordering = self.get_stream_id_for_event_txn(
+ txn, last_read_event_id, allow_none=True,
+ )
+
+ if stream_ordering is None:
+ # Either last_read_event_id is None, or it's an event we don't have (e.g.
+ # because it's been purged), in which case retrieve the stream ordering for
+ # the latest membership event from this user in this room (which we assume is
+ # a join).
+ event_id = self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ retcol="event_id",
+ )
- stream_ordering = results[0][0]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
sql = (
- "SELECT count(*)"
+ "SELECT"
+ " COUNT(CASE WHEN notif = 1 THEN 1 END),"
+ " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+ " COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
+ " WHERE user_id = ?"
+ " AND room_id = ?"
+ " AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- notify_count = row[0] if row else 0
+
+ (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+ if row:
+ (notif_count, highlight_count, unread_count) = row
txn.execute(
"""
- SELECT notif_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """,
+ SELECT notif_count, unread_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """,
(room_id, user_id, stream_ordering),
)
- rows = txn.fetchall()
- if rows:
- notify_count += rows[0][0]
-
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
-
- txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- highlight_count = row[0] if row else 0
- return {"notify_count": notify_count, "highlight_count": highlight_count}
+ if row:
+ notif_count += row[0]
+ unread_count += row[1]
+
+ return {
+ "notify_count": notif_count,
+ "unread_count": unread_count,
+ "highlight_count": highlight_count,
+ }
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
- def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+ async def get_if_maybe_push_in_range_for_user(
+ self, user_id: str, min_stream_ordering: int
+ ) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
- user_id (str)
- min_stream_ordering (int)
+ user_id
+ min_stream_ordering
Returns:
- Deferred[bool]: True if there may be push to process, False if
- there definitely isn't.
+ True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
- WHERE user_id = ? AND stream_ordering > ?
+ WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
- async def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(
+ self,
+ event_id: str,
+ user_id_actions: Dict[str, List[Union[dict, str]]],
+ count_as_unread: bool,
+ ) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
- event_id (str)
- user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
- user_id to list of push actions, where an action can either be
- a string or dict.
-
- Returns:
- Deferred
+ event_id
+ user_id_actions: A mapping of user_id to list of push actions, where
+ an action can either be a string or dict.
+ count_as_unread: Whether this event should increment unread counts.
"""
-
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
- # can be used to inert into the `event_push_actions_staging` table.
+ # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
+ notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
- 1, # notif column
+ notif, # notif column
is_highlight, # highlight column
+ int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight)
- VALUES (?, ?, ?, ?, ?)
+ (event_id, user_id, actions, notif, highlight, unread)
+ VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
- def find_first_stream_ordering_after_ts(self, ts):
+ async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
- ts (int): timestamp in millis
+ ts: timestamp in millis
Returns:
- Deferred[int]: stream ordering of the first event received on/after
- the timestamp
+ stream ordering of the first event received on/after the timestamp
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
- coalesce(old.notif_count, 0) + upd.notif_count,
+ coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
- SELECT user_id, room_id, count(*) as notif_count,
+ SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
+ AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
- rows = txn.fetchall()
+ # First get the count of unread messages.
+ txn.execute(
+ sql % ("unread_count", "unread"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # We need to merge results from the two requests (the one that retrieves the
+ # unread count and the one that retrieves the notifications count) into a single
+ # object because we might not have the same amount of rows in each of them. To do
+ # this, we use a dict indexed on the user ID and room ID to make it easier to
+ # populate.
+ summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ for row in txn:
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=row[2],
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=0,
+ )
+
+ # Then get the count of notifications.
+ txn.execute(
+ sql % ("notif_count", "notif"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ for row in txn:
+ if (row[0], row[1]) in summaries:
+ summaries[(row[0], row[1])].notif_count = row[2]
+ else:
+ # Because the rules on notifying are different than the rules on marking
+ # a message unread, we might end up with messages that notify but aren't
+ # marked unread, so we might not have a summary for this (user, room)
+ # tuple to complete.
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=0,
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=row[2],
+ )
- logger.info("Rotating notifications, handling %d rows", len(rows))
+ logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
- "user_id": row[0],
- "room_id": row[1],
- "notif_count": row[2],
- "stream_ordering": row[3],
+ "user_id": user_id,
+ "room_id": room_id,
+ "notif_count": summary.notif_count,
+ "unread_count": summary.unread_count,
+ "stream_ordering": summary.stream_ordering,
}
- for row in rows
- if row[4] is None
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is None
],
)
txn.executemany(
"""
- UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ UPDATE event_push_summary
+ SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+ (
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ user_id,
+ room_id,
+ )
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is not None
+ ),
)
txn.execute(
@@ -881,3 +961,15 @@ def _action_has_highlight(actions):
pass
return False
+
+
+@attr.s
+class _EventPushSummary:
+ """Summary of pending event push actions for a given user in a given room.
+ Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+ """
+
+ unread_count = attr.ib(type=int)
+ stream_ordering = attr.ib(type=int)
+ old_user_id = attr.ib(type=str)
+ notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6313b41eef..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -97,6 +97,7 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -108,7 +109,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write
assert (
- hs.config.worker.writers.events == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -800,6 +801,7 @@ class PersistEventsStore:
table="events",
values=[
{
+ "instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@@ -1296,9 +1298,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
+ topological_ordering, notif, highlight, unread
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e6247d682d..17f5997b89 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
- if hs.config.worker.writers.events == hs.get_instance_name():
- # We are the process in charge of generating stream ids for events,
- # so instantiate ID generators based on the database
- self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ if isinstance(database.engine, PostgresEngine):
+ # If we're using Postgres than we can use `MultiWriterIdGenerator`
+ # regardless of whether this process writes to the streams or not.
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_stream_seq",
)
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
)
else:
- # Another process is in charge of persisting events and generating
- # stream IDs: rely on the replication streams to let us know which
- # IDs we can process.
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering"
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
self._get_event_cache = Cache(
"*getEvent*",
@@ -823,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
- def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ def _maybe_redact_event_row(
+ self,
+ original_ev: EventBase,
+ redactions: Iterable[str],
+ event_map: Dict[str, EventBase],
+ ) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
- original_ev (EventBase):
- redactions (iterable[str]): list of event ids of potential redaction events
- event_map (dict[str, EventBase]): other events which have been fetched, in
- which we can look up the redaaction events. Map from event id to event.
+ original_ev: The original event.
+ redactions: list of event ids of potential redaction events
+ event_map: other events which have been fetched, in which we can
+ look up the redaaction events. Map from event id to event.
Returns:
- Deferred[EventBase|None]: if the event should be redacted, a pruned
- event object. Otherwise, None.
+ If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
@@ -946,17 +978,17 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone()
return row[0] if row else 0
- def get_current_state_event_counts(self, room_id):
+ async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- Deferred[int]
+ The current number of state events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
@@ -991,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_forward_event_rows(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@@ -999,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1020,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
- def get_ex_outlier_stream_rows(self, last_id, current_id):
+ async def get_ex_outlier_stream_rows(
+ self, last_id: int, current_id: int
+ ) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1054,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
@@ -1226,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_next_event_to_expire(self):
+ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
- Returns: Deferred[Optional[Tuple[str, int]]]
+ Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
@@ -1246,6 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- def add_user_filter(self, user_localpart, user_filter):
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db_pool.runInteraction("add_user_filter", _do_txn)
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3919ecad69..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- def get_url_cache(self, url, ts):
+ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+ return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
@@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ async def update_cached_last_access_time(
+ self,
+ local_media: Iterable[str],
+ remote_media: Iterable[Tuple[str, str]],
+ time_ms: int,
+ ):
"""Updates the last access time of the given media
Args:
- local_media (iterable[str]): Set of media_ids
- remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ local_media: Set of media_ids
+ remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
- def delete_remote_media(self, media_origin, media_id):
+ async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
- def get_expired_url_cache(self, now_ts):
+ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
- def get_url_cache_media_before(self, before_ts):
+ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index 4db8949da7..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from synapse.storage._base import SQLBaseStore
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
- def get_user_id_for_open_id_token(self, token, ts_now_ms):
+ async def get_user_id_for_open_id_token(
+ self, token: str, ts_now_ms: int
+ ) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 301875a672..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
- def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ async def get_remote_profile_cache_entries_that_expire(
+ self, last_checked: int
+ ) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..ea833829ae 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> Set[int]:
"""Deletes room history before a certain point
Args:
- room_id (str):
-
- token (str): A topological token to delete events before
-
- delete_local_events (bool):
+ room_id:
+ token: A topological token to delete events before
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ The set of state groups that are referenced by deleted events.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups
- def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room
Args:
- room_id (str)
+ room_id
Returns:
- Deferred[List[int]]: The list of state groups to delete.
+ The list of state groups to delete.
"""
-
- return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return await self.db_pool.runInteraction(
+ "purge_room", self._purge_room_txn, room_id
+ )
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2fb5b02d7d..0de802a86b 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
-from twisted.internet import defer
-
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- def have_push_rules_changed_for_user(self, user_id, last_id):
+ async def have_push_rules_changed_for_user(
+ self, user_id: str, last_id: int
+ ) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
+ return False
else:
def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 436f22ad2d..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+ async def get_users_sent_receipts_between(
+ self, last_id: int, current_id: int
+ ) -> List[str]:
"""Get all users who sent receipts between `last_id` exclusive and
`current_id` inclusive.
Returns:
- Deferred[List[str]]
+ The list of users.
"""
if last_id == current_id:
@@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
)
@@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
- def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.db_pool.runInteraction(
+ async def insert_graph_receipt(
+ self, room_id, receipt_type, user_id, event_ids, data
+ ):
+ return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 12689f4308..01f20c03c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -84,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
- def get_user_by_access_token(self, token):
+ async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
- token (str): The access token of a user.
+ token: The access token of a user.
Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@cached()
- async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
@@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
- def set_server_admin(self, user, admin):
+ async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
- admin (bool): true iff the user is to be a server admin,
- false otherwise.
+ user: user ID of the user to test
+ admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
- return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
- def get_users_by_id_case_insensitive(self, user_id):
+ async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
- Returns a mapping of user_id -> password_hash.
+
+ Returns:
+ A mapping of user_id -> password_hash.
"""
def f(txn):
@@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+ return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
- def count_daily_user_type(self):
+ async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
- def get_threepid_validation_session(
- self, medium, client_secret, address=None, sid=None, validated=True
- ):
+ async def get_threepid_validation_session(
+ self,
+ medium: Optional[str],
+ client_secret: str,
+ address: Optional[str] = None,
+ sid: Optional[str] = None,
+ validated: Optional[bool] = True,
+ ) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str): A unique string provided by the client to help identify this
+ medium: The medium of the 3PID
+ client_secret: A unique string provided by the client to help identify this
validation attempt
- validated (bool|None): Whether sessions should be filtered by
+ address: The address of the 3PID
+ sid: The ID of the validation session
+ validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
- Deferred[dict|None]: A dict containing the following:
+ A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
- def delete_threepid_session(self, session_id):
+ async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
- session_id (str): The ID of the session to delete
+ session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
- def register_user(
+ async def register_user(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Attempts to register an account.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
- upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
- false to add a regular user account.
- appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode): Optionally create a profile for
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Whether this is a guest account being upgraded to a
+ non-guest account.
+ make_guest: True if the the new user should be guest, false to add a
+ regular user account.
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
- api.constants.UserTypes, or None for a normal user.
- shadow_banned (bool): Whether the user is shadow-banned,
- i.e. they may be told their requests succeeded but we ignore them.
+ admin: is an admin user?
+ user_type: type of user. One of the values from api.constants.UserTypes,
+ or None for a normal user.
+ shadow_banned: Whether the user is shadow-banned, i.e. they may be
+ told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
-
- Returns:
- Deferred
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- def user_set_password_hash(self, user_id, password_hash):
+ async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
- def user_set_consent_version(self, user_id, consent_version):
+ async def user_set_consent_version(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy the user has consented
- to
+ user_id: full mxid of the user to update
+ consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_version", f)
+ await self.db_pool.runInteraction("user_set_consent_version", f)
- def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ async def user_set_consent_server_notice_sent(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy we have notified the
- user about
+ user_id: full mxid of the user to update
+ consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+ await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+ async def user_delete_access_tokens(
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
+ ) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str): list of access_tokens IDs which should
- *not* be deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_tokens ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
- defer.Deferred[list[str, int, str|None, int]]: a list of
- (token, token id, device id) for each of the deleted tokens
+ A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.db_pool.runInteraction("user_delete_access_tokens", f)
+ return await self.db_pool.runInteraction("user_delete_access_tokens", f)
- def delete_access_token(self, access_token):
+ async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.db_pool.runInteraction("delete_access_token", f)
+ await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
- def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+ async def validate_threepid_session(
+ self, session_id: str, client_secret: str, token: str, current_ts: int
+ ) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
- session_id (str): The id of a validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- token (str): A validation token
- current_ts (int): The current unix time in milliseconds. Used for
- checking token expiry status
+ session_id: The id of a validation session
+ client_secret: A unique string provided by the client to help identify
+ this validation attempt
+ token: A validation token
+ current_ts: The current unix time in milliseconds. Used for checking
+ token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
- deferred str|None: A str representing a link to redirect the user
- to if there is one.
+ A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def start_or_continue_validation_session(
+ async def start_or_continue_validation_session(
self,
- medium,
- address,
- session_id,
- client_secret,
- send_attempt,
- next_link,
- token,
- token_expires,
- ):
+ medium: str,
+ address: str,
+ session_id: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str],
+ token: str,
+ token_expires: int,
+ ) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- session_id (str): The id of this validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- next_link (str|None): The link to redirect the user to upon
- successful validation
- token (str): The validation token
- token_expires (int): The timestamp for which after the token
- will no longer be valid
+ medium: The medium of the 3PID
+ address: The address of the 3PID
+ session_id: The id of this validation session
+ client_secret: A unique string provided by the client to help
+ identify this validation attempt
+ send_attempt: The latest send_attempt on this session
+ next_link: The link to redirect the user to upon successful validation
+ token: The validation token
+ token_expires: The timestamp for which after the token will no
+ longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
- def cull_expired_threepid_validation_tokens(self):
+ async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
- return txn.execute(sql, (ts,))
+ txn.execute(sql, (ts,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
- def get_relations_for_event(
+ async def get_relations_for_event(
self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[RelationPaginationToken] = None,
+ to_token: Optional[RelationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
+ event_id: Fetch events that relate to this event ID.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
+ List of event IDs that match relations requested. The rows are of
+ the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
- def get_aggregation_groups_for_event(
+ async def get_aggregation_groups_for_event(
self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ event_type: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[AggregationPaginationToken] = None,
+ to_token: Optional[AggregationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event.
Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
+ event_id: Fetch events that relate to this event ID.
+ event_type: Only fetch events with this event type, if given.
+ limit: Only fetch the `limit` groups.
+ direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
+ List of groups of annotations that match. Each row is a dict with
+ `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ async def has_user_annotated_event(
+ self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+ ) -> bool:
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
+ parent_id: The event being annotated
+ event_type: The event type of the annotation
+ aggregation_key: The aggregation key of the annotation
+ sender: The sender of the annotation
Returns:
- Deferred[bool]
+ True if the event is already annotated.
"""
sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a92641c339..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_room_with_stats(self, room_id: str):
+ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
- def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ async def count_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ ignore_non_federatable: bool,
+ ) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
- network_tuple (ThirdPartyInstanceID|None)
- ignore_non_federatable (bool): If true filters out non-federatable rooms
+ network_tuple
+ ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
- def get_media_mxcs_in_room(self, room_id):
+ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
- room_id (str)
+ room_id
Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
+ The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ async def quarantine_media_ids_in_room(
+ self, room_id: str, quarantined_by: str
+ ) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- def quarantine_media_by_id(
+ async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
- ):
+ ) -> int:
"""quarantines a single local or remote media id
Args:
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
- def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ async def quarantine_media_ids_by_user(
+ self, user_id: str, quarantined_by: str
+ ) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_room_count(self):
- """Retrieve a list of all rooms
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
"""
def f(txn):
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db_pool.runInteraction("get_rooms", f)
+ return await self.db_pool.runInteraction("get_rooms", f)
async def add_event_report(
self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 161edbeccb..c46f5cd524 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id: str):
- return self.db_pool.runInteraction(
+ async def get_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id: str):
+ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
+ dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+ return await self.db_pool.runInteraction(
+ "get_room_summary", _get_room_summary_txn
+ )
@cached()
- def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
- A awaitable list of RoomsForUser.
+ A list of RoomsForUser.
"""
- return self.get_rooms_for_local_user_where_membership_is(
+ return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@@ -297,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: List[str]
- ) -> Optional[List[RoomsForUser]]:
+ self, user_id: str, membership_list: Collection[str]
+ ) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -313,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return None
+ return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
@@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+ async def get_rooms_for_user_with_stream_ordering(
+ self, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
+ Returns the rooms the user is in currently, along with the stream
+ ordering of the most recent join for that user and room.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+ def _get_rooms_for_user_with_stream_ordering_txn(
+ self, txn, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
- return results
+ return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id: str):
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
- user_id
+ user_id: The user ID to query the rooms of.
Returns:
- Deferred[set[str]]
+ The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id: str, room_id: str):
+ async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db_pool.runInteraction("forget_membership", f)
+ await self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+ SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+ SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dcbdeab36e..9c5f0229c1 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- def _find_highlights_in_postgres(self, search_query, events):
+ async def _find_highlights_in_postgres(
+ self, search_query: str, events: List[EventBase]
+ ) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
- search_query (str)
- events (list): A list of events
+ search_query
+ events: A list of events
Returns:
- deferred : A set of strings.
+ A set of strings.
"""
def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db_pool.runInteraction("_find_highlights", f)
+ return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Tuple
+
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
- def get_event_reference_hashes(self, event_ids):
+ async def get_event_reference_hashes(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Dict[str, bytes]]:
+ """Get all hashes for given events.
+
+ Args:
+ event_ids: The event IDs to get hashes for.
+
+ Returns:
+ A mapping of event ID to a mapping of algorithm to hash.
+ """
+
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
- return self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return await self.db_pool.runInteraction("get_event_reference_hashes", f)
- async def add_event_hashes(self, event_ids):
+ async def add_event_hashes(
+ self, event_ids: Iterable[str]
+ ) -> List[Tuple[str, Dict[str, str]]]:
+ """
+
+ Args:
+ event_ids: The event IDs
+
+ Returns:
+ A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+ """
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
- def _get_event_reference_hashes_txn(self, txn, event_id):
+ def _get_event_reference_hashes_txn(
+ self, txn: Cursor, event_id: str
+ ) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
- txn (cursor):
- event_id (str): Id for the Event.
+ txn:
+ event_id: Id for the Event.
Returns:
- A dict[unicode, bytes] of algorithm -> hash.
+ A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9b9bc304a8..55a250ef06 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
)
async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
- """
+ """Update the state of a room.
+
+ fields can contain the following keys with string values:
+ * join_rules
+ * history_visibility
+ * encryption
+ * name
+ * topic
+ * avatar
+ * canonical_alias
+
+ A is_federatable key can also be included with a boolean value.
+
Args:
- room_id
- fields
+ room_id: The room ID to update the state of.
+ fields: The fields to update. This can include a partial list of the
+ above fields to only update some room information.
"""
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
+ # Ensure that the values to update are valid, they should be strings and
+ # not contain any null bytes.
+ #
+ # Invalid data gets overwritten with null.
+ #
+ # Note that a missing value should not be overwritten (it keeps the
+ # previous value).
+ sentinel = object()
for col in (
"join_rules",
"history_visibility",
@@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
- field = fields.get(col)
- if field and "\0" in field:
+ field = fields.get(col, sentinel)
+ if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
await self.db_pool.simple_upsert(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..be6df8a6d1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,7 +39,7 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -202,7 +209,7 @@ def _make_generic_sql_bound(
)
-def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
- room_ids: Iterable[str],
+ room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
@@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def get_rooms_that_changed(self, room_ids, from_key):
+ def get_rooms_that_changed(
+ self, room_ids: Collection[str], from_key: str
+ ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
- room_ids (list)
- from_key (str): The room_key portion of a StreamToken
+ room_ids
+ from_key: The room_key portion of a StreamToken
"""
- from_key = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
- if self._events_stream_cache.has_entity_changed(room_id, from_key)
+ if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
async def get_room_events_stream_for_room(
@@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- async def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(
+ self, user_id: str, from_key: str, to_key: str
+ ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A stream ID.
"""
- return await self.db_pool.simple_select_one_onecol(
- table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+ return await self.db_pool.runInteraction(
+ "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+ )
+
+ def get_stream_id_for_event_txn(
+ self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ ) -> int:
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="stream_ordering",
+ allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
@@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return row[0][0] if row else 0
- def _get_max_topological_txn(self, txn, room_id):
+ def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn(
self,
- txn,
+ txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
@@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
+ # This cannot happen as `allow_none=False`.
+ assert results is not None
+
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn) -> None:
+ def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
- min_positions = dict(txn) # Map from type -> min position
+ min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn(
self,
- txn,
+ txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 0c34bbf21a..96ffe26cc9 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- tags_by_room = {}
+ tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
@@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, List[str]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 9eef8e57c5..b89668d561 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
class UIAuthStore(UIAuthWorkerStore):
- def delete_old_ui_auth_sessions(self, expiration_time: int):
+ async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f2e93614..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ async def update_profile_in_user_dir(
+ self, user_id: str, display_name: str, avatar_url: str
+ ) -> None:
"""
Update or add a user's profile in the user directory.
"""
+ # If the display name or avatar URL are unexpected types, overwrite them.
+ if not isinstance(display_name, str):
+ display_name = None
+ if not isinstance(avatar_url, str):
+ avatar_url = None
def _update_profile_in_user_dir_txn(txn):
new_entry = self.db_pool.simple_upsert_txn(
@@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ async def add_users_who_share_private_room(
+ self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ room_id
+ user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def add_users_in_public_rooms(self, room_id, user_ids):
+ async def add_users_in_public_rooms(
+ self, room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_ids (list[str])
+ room_id
+ user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def delete_all_from_user_dir(self):
+ async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
- def remove_from_user_dir(self, user_id):
+ async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
- def remove_user_who_share_room(self, user_id, room_id):
+ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
- user_id (str)
- room_id (str)
+ user_id
+ room_id
"""
def _remove_user_who_share_room_txn(txn):
@@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
+ @cached()
+ async def get_shared_rooms_for_users(
+ self, user_id: str, other_user_id: str
+ ) -> Set[str]:
+ """
+ Returns the rooms that a local user shares with another local or remote user.
+
+ Args:
+ user_id: The MXID of a local user
+ other_user_id: The MXID of the other user
+
+ Returns:
+ A set of room ID's that the users share.
+ """
+
+ def _get_shared_rooms_for_users_txn(txn):
+ txn.execute(
+ """
+ SELECT p1.room_id
+ FROM users_in_public_rooms as p1
+ INNER JOIN users_in_public_rooms as p2
+ ON p1.room_id = p2.room_id
+ AND p1.user_id = ?
+ AND p2.user_id = ?
+ UNION
+ SELECT room_id
+ FROM users_who_share_private_rooms
+ WHERE
+ user_id = ?
+ AND other_user_id = ?
+ """,
+ (user_id, other_user_id, user_id, other_user_id),
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows
+
+ rows = await self.db_pool.runInteraction(
+ "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ )
+
+ return {row["room_id"] for row in rows}
+
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index e3547e53b3..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id: str) -> None:
+ async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_erased", f)
+ await self.db_pool.runInteraction("mark_user_erased", f)
- def mark_user_not_erased(self, user_id: str) -> None:
+ async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db_pool.runInteraction("mark_user_not_erased", f)
+ await self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b27a4843d0..8fd21c2bf8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ positive: Whether the IDs are positive (true) or negative (false).
+ When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
def __init__(
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
instance_column: str,
id_column: str,
sequence_name: str,
+ positive: bool = True,
):
self._db = db
self._instance_name = instance_name
+ self._positive = positive
+ self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
+ # Note: If we are a negative stream then we still store all the IDs as
+ # positive to make life easier for us, and simply negate the IDs when we
+ # return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
@@ -223,8 +231,12 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
+ #
+ # We start at 1 here as a) the first generated stream ID will be 2, and
+ # b) other parts of the code assume that stream IDs are strictly greater
+ # than 0.
self._persisted_upto_position = (
- min(self._current_positions.values()) if self._current_positions else 0
+ min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
@@ -233,13 +245,16 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
+ # If positive stream aggregate via MAX. For negative stream use MIN
+ # *and* negate the result to get a positive number.
sql = """
- SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
+ "agg": "MAX" if self._positive else "-MIN",
}
cur = db_conn.cursor()
@@ -269,15 +284,16 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token_for_writer(self._instance_name) < next_id
-
with self._lock:
+ assert self._current_positions.get(self._instance_name, 0) < next_id
+
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
- yield next_id
+ # Multiply by the return factor so that the ID has correct sign.
+ yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
@@ -296,15 +312,15 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
- assert max(self.get_positions().values(), default=0) < min(next_ids)
-
with self._lock:
+ assert max(self._current_positions.values(), default=0) < min(next_ids)
+
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
- yield next_ids
+ yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
@@ -327,7 +343,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
- return next_id
+ return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
@@ -350,29 +366,32 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
- # Currently we don't support this operation, as it's not obvious how to
- # condense the stream positions of multiple writers into a single int.
- raise NotImplementedError()
+ return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
with self._lock:
- return self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
- return dict(self._current_positions)
+ return {
+ name: self._return_factor * i
+ for name, i in self._current_positions.items()
+ }
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
+ new_id *= self._return_factor
+
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +409,7 @@ class MultiWriterIdGenerator:
"""
with self._lock:
- return self._persisted_upto_position
+ return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..dfefbd996d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -20,6 +20,7 @@ from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
import attr
+from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -338,11 +339,11 @@ class Linearizer(object):
class ReadWriteLock(object):
- """A deferred style read write lock.
+ """An async read write lock.
Example:
- with (yield read_write_lock.read("test_key")):
+ with await read_write_lock.read("test_key"):
# do some work
"""
@@ -365,8 +366,7 @@ class ReadWriteLock(object):
# Latest writer queued
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
- @defer.inlineCallbacks
- def read(self, key):
+ async def read(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -376,7 +376,8 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
- yield make_deferred_yieldable(curr_writer)
+ if curr_writer:
+ await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@@ -388,8 +389,7 @@ class ReadWriteLock(object):
return _ctx_manager()
- @defer.inlineCallbacks
- def write(self, key):
+ async def write(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
@@ -405,7 +405,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
- yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
+ await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 69945a8f98..eb78ab412a 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0b5204654c..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
)
)
return event, context
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+ """
+ Tests the UserSharedRoomsServlet.
+ """
+
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ shared_rooms.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["update_user_directory"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def _get_shared_rooms(self, token, other_user):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ % other_user,
+ access_token=token,
+ )
+ self.render(request)
+ return request, channel
+
+ def test_shared_room_list_public(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is public.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_private(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is private.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_mixed(self):
+ """
+ The shared room list between two users should contain both public and private
+ rooms.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+ self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+ self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+ self.helper.join(room_public, user=u2, tok=u2_token)
+ self.helper.join(room_private, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 2)
+ self.assertTrue(room_public in channel.json_body["joined"])
+ self.assertTrue(room_private in channel.json_body["joined"])
+
+ def test_shared_room_list_after_leave(self):
+ """
+ A room should no longer be considered shared if the other
+ user has left it.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Assert user directory is not empty
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ self.helper.leave(room, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u2_token, u1)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 261bf5b08b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index cdfd2634aa..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ event.event_id, {user_id: action}, False,
)
)
yield defer.ensureDeferred(
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 14ce21c786..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first")
+ id_gen_2 = self._create_id_generator("second")
+
+ with self.get_success(id_gen_1.get_next()) as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen_2.get_next()) as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8522c6fc09..fb1ca90336 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Collection
"""
Utility functions for poking events into the storage of the server under test.
@@ -58,7 +57,7 @@ async def inject_member_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -80,7 +79,7 @@ async def inject_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
if room_version is None:
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.async_helpers import ReadWriteLock
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
rwlock.read(key), # 5
rwlock.write(key), # 6
]
+ ds = [defer.ensureDeferred(d) for d in ds]
self._assert_called_before_not_after(ds, 2)
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
with ds[6].result:
pass
- d = rwlock.write(key)
+ d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
- d = rwlock.read(key)
+ d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
|