diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 7b96f61d7b..d3b4887f69 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -268,6 +268,9 @@ class MockHomeserver:
def get_instance_name(self) -> str:
return "master"
+ def should_send_federation(self) -> bool:
+ return False
+
class Porter:
def __init__(
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e1d31cabed..2653764119 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -259,3 +259,13 @@ class ReceiptTypes:
READ: Final = "m.read"
READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
FULLY_READ: Final = "m.fully_read"
+
+
+class PublicRoomsFilterFields:
+ """Fields in the search filter for `/publicRooms` that we understand.
+
+ As defined in https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3publicrooms
+ """
+
+ GENERIC_SEARCH_TERM: Final = "generic_search_term"
+ ROOM_TYPES: Final = "org.matrix.msc3827.room_types"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 363ac98ea9..923891ae0d 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -106,7 +106,9 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs)
def start_worker_reactor(
appname: str,
config: HomeServerConfig,
- run_command: Callable[[], None] = reactor.run,
+ # Use a lambda to avoid binding to a given reactor at import time.
+ # (needed when synapse.app.complement_fork_starter is being used)
+ run_command: Callable[[], None] = lambda: reactor.run(),
) -> None:
"""Run the reactor in the main process
@@ -141,7 +143,9 @@ def start_reactor(
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
- run_command: Callable[[], None] = reactor.run,
+ # Use a lambda to avoid binding to a given reactor at import time.
+ # (needed when synapse.app.complement_fork_starter is being used)
+ run_command: Callable[[], None] = lambda: reactor.run(),
) -> None:
"""Run the reactor in the main process
diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py
new file mode 100644
index 0000000000..89eb07df27
--- /dev/null
+++ b/synapse/app/complement_fork_starter.py
@@ -0,0 +1,190 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ## What this script does
+#
+# This script spawns multiple workers, whilst only going through the code loading
+# process once. The net effect is that start-up time for a swarm of workers is
+# reduced, particularly in CPU-constrained environments.
+#
+# Before the workers are spawned, the database is prepared in order to avoid the
+# workers racing.
+#
+# ## Stability
+#
+# This script is only intended for use within the Synapse images for the
+# Complement test suite.
+# There are currently no stability guarantees whatsoever; especially not about:
+# - whether it will continue to exist in future versions;
+# - the format of its command-line arguments; or
+# - any details about its behaviour or principles of operation.
+#
+# ## Usage
+#
+# The first argument should be the path to the database configuration, used to
+# set up the database. The rest of the arguments are used as follows:
+# Each worker is specified as an argument group (each argument group is
+# separated by '--').
+# The first argument in each argument group is the Python module name of the application
+# to start. Further arguments are then passed to that module as-is.
+#
+# ## Example
+#
+# python -m synapse.app.complement_fork_starter path_to_db_config.yaml \
+# synapse.app.homeserver [args..] -- \
+# synapse.app.generic_worker [args..] -- \
+# ...
+# synapse.app.generic_worker [args..]
+#
+import argparse
+import importlib
+import itertools
+import multiprocessing
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet.main import installReactor
+
+
+class ProxiedReactor:
+ """
+ Twisted tracks the 'installed' reactor as a global variable.
+ (Actually, it does some module trickery, but the effect is similar.)
+
+ The default EpollReactor is buggy if it's created before a process is
+ forked, then used in the child.
+ See https://twistedmatrix.com/trac/ticket/4759#comment:17.
+
+ However, importing certain Twisted modules will automatically create and
+ install a reactor if one hasn't already been installed.
+ It's not normally possible to re-install a reactor.
+
+ Given the goal of launching workers with fork() to only import the code once,
+ this presents a conflict.
+ Our work around is to 'install' this ProxiedReactor which prevents Twisted
+ from creating and installing one, but which lets us replace the actual reactor
+ in use later on.
+ """
+
+ def __init__(self) -> None:
+ self.___reactor_target: Any = None
+
+ def _install_real_reactor(self, new_reactor: Any) -> None:
+ """
+ Install a real reactor for this ProxiedReactor to forward lookups onto.
+
+ This method is specific to our ProxiedReactor and should not clash with
+ any names used on an actual Twisted reactor.
+ """
+ self.___reactor_target = new_reactor
+
+ def __getattr__(self, attr_name: str) -> Any:
+ return getattr(self.___reactor_target, attr_name)
+
+
+def _worker_entrypoint(
+ func: Callable[[], None], proxy_reactor: ProxiedReactor, args: List[str]
+) -> None:
+ """
+ Entrypoint for a forked worker process.
+
+ We just need to set up the command-line arguments, create our real reactor
+ and then kick off the worker's main() function.
+ """
+
+ sys.argv = args
+
+ from twisted.internet.epollreactor import EPollReactor
+
+ proxy_reactor._install_real_reactor(EPollReactor())
+ func()
+
+
+def main() -> None:
+ """
+ Entrypoint for the forking launcher.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("db_config", help="Path to database config file")
+ parser.add_argument(
+ "args",
+ nargs="...",
+ help="Argument groups separated by `--`. "
+ "The first argument of each group is a Synapse app name. "
+ "Subsequent arguments are passed through.",
+ )
+ ns = parser.parse_args()
+
+ # Split up the subsequent arguments into each workers' arguments;
+ # `--` is our delimiter of choice.
+ args_by_worker: List[List[str]] = [
+ list(args)
+ for cond, args in itertools.groupby(ns.args, lambda ele: ele != "--")
+ if cond and args
+ ]
+
+ # Prevent Twisted from installing a shared reactor that all the workers will
+ # inherit when we fork(), by installing our own beforehand.
+ proxy_reactor = ProxiedReactor()
+ installReactor(proxy_reactor)
+
+ # Import the entrypoints for all the workers.
+ worker_functions = []
+ for worker_args in args_by_worker:
+ worker_module = importlib.import_module(worker_args[0])
+ worker_functions.append(worker_module.main)
+
+ # We need to prepare the database first as otherwise all the workers will
+ # try to create a schema version table and some will crash out.
+ from synapse._scripts import update_synapse_database
+
+ update_proc = multiprocessing.Process(
+ target=_worker_entrypoint,
+ args=(
+ update_synapse_database.main,
+ proxy_reactor,
+ [
+ "update_synapse_database",
+ "--database-config",
+ ns.db_config,
+ "--run-background-updates",
+ ],
+ ),
+ )
+ print("===== PREPARING DATABASE =====", file=sys.stderr)
+ update_proc.start()
+ update_proc.join()
+ print("===== PREPARED DATABASE =====", file=sys.stderr)
+
+ # At this point, we've imported all the main entrypoints for all the workers.
+ # Now we basically just fork() out to create the workers we need.
+ # Because we're using fork(), all the workers get a clone of this launcher's
+ # memory space and don't need to repeat the work of loading the code!
+ # Instead of using fork() directly, we use the multiprocessing library,
+ # which uses fork() on Unix platforms.
+ processes = []
+ for (func, worker_args) in zip(worker_functions, args_by_worker):
+ process = multiprocessing.Process(
+ target=_worker_entrypoint, args=(func, proxy_reactor, worker_args)
+ )
+ process.start()
+ processes.append(process)
+
+ # Be a good parent and wait for our children to die before exiting.
+ for process in processes:
+ process.join()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 63310c8d07..2db8cfb005 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, Optional
import attr
from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
from ._base import Config, ConfigError
@@ -159,12 +159,7 @@ class CacheConfig(Config):
self.track_memory_usage = cache_config.get("track_memory_usage", False)
if self.track_memory_usage:
- try:
- check_requirements("cache_memory")
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- )
+ check_requirements("cache_memory")
expire_caches = cache_config.get("expire_caches", True)
cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m")
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index c82f3ee7a3..6e11fbdb9a 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -145,7 +145,7 @@ class EmailConfig(Config):
raise ConfigError(
'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". '
- "Please consult the sample config at docs/sample_config.yaml for "
+ "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for "
"details and update your config file."
)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 0a285dba31..ee443cea00 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -87,3 +87,6 @@ class ExperimentalConfig(Config):
# MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
+
+ # MSC3827: Filtering of /publicRooms by room type
+ self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False)
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 49aaca7cf6..a973bb5080 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -15,14 +15,9 @@
from typing import Any
from synapse.types import JsonDict
+from synapse.util.check_dependencies import check_requirements
-from ._base import Config, ConfigError
-
-MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
-
- Install by running:
- pip install synapse[jwt]
- """
+from ._base import Config
class JWTConfig(Config):
@@ -41,13 +36,7 @@ class JWTConfig(Config):
# that the claims exist on the JWT.
self.jwt_issuer = jwt_config.get("issuer")
self.jwt_audiences = jwt_config.get("audiences")
-
- try:
- from authlib.jose import JsonWebToken
-
- JsonWebToken # To stop unused lint.
- except ImportError:
- raise ConfigError(MISSING_AUTHLIB)
+ check_requirements("jwt")
else:
self.jwt_enabled = False
self.jwt_secret = None
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index d636507886..3b42be5b5b 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -18,7 +18,7 @@ from typing import Any, Optional
import attr
from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
from ._base import Config, ConfigError
@@ -57,12 +57,7 @@ class MetricsConfig(Config):
self.sentry_enabled = "sentry" in config
if self.sentry_enabled:
- try:
- check_requirements("sentry")
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- )
+ check_requirements("sentry")
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 98e8cd8b5a..5418a332da 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
-from ..util.check_dependencies import DependencyException, check_requirements
+from ..util.check_dependencies import check_requirements
from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider"
@@ -41,12 +41,7 @@ class OIDCConfig(Config):
if not self.oidc_providers:
return
- try:
- check_requirements("oidc")
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- ) from e
+ check_requirements("oidc")
# check we don't have any duplicate idp_ids now. (The SSO handler will also
# check for duplicates when the REST listeners get registered, but that happens
@@ -146,7 +141,6 @@ OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
}
-
# the `oidc_providers` list can either be None (as it is in the default config), or
# a list of provider configs, each of which requires an explicit ID and name.
OIDC_PROVIDER_LIST_SCHEMA = {
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index d4090a1f9a..4fc1784efe 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -136,6 +136,11 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.003, "burst_count": 5},
)
+ self.rc_invites_per_issuer = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_issuer", {}),
+ defaults={"per_second": 0.3, "burst_count": 10},
+ )
+
self.rc_third_party_invite = RateLimitConfig(
config.get("rc_third_party_invite", {}),
defaults={
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index aadec1e54e..3c69dd325f 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -21,7 +21,7 @@ import attr
from synapse.config.server import generate_ip_set
from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
@@ -184,13 +184,7 @@ class ContentRepositoryConfig(Config):
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
- try:
- check_requirements("url_preview")
-
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- )
+ check_requirements("url_preview")
proxy_env = getproxies_environment()
if "url_preview_ip_range_blacklist" not in config:
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index bd7c234d31..49ca663dde 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -18,7 +18,7 @@ from typing import Any, List, Set
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
@@ -76,12 +76,7 @@ class SAML2Config(Config):
if not saml2_config.get("sp_config") and not saml2_config.get("config_path"):
return
- try:
- check_requirements("saml2")
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- )
+ check_requirements("saml2")
self.saml2_enabled = True
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 6fbf927f11..c19270c6c5 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -15,7 +15,7 @@
from typing import Any, List, Set
from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
from ._base import Config, ConfigError
@@ -40,12 +40,7 @@ class TracerConfig(Config):
if not self.opentracer_enabled:
return
- try:
- check_requirements("opentracing")
- except DependencyException as e:
- raise ConfigError(
- e.message # noqa: B306, DependencyException.message is a property
- )
+ check_requirements("opentracing")
# The tracer is enabled so sanitize the config
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b7c54e642f..479d936dc0 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1092,20 +1092,19 @@ class FederationEventHandler:
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier
+ context = await self._state_handler.compute_event_context(
+ event,
+ state_ids_before_event=state_ids,
+ )
try:
- context = await self._state_handler.compute_event_context(
- event,
- state_ids_before_event=state_ids,
- )
context = await self._check_event_auth(
origin,
event,
context,
)
except AuthError as e:
- # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it
- # for now
- logger.exception("Unexpected AuthError from _check_event_auth")
+ # This happens only if we couldn't find the auth events. We'll already have
+ # logged a warning, so now we just convert to a FederationError.
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
if not backfilled and not context.rejected:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 189f52fe5a..c6b40a5b7a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -903,6 +903,9 @@ class EventCreationHandler:
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
+ if ratelimit:
+ await self.request_ratelimiter.ratelimit(requester, update=False)
+
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
# a situation where event persistence can't keep up, causing
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 183d4ae3c4..29868eb743 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -25,6 +25,7 @@ from synapse.api.constants import (
GuestAccess,
HistoryVisibility,
JoinRules,
+ PublicRoomsFilterFields,
)
from synapse.api.errors import (
Codes,
@@ -181,6 +182,7 @@ class RoomListHandler:
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
+ "org.matrix.msc3827.room_type": room["room_type"],
}
# Filter out Nones – rather omit the field altogether
@@ -239,7 +241,9 @@ class RoomListHandler:
response["chunk"] = results
response["total_room_count_estimate"] = await self.store.count_public_rooms(
- network_tuple, ignore_non_federatable=from_federation
+ network_tuple,
+ ignore_non_federatable=from_federation,
+ search_filter=search_filter,
)
return response
@@ -508,8 +512,21 @@ class RoomListNextBatch:
def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
- if search_filter and search_filter.get("generic_search_term", None):
- generic_search_term = search_filter["generic_search_term"].upper()
+ """Determines whether the given search filter matches a room entry returned over
+ federation.
+
+ Only used if the remote server does not support MSC2197 remote-filtered search, and
+ hence does not support MSC3827 filtering of `/publicRooms` by room type either.
+
+ In this case, we cannot apply the `room_type` filter since no `room_type` field is
+ returned.
+ """
+ if search_filter and search_filter.get(
+ PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+ ):
+ generic_search_term = search_filter[
+ PublicRoomsFilterFields.GENERIC_SEARCH_TERM
+ ].upper()
if generic_search_term in room_entry.get("name", "").upper():
return True
elif generic_search_term in room_entry.get("topic", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bf6bae1232..5648ab4bf4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -101,19 +101,33 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ # Ratelimiter for invites, keyed by room (across all issuers, all
+ # recipients).
self._invites_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
)
- self._invites_per_user_limiter = Ratelimiter(
+
+ # Ratelimiter for invites, keyed by recipient (across all rooms, all
+ # issuers).
+ self._invites_per_recipient_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
+ # Ratelimiter for invites, keyed by issuer (across all rooms, all
+ # recipients).
+ self._invites_per_issuer_limiter = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count,
+ )
+
self._third_party_invite_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
@@ -258,7 +272,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if room_id:
await self._invites_per_room_limiter.ratelimit(requester, room_id)
- await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
+ await self._invites_per_recipient_limiter.ratelimit(requester, invitee_user_id)
+ if requester is not None:
+ await self._invites_per_issuer_limiter.ratelimit(requester)
async def _local_membership_update(
self,
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index f45e06eb0e..5c01482acf 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -271,6 +271,9 @@ class StatsHandler:
room_state["is_federatable"] = (
event_content.get(EventContentFields.FEDERATE, True) is True
)
+ room_type = event_content.get(EventContentFields.ROOM_TYPE)
+ if isinstance(room_type, str):
+ room_state["room_type"] = room_type
elif typ == EventTypes.JoinRules:
room_state["join_rules"] = event_content.get("join_rule")
elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 903ec40c86..50c57940f9 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,6 +164,7 @@ Gotchas
with an active span?
"""
import contextlib
+import enum
import inspect
import logging
import re
@@ -268,7 +269,7 @@ try:
_reporter: Reporter = attr.Factory(Reporter)
- def set_process(self, *args, **kwargs):
+ def set_process(self, *args: Any, **kwargs: Any) -> None:
return self._reporter.set_process(*args, **kwargs)
def report_span(self, span: "opentracing.Span") -> None:
@@ -319,7 +320,11 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
# Util methods
-Sentinel = object()
+
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
P = ParamSpec("P")
@@ -339,12 +344,12 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
return _only_if_tracing_inner
-def ensure_active_span(message, ret=None):
+def ensure_active_span(message: str, ret=None):
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
Args:
- message (str): Message which fills in "There was no active span when trying to %s"
+ message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
ret (object): return value if opentracing is None or there is no active span.
@@ -402,7 +407,7 @@ def init_tracer(hs: "HomeServer") -> None:
config = JaegerConfig(
config=hs.config.tracing.jaeger_config,
service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
- scope_manager=LogContextScopeManager(hs.config),
+ scope_manager=LogContextScopeManager(),
metrics_factory=PrometheusMetricsFactory(),
)
@@ -451,15 +456,15 @@ def whitelisted_homeserver(destination: str) -> bool:
# Could use kwargs but I want these to be explicit
def start_active_span(
- operation_name,
- child_of=None,
- references=None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
+ operation_name: str,
+ child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
+ references: Optional[List["opentracing.Reference"]] = None,
+ tags: Optional[Dict[str, str]] = None,
+ start_time: Optional[float] = None,
+ ignore_active_span: bool = False,
+ finish_on_close: bool = True,
*,
- tracer=None,
+ tracer: Optional["opentracing.Tracer"] = None,
):
"""Starts an active opentracing span.
@@ -493,11 +498,11 @@ def start_active_span(
def start_active_span_follows_from(
operation_name: str,
contexts: Collection,
- child_of=None,
+ child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
start_time: Optional[float] = None,
*,
- inherit_force_tracing=False,
- tracer=None,
+ inherit_force_tracing: bool = False,
+ tracer: Optional["opentracing.Tracer"] = None,
):
"""Starts an active opentracing span, with additional references to previous spans
@@ -540,7 +545,7 @@ def start_active_span_from_edu(
edu_content: Dict[str, Any],
operation_name: str,
references: Optional[List["opentracing.Reference"]] = None,
- tags: Optional[Dict] = None,
+ tags: Optional[Dict[str, str]] = None,
start_time: Optional[float] = None,
ignore_active_span: bool = False,
finish_on_close: bool = True,
@@ -617,23 +622,27 @@ def set_operation_name(operation_name: str) -> None:
@only_if_tracing
-def force_tracing(span=Sentinel) -> None:
+def force_tracing(
+ span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel
+) -> None:
"""Force sampling for the active/given span and its children.
Args:
span: span to force tracing for. By default, the active span.
"""
- if span is Sentinel:
- span = opentracing.tracer.active_span
- if span is None:
+ if isinstance(span, _Sentinel):
+ span_to_trace = opentracing.tracer.active_span
+ else:
+ span_to_trace = span
+ if span_to_trace is None:
logger.error("No active span in force_tracing")
return
- span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+ span_to_trace.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
# also set a bit of baggage, so that we have a way of figuring out if
# it is enabled later
- span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
+ span_to_trace.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
def is_context_forced_tracing(
@@ -789,7 +798,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
-def trace(func=None, opname=None):
+def trace(func=None, opname: Optional[str] = None):
"""
Decorator to trace a function.
Sets the operation name to that of the function's or that given
@@ -822,11 +831,11 @@ def trace(func=None, opname=None):
result = func(*args, **kwargs)
if isinstance(result, defer.Deferred):
- def call_back(result):
+ def call_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
- def err_back(result):
+ def err_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index a26a1a58e7..10877bdfc5 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -16,11 +16,15 @@ import logging
from types import TracebackType
from typing import Optional, Type
-from opentracing import Scope, ScopeManager
+from opentracing import Scope, ScopeManager, Span
import twisted
-from synapse.logging.context import current_context, nested_logging_context
+from synapse.logging.context import (
+ LoggingContext,
+ current_context,
+ nested_logging_context,
+)
logger = logging.getLogger(__name__)
@@ -35,11 +39,11 @@ class LogContextScopeManager(ScopeManager):
but currently that doesn't work due to https://twistedmatrix.com/trac/ticket/10301.
"""
- def __init__(self, config):
+ def __init__(self) -> None:
pass
@property
- def active(self):
+ def active(self) -> Optional[Scope]:
"""
Returns the currently active Scope which can be used to access the
currently active Scope.span.
@@ -48,19 +52,18 @@ class LogContextScopeManager(ScopeManager):
Tracer.start_active_span() time.
Return:
- (Scope) : the Scope that is active, or None if not
- available.
+ The Scope that is active, or None if not available.
"""
ctx = current_context()
return ctx.scope
- def activate(self, span, finish_on_close):
+ def activate(self, span: Span, finish_on_close: bool) -> Scope:
"""
Makes a Span active.
Args
- span (Span): the span that should become active.
- finish_on_close (Boolean): whether Span should be automatically
- finished when Scope.close() is called.
+ span: the span that should become active.
+ finish_on_close: whether Span should be automatically finished when
+ Scope.close() is called.
Returns:
Scope to control the end of the active period for
@@ -112,8 +115,8 @@ class _LogContextScope(Scope):
def __init__(
self,
manager: LogContextScopeManager,
- span,
- logcontext,
+ span: Span,
+ logcontext: LoggingContext,
enter_logcontext: bool,
finish_on_close: bool,
):
@@ -121,13 +124,13 @@ class _LogContextScope(Scope):
Args:
manager:
the manager that is responsible for this scope.
- span (Span):
+ span:
the opentracing span which this scope represents the local
lifetime for.
- logcontext (LogContext):
- the logcontext to which this scope is attached.
+ logcontext:
+ the log context to which this scope is attached.
enter_logcontext:
- if True the logcontext will be exited when the scope is finished
+ if True the log context will be exited when the scope is finished
finish_on_close:
if True finish the span when the scope is closed
"""
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index bfe985939b..f13970b898 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -15,11 +15,11 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomID
from ._base import client_patterns
@@ -104,6 +104,13 @@ class RoomAccountDataServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(
+ 400,
+ f"{room_id} is not a valid room ID",
+ Codes.INVALID_PARAM,
+ )
+
body = parse_json_object_from_request(request)
if account_data_type == "m.fully_read":
@@ -111,6 +118,7 @@ class RoomAccountDataServlet(RestServlet):
405,
"Cannot set m.fully_read through this API."
" Use /rooms/!roomId:server.name/read_markers",
+ Codes.BAD_JSON,
)
await self.handler.add_account_data_to_room(
@@ -130,6 +138,13 @@ class RoomAccountDataServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(
+ 400,
+ f"{room_id} is not a valid room ID",
+ Codes.INVALID_PARAM,
+ )
+
event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c1bd775fec..f4f06563dd 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -95,6 +95,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+ # Supports filtering of /publicRooms by room type MSC3827
+ "org.matrix.msc3827": self.config.experimental.msc3827_enabled,
# Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9d3fe66100..d5cbdb3eef 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -249,8 +249,12 @@ class StateHandler:
partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
+ # we've already taken into account partial state, so no need to wait for
+ # complete state here.
entry = await self.resolve_state_groups_for_events(
- event.room_id, event.prev_event_ids()
+ event.room_id,
+ event.prev_event_ids(),
+ await_full_state=False,
)
state_ids_before_event = entry.state
@@ -335,7 +339,7 @@ class StateHandler:
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: Collection[str]
+ self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -343,6 +347,8 @@ class StateHandler:
Args:
room_id
event_ids
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
Returns:
The resolved state
@@ -350,7 +356,7 @@ class StateHandler:
logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self._state_storage_controller.get_state_group_for_events(
- event_ids
+ event_ids, await_full_state=await_full_state
)
state_group_ids = state_groups.values()
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e8c63cf567..e21ab08515 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -366,10 +366,11 @@ class LoggingTransaction:
*args: P.args,
**kwargs: P.kwargs,
) -> R:
- sql = self._make_sql_one_line(sql)
+ # Generate a one-line version of the SQL to better log it.
+ one_line_sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, one_line_sql)
sql = self.database_engine.convert_param_style(sql)
if args:
@@ -389,7 +390,7 @@ class LoggingTransaction:
"db.query",
tags={
opentracing.tags.DATABASE_TYPE: "sql",
- opentracing.tags.DATABASE_STATEMENT: sql,
+ opentracing.tags.DATABASE_STATEMENT: one_line_sql,
},
):
return func(sql, *args, **kwargs)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 57aaf778ec..a3d31d3737 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -87,7 +87,6 @@ class DataStore(
RoomStore,
RoomBatchStore,
RegistrationStore,
- StreamWorkerStore,
ProfileStore,
PresenceStore,
TransactionWorkerStore,
@@ -112,6 +111,7 @@ class DataStore(
SearchStore,
TagsStore,
AccountDataStore,
+ StreamWorkerStore,
OpenIdStore,
ClientIpWorkerStore,
DeviceStore,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 505616e210..bb6e104d71 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -25,8 +25,8 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -122,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
return DEFAULT_NOTIF_ACTION
-class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
@@ -218,7 +218,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
retcol="event_id",
)
- stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
@@ -307,12 +307,22 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBas
actions that have been deleted from `event_push_actions` table.
"""
+ # If there have been no events in the room since the stream ordering,
+ # there can't be any push actions either.
+ if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
+ return 0, 0
+
clause = ""
args = [user_id, room_id, stream_ordering]
if max_stream_ordering is not None:
clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering)
+ # If the max stream ordering is less than the min stream ordering,
+ # then obviously there are zero push actions in that range.
+ if max_stream_ordering <= stream_ordering:
+ return 0, 0
+
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5760d3428e..d8026e3fac 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
import attr
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ JoinRules,
+ PublicRoomsFilterFields,
+)
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
desc="get_public_room_ids",
)
+ def _construct_room_type_where_clause(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[Union[str, None], List[str]]:
+ if not room_types or not self.config.experimental.msc3827_enabled:
+ return None, []
+ else:
+ # We use None when we want get rooms without a type
+ is_null_clause = ""
+ if None in room_types:
+ is_null_clause = "OR room_type IS NULL"
+ room_types = [value for value in room_types if value is not None]
+
+ list_clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_type", room_types
+ )
+
+ return f"({list_clause} {is_null_clause})", args
+
async def count_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
ignore_non_federatable: bool,
+ search_filter: Optional[dict],
) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Args:
network_tuple
ignore_non_federatable: If true filters out non-federatable rooms
+ search_filter
"""
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
+ {room_type_clause}
AND joined_members > 0
"""
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if ignore_non_federatable:
where_clauses.append("is_federatable")
- if search_filter and search_filter.get("generic_search_term", None):
- search_term = "%" + search_filter["generic_search_term"] + "%"
+ if search_filter and search_filter.get(
+ PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+ ):
+ search_term = (
+ "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+ )
where_clauses.append(
"""
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
search_term.lower(),
]
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ if room_type_clause:
+ where_clauses.append(room_type_clause)
+ query_args += args
+
where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
sql = f"""
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, guest_access, join_rules
+ avatar, history_visibility, guest_access, join_rules, room_type
FROM (
{published_sql}
) published
@@ -1166,6 +1213,7 @@ class _BackgroundUpdates:
POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+ ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column,
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ self._background_add_room_type_column,
+ )
+
# BG updates to change the type of room_depth.min_depth
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
+ async def _background_add_room_type_column(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update to go and add room_type information to `room_stats_state`
+ table from `event_json` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_add_room_type_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+ sql = """
+ SELECT state.room_id, json FROM event_json
+ INNER JOIN current_state_events AS state USING (event_id)
+ WHERE state.room_id > ? AND type = 'm.room.create'
+ ORDER BY state.room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_id_to_create_event_results = txn.fetchall()
+
+ new_last_room_id = None
+ for room_id, event_json in room_id_to_create_event_results:
+ event_dict = db_to_json(event_json)
+
+ room_type = event_dict.get("content", {}).get(
+ EventContentFields.ROOM_TYPE, None
+ )
+ if isinstance(room_type, str):
+ self.db_pool.simple_update_txn(
+ txn,
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_type": room_type},
+ )
+
+ new_last_room_id = room_id
+
+ if new_last_room_id is None:
+ return True
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ {"room_id": new_last_room_id},
+ )
+
+ return False
+
+ end = await self.db_pool.runInteraction(
+ "_background_add_room_type_column",
+ _background_add_room_type_column_txn,
+ )
+
+ if end:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+ )
+
+ return batch_size
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__(
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 82851ffa95..b4c652acf3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing_extensions import Counter
@@ -238,6 +238,7 @@ class StatsStore(StateDeltasStore):
* avatar
* canonical_alias
* guest_access
+ * room_type
A is_federatable key can also be included with a boolean value.
@@ -263,6 +264,7 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
"guest_access",
+ "room_type",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
@@ -572,7 +574,7 @@ class StatsStore(StateDeltasStore):
state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
- room_state = {
+ room_state: Dict[str, Union[None, bool, str]] = {
"join_rules": None,
"history_visibility": None,
"encryption": None,
@@ -581,6 +583,7 @@ class StatsStore(StateDeltasStore):
"avatar": None,
"canonical_alias": None,
"is_federatable": True,
+ "room_type": None,
}
for event in state_event_map.values():
@@ -604,6 +607,9 @@ class StatsStore(StateDeltasStore):
room_state["is_federatable"] = (
event.content.get(EventContentFields.FEDERATE, True) is True
)
+ room_type = event.content.get(EventContentFields.ROOM_TYPE)
+ if isinstance(room_type, str):
+ room_state["room_type"] = room_type
await self.update_room_state(room_id, room_state)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..3a1df7776c 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -46,10 +46,12 @@ from typing import (
Set,
Tuple,
cast,
+ overload,
)
import attr
from frozendict import frozendict
+from typing_extensions import Literal
from twisted.internet import defer
@@ -795,6 +797,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return RoomStreamToken(topo, stream_ordering)
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: Literal[False] = False,
+ ) -> int:
+ ...
+
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ ...
+
def get_stream_id_for_event_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
new file mode 100644
index 0000000000..d5e0765471
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_state ADD room_type TEXT;
+
+INSERT INTO background_updates (update_name, progress_json)
+ VALUES ('add_room_type_column', '{}');
|