diff --git a/changelog.d/14978.feature b/changelog.d/14978.feature
new file mode 100644
index 0000000000..14f6fee658
--- /dev/null
+++ b/changelog.d/14978.feature
@@ -0,0 +1 @@
+Add the ability to enable/disable registrations when in the OIDC flow.
\ No newline at end of file
diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature
index 68b289b0cc..5ce0c029ce 100644
--- a/changelog.d/15314.feature
+++ b/changelog.d/15314.feature
@@ -1 +1 @@
-Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
+Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).
diff --git a/changelog.d/15321.feature b/changelog.d/15321.feature
new file mode 100644
index 0000000000..5ce0c029ce
--- /dev/null
+++ b/changelog.d/15321.feature
@@ -0,0 +1 @@
+Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).
diff --git a/changelog.d/15334.misc b/changelog.d/15334.misc
new file mode 100644
index 0000000000..0c30818ed0
--- /dev/null
+++ b/changelog.d/15334.misc
@@ -0,0 +1 @@
+Speed up unit tests when using SQLite3.
diff --git a/changelog.d/15349.bugfix b/changelog.d/15349.bugfix
new file mode 100644
index 0000000000..65ea7ae7eb
--- /dev/null
+++ b/changelog.d/15349.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where some to_device messages could be dropped when using workers.
diff --git a/changelog.d/15350.misc b/changelog.d/15350.misc
new file mode 100644
index 0000000000..2dea23784f
--- /dev/null
+++ b/changelog.d/15350.misc
@@ -0,0 +1 @@
+Make the `thread_id` column on `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` non-null.
diff --git a/changelog.d/15351.bugfix b/changelog.d/15351.bugfix
new file mode 100644
index 0000000000..e68023c671
--- /dev/null
+++ b/changelog.d/15351.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.70.0 where the background sync from a faster join could spin for hours when one of the events involved had been marked for backoff.
diff --git a/changelog.d/15352.bugfix b/changelog.d/15352.bugfix
new file mode 100644
index 0000000000..36d6615cac
--- /dev/null
+++ b/changelog.d/15352.bugfix
@@ -0,0 +1 @@
+Fix missing app variable in mail subject for password resets. Contributed by Cyberes.
diff --git a/changelog.d/15354.misc b/changelog.d/15354.misc
new file mode 100644
index 0000000000..862444edfb
--- /dev/null
+++ b/changelog.d/15354.misc
@@ -0,0 +1 @@
+Add some clarification to the doc/comments regarding TCP replication.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index 15df949deb..083cda8413 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -25,7 +25,7 @@ position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where
the format of `<row>` is defined by the individual streams. The
`<instance_name>` is the name of the Synapse process that generated the data
-(usually "master").
+(usually "master"). We expect an RDATA for every row in the DB.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -107,7 +107,7 @@ reconnect, following the steps above.
If the server sends messages faster than the client can consume them the
server will first buffer a (fairly large) number of commands and then
disconnect the client. This ensures that we don't queue up an unbounded
-number of commands in memory and gives us a potential oppurtunity to
+number of commands in memory and gives us a potential opportunity to
squawk loudly. When/if the client recovers it can reconnect to the
server and ask for missed messages.
@@ -122,7 +122,7 @@ since these include tokens which can be used to restart the stream on
connection errors.
The client should keep track of the token in the last RDATA command
-received for each stream so that on reconneciton it can start streaming
+received for each stream so that on reconnection it can start streaming
from the correct place. Note: not all RDATA have valid tokens due to
batching. See `RdataCommand` for more details.
@@ -188,7 +188,8 @@ client (C):
Two positions are included, the "new" position and the last position sent respectively.
This allows servers to tell instances that the positions have advanced but no
data has been written, without clients needlessly checking to see if they
- have missed any updates.
+ have missed any updates. Instances will only fetch stuff if there is a gap between
+ their current position and the given last position.
#### ERROR (S, C)
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 060d0d5e69..c5c2c2b615 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -3100,6 +3100,11 @@ Options for each entry include:
match a pre-existing account instead of failing. This could be used if
switching from password logins to OIDC. Defaults to false.
+* `enable_registration`: set to 'false' to disable automatic registration of new
+ users. This allows the OIDC SSO flow to be limited to sign in only, rather than
+ automatically registering users that have a valid SSO login but do not have
+ a pre-registered account. Defaults to true.
+
* `user_mapping_provider`: Configuration for how attributes returned from a OIDC
provider are mapped onto a matrix user. This setting has the following
sub-properties:
@@ -3216,6 +3221,7 @@ oidc_providers:
userinfo_endpoint: "https://accounts.example.com/userinfo"
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
skip_verification: true
+ enable_registration: true
user_mapping_provider:
config:
subject_claim: "id"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 8c6822f3c6..f2d6f9ab2d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -27,7 +27,7 @@ from synapse.util import json_decoder
if typing.TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
- from synapse.types import JsonDict
+ from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__)
@@ -682,18 +682,27 @@ class FederationPullAttemptBackoffError(RuntimeError):
Attributes:
event_id: The event_id which we are refusing to pull
message: A custom error message that gives more context
+ retry_after_ms: The remaining backoff interval, in milliseconds
"""
- def __init__(self, event_ids: List[str], message: Optional[str]):
- self.event_ids = event_ids
+ def __init__(
+ self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int
+ ):
+ event_ids = list(event_ids)
if message:
error_message = message
else:
- error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)."
+ error_message = (
+ f"Not attempting to pull event_ids={event_ids} because we already "
+ "tried to pull them recently (backing off)."
+ )
super().__init__(error_message)
+ self.event_ids = event_ids
+ self.retry_after_ms = retry_after_ms
+
class HttpResponseException(CodeMessageException):
"""
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 51ee0e79df..b27eedef99 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -30,7 +30,7 @@ from prometheus_client import Counter
from typing_extensions import TypeGuard
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
-from synapse.api.errors import CodeMessageException
+from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeysCount,
@@ -38,7 +38,7 @@ from synapse.appservice import (
)
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
-from synapse.http.client import SimpleHttpClient
+from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
@@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient):
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from an application service.
+ Note that any error (including a timeout) is treated as the application
+ service having no information.
+
Args:
+ service: The application service to query.
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
@@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
- except CodeMessageException as e:
+ except HttpResponseException as e:
# The appservice doesn't support this endpoint.
- if e.code == 404 or e.code == 405:
+ if is_unknown_endpoint(e):
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
@@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient):
return response, missing
+ async def query_keys(
+ self, service: "ApplicationService", query: Dict[str, List[str]]
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Query the application service for keys.
+
+ Note that any error (including a timeout) is treated as the application
+ service having no information.
+
+ Args:
+ service: The application service to query.
+ query: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
+
+ device_keys is a map of user ID -> a map device ID -> device info.
+ """
+ if service.url is None:
+ return {}
+
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
+ uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
+ try:
+ response = await self.post_json_get_json(
+ uri,
+ query,
+ headers={"Authorization": [f"Bearer {service.hs_token}"]},
+ )
+ except HttpResponseException as e:
+ # The appservice doesn't support this endpoint.
+ if is_unknown_endpoint(e):
+ return {}
+ logger.warning("query_keys to %s received %s", uri, e.code)
+ return {}
+ except Exception as ex:
+ logger.warning("query_keys to %s threw exception %s", uri, ex)
+ return {}
+
+ return response
+
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53e6fc2b54..7687c80ea0 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -79,6 +79,11 @@ class ExperimentalConfig(Config):
"msc3983_appservice_otk_claims", False
)
+ # MSC3984: Proxying key queries to exclusive ASes.
+ self.msc3984_appservice_key_query: bool = experimental.get(
+ "msc3984_appservice_key_query", False
+ )
+
# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index df8c422043..77c1d1dc8e 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -136,6 +136,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "array",
"items": SsoAttributeRequirement.JSON_SCHEMA,
},
+ "enable_registration": {"type": "boolean"},
},
}
@@ -306,6 +307,7 @@ def _parse_oidc_config_dict(
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
attribute_requirements=attribute_requirements,
+ enable_registration=oidc_config.get("enable_registration", True),
)
@@ -405,3 +407,6 @@ class OidcProviderConfig:
# required attributes to require in userinfo to allow login/registration
attribute_requirements: List[SsoAttributeRequirement]
+
+ # Whether automatic registrations are enabled in the ODIC flow. Defaults to True
+ enable_registration: bool
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7d04560dca..4cf4957a42 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
event_from_pdu_json,
)
from synapse.federation.transport.client import SendJoinResponse
+from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
@@ -759,43 +760,6 @@ class FederationClient(FederationBase):
return signed_auth
- def _is_unknown_endpoint(
- self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
- ) -> bool:
- """
- Returns true if the response was due to an endpoint being unimplemented.
-
- Args:
- e: The error response received from the remote server.
- synapse_error: The above error converted to a SynapseError. This is
- automatically generated if not provided.
-
- """
- if synapse_error is None:
- synapse_error = e.to_synapse_error()
- # MSC3743 specifies that servers should return a 404 or 405 with an errcode
- # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
- # to an unknown method, respectively.
- #
- # Older versions of servers don't properly handle this. This needs to be
- # rather specific as some endpoints truly do return 404 errors.
- return (
- # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
- (e.code == 404 or e.code == 405)
- and (
- # Older Dendrites returned a text or empty body.
- # Older Conduit returned an empty body.
- not e.response
- or e.response == b"404 page not found"
- # The proper response JSON with M_UNRECOGNIZED errcode.
- or synapse_error.errcode == Codes.UNRECOGNIZED
- )
- ) or (
- # Older Synapses returned a 400 error.
- e.code == 400
- and synapse_error.errcode == Codes.UNRECOGNIZED
- )
-
async def _try_destination_list(
self,
description: str,
@@ -887,7 +851,7 @@ class FederationClient(FederationBase):
elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
failover = True
- elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
+ elif failover_on_unknown_endpoint and is_unknown_endpoint(
e, synapse_error
):
failover = True
@@ -1223,7 +1187,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
- if not self._is_unknown_endpoint(e):
+ if not is_unknown_endpoint(e):
raise
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
@@ -1297,7 +1261,7 @@ class FederationClient(FederationBase):
# fallback to the v1 endpoint if the room uses old-style event IDs.
# Otherwise, consider it a legitimate error and raise.
err = e.to_synapse_error()
- if self._is_unknown_endpoint(e, err):
+ if is_unknown_endpoint(e, err):
if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
raise SynapseError(
400,
@@ -1358,7 +1322,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
- if not self._is_unknown_endpoint(e):
+ if not is_unknown_endpoint(e):
raise
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
@@ -1629,7 +1593,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the unstable endpoint. Otherwise, consider it a
# legitimate error and raise.
- if not self._is_unknown_endpoint(e):
+ if not is_unknown_endpoint(e):
raise
logger.debug(
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 953df4d9cd..da887647d4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,6 +18,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -846,6 +847,10 @@ class ApplicationServicesHandler:
]:
"""Claim one time keys from application services.
+ Users which are exclusively owned by an application service are sent a
+ key claim request to check if the application service provides keys
+ directly.
+
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
@@ -901,3 +906,59 @@ class ApplicationServicesHandler:
missing.extend(result[1])
return claimed_keys, missing
+
+ async def query_keys(
+ self, query: Mapping[str, Optional[List[str]]]
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Query application services for device keys.
+
+ Users which are exclusively owned by an application service are queried
+ for keys to check if the application service provides keys directly.
+
+ Args:
+ query: map from user_id to a list of devices to query
+
+ Returns:
+ A map from user_id -> device_id -> device details
+ """
+ services = self.store.get_app_services()
+
+ # Partition the users by appservice.
+ query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
+ for user_id, device_ids in query.items():
+ if not self.store.get_if_app_services_interested_in_user(user_id):
+ continue
+
+ # Find the associated appservice.
+ for service in services:
+ if service.is_exclusive_user(user_id):
+ query_by_appservice.setdefault(service.id, {})[user_id] = (
+ device_ids or []
+ )
+ continue
+
+ # Query each service in parallel.
+ results = await make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.query_keys,
+ # We know this must be an app service.
+ self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
+ service_query,
+ )
+ for service_id, service_query in query_by_appservice.items()
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ # Patch together the results -- they are all independent (since they
+ # require exclusive control over the users). They get returned as a single
+ # dictionary.
+ key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ for success, result in results:
+ if success:
+ key_queries.update(result)
+
+ return key_queries
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9e7c2c45b5..0073667470 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -91,6 +91,9 @@ class E2eKeysHandler:
self._query_appservices_for_otks = (
hs.config.experimental.msc3983_appservice_otk_claims
)
+ self._query_appservices_for_keys = (
+ hs.config.experimental.msc3984_appservice_key_query
+ )
@trace
@cancellable
@@ -497,6 +500,19 @@ class E2eKeysHandler:
local_query, include_displaynames
)
+ # Check if the application services have any additional results.
+ if self._query_appservices_for_keys:
+ # Query the appservices for any keys.
+ appservice_results = await self._appservice_handler.query_keys(query)
+
+ # Merge results, overriding with what the appservice returned.
+ for user_id, devices in appservice_results.get("device_keys", {}).items():
+ # Copy the appservice device info over the homeserver device info, but
+ # don't completely overwrite it.
+ results.setdefault(user_id, {}).update(devices)
+
+ # TODO Handle cross-signing keys.
+
# Build the result structure
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 80156ef343..65461a0787 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1949,27 +1949,25 @@ class FederationHandler:
)
for event in events:
for attempt in itertools.count():
+ # We try a new destination on every iteration.
try:
- await self._federation_event_handler.update_state_for_partial_state_event(
- destination, event
- )
+ while True:
+ try:
+ await self._federation_event_handler.update_state_for_partial_state_event(
+ destination, event
+ )
+ break
+ except FederationPullAttemptBackoffError as e:
+ # We are in the backoff period for one of the event's
+ # prev_events. Wait it out and try again after.
+ logger.warning(
+ "%s; waiting for %d ms...", e, e.retry_after_ms
+ )
+ await self.clock.sleep(e.retry_after_ms / 1000)
+
+ # Success, no need to try the rest of the destinations.
break
- except FederationPullAttemptBackoffError as exc:
- # Log a warning about why we failed to process the event (the error message
- # for `FederationPullAttemptBackoffError` is pretty good)
- logger.warning("_sync_partial_state_room: %s", exc)
- # We do not record a failed pull attempt when we backoff fetching a missing
- # `prev_event` because not being able to fetch the `prev_events` just means
- # we won't be able to de-outlier the pulled event. But we can still use an
- # `outlier` in the state/auth chain for another event. So we shouldn't stop
- # a downstream event from trying to pull it.
- #
- # This avoids a cascade of backoff for all events in the DAG downstream from
- # one event backoff upstream.
except FederationError as e:
- # TODO: We should `record_event_failed_pull_attempt` here,
- # see https://github.com/matrix-org/synapse/issues/13700
-
if attempt == len(destinations) - 1:
# We have tried every remote server for this event. Give up.
# TODO(faster_joins) giving up isn't the right thing to do
@@ -1986,6 +1984,8 @@ class FederationHandler:
destination,
e,
)
+ # TODO: We should `record_event_failed_pull_attempt` here,
+ # see https://github.com/matrix-org/synapse/issues/13700
raise
# Try the next remote server.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 648843cdbe..982c8d3b2f 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -140,6 +140,7 @@ class FederationEventHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -1038,8 +1039,8 @@ class FederationEventHandler:
Raises:
FederationPullAttemptBackoffError if we are are deliberately not attempting
- to pull the given event over federation because we've already done so
- recently and are backing off.
+ to pull one of the given event's `prev_event`s over federation because
+ we've already done so recently and are backing off.
FederationError if we fail to get the state from the remote server after any
missing `prev_event`s.
"""
@@ -1053,13 +1054,22 @@ class FederationEventHandler:
# If we've already recently attempted to pull this missing event, don't
# try it again so soon. Since we have to fetch all of the prev_events, we can
# bail early here if we find any to ignore.
- prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
- room_id, missing_prevs
+ prevs_with_pull_backoff = (
+ await self._store.get_event_ids_to_not_pull_from_backoff(
+ room_id, missing_prevs
+ )
)
- if len(prevs_to_ignore) > 0:
+ if len(prevs_with_pull_backoff) > 0:
raise FederationPullAttemptBackoffError(
- event_ids=prevs_to_ignore,
- message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
+ event_ids=prevs_with_pull_backoff.keys(),
+ message=(
+ f"While computing context for event={event_id}, not attempting to "
+ f"pull missing prev_events={list(prevs_with_pull_backoff.keys())} "
+ "because we already tried to pull recently (backing off)."
+ ),
+ retry_after_ms=(
+ max(prevs_with_pull_backoff.values()) - self._clock.time_msec()
+ ),
)
if not missing_prevs:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 0fc829acf7..e7e0b5e049 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -1239,6 +1239,7 @@ class OidcProvider:
grandfather_existing_users,
extra_attributes,
auth_provider_session_id=sid,
+ registration_enabled=self._config.enable_registration,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 4a27c0f051..c28325323c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -383,6 +383,7 @@ class SsoHandler:
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
+ registration_enabled: bool = True,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -435,6 +436,10 @@ class SsoHandler:
auth_provider_session_id: An optional session ID from the IdP.
+ registration_enabled: An optional boolean to enable/disable automatic
+ registrations of new users. If false and the user does not exist then the
+ flow is aborted. Defaults to true.
+
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@@ -462,8 +467,16 @@ class SsoHandler:
auth_provider_id, remote_user_id, user_id
)
- # Otherwise, generate a new user.
- if not user_id:
+ if not user_id and not registration_enabled:
+ logger.info(
+ "User does not exist and registration are disabled for IdP '%s' and remote_user_id '%s'",
+ auth_provider_id,
+ remote_user_id,
+ )
+ raise MappingException(
+ "User does not exist and registrations are disabled"
+ )
+ elif not user_id: # Otherwise, generate a new user.
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
next_step_url = self._get_url_for_next_new_user_step(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d777d59ccf..5ee55981d9 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
return self
+
+
+def is_unknown_endpoint(
+ e: HttpResponseException, synapse_error: Optional[SynapseError] = None
+) -> bool:
+ """
+ Returns true if the response was due to an endpoint being unimplemented.
+
+ Args:
+ e: The error response received from the remote server.
+ synapse_error: The above error converted to a SynapseError. This is
+ automatically generated if not provided.
+
+ """
+ if synapse_error is None:
+ synapse_error = e.to_synapse_error()
+ # MSC3743 specifies that servers should return a 404 or 405 with an errcode
+ # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
+ # to an unknown method, respectively.
+ #
+ # Older versions of servers don't properly handle this. This needs to be
+ # rather specific as some endpoints truly do return 404 errors.
+ return (
+ # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
+ (e.code == 404 or e.code == 405)
+ and (
+ # Older Dendrites returned a text body or empty body.
+ # Older Conduit returned an empty body.
+ not e.response
+ or e.response == b"404 page not found"
+ # The proper response JSON with M_UNRECOGNIZED errcode.
+ or synapse_error.errcode == Codes.UNRECOGNIZED
+ )
+ ) or (
+ # Older Synapses returned a 400 error.
+ e.code == 400
+ and synapse_error.errcode == Codes.UNRECOGNIZED
+ )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 93b255ced5..491a09b71d 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -149,7 +149,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.password_reset
- % {"server_name": self.hs.config.server.server_name},
+ % {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 56a5c21910..a7248d7b2e 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -14,36 +14,7 @@
"""This module contains the implementation of both the client and server
protocols.
-The basic structure of the protocol is line based, where the initial word of
-each line specifies the command. The rest of the line is parsed based on the
-command. For example, the `RDATA` command is defined as::
-
- RDATA <stream_name> <token> <row_json>
-
-(Note that `<row_json>` may contains spaces, but cannot contain newlines.)
-
-Blank lines are ignored.
-
-# Example
-
-An example iteraction is shown below. Each line is prefixed with '>' or '<' to
-indicate which side is sending, these are *not* included on the wire::
-
- * connection established *
- > SERVER localhost:8823
- > PING 1490197665618
- < NAME synapse.app.appservice
- < PING 1490197665618
- < REPLICATE
- > POSITION events 1
- > POSITION backfill 1
- > POSITION caches 1
- > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
- > RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823",
- "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]]
- < PING 1490197675618
- > ERROR server stopping
- * connection closed by server *
+An explanation of this protocol is available in docs/tcp_replication.md
"""
import fcntl
import logging
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index a4bdb48c0c..c6088a0f99 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -152,8 +152,8 @@ class Stream:
Returns:
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
- position in stream, and `limited` is whether there are more updates
- to fetch.
+ position in stream (ie the highest token returned in the updates),
+ and `limited` is whether there are more updates to fetch.
"""
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 994c50116f..25f70fee84 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -617,14 +617,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
- upper_pos = min(current_id, last_id + limit)
+ upto_token = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
- txn.execute(sql, (last_id, upper_pos))
+ txn.execute(sql, (last_id, upto_token))
updates = [(row[0], row[1:]) for row in txn]
sql = (
@@ -633,19 +633,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
- txn.execute(sql, (last_id, upper_pos))
+ txn.execute(sql, (last_id, upto_token))
updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
updates.sort()
- limited = False
- upto_token = current_id
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
+ return updates, upto_token, upto_token < current_id
return await self.db_pool.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ff3edeb716..a19ba88bf8 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1544,7 +1544,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self,
room_id: str,
event_ids: Collection[str],
- ) -> List[str]:
+ ) -> Dict[str, int]:
"""
Filter down the events to ones that we've failed to pull before recently. Uses
exponential backoff.
@@ -1554,7 +1554,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_ids: A list of events to filter down
Returns:
- List of event_ids that should not be attempted to be pulled
+ A dictionary of event_ids that should not be attempted to be pulled and the
+ next timestamp at which we may try pulling them again.
"""
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
@@ -1570,22 +1571,28 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
current_time = self._clock.time_msec()
- return [
- event_failed_pull_attempt["event_id"]
- for event_failed_pull_attempt in event_failed_pull_attempts
+
+ event_ids_with_backoff = {}
+ for event_failed_pull_attempt in event_failed_pull_attempts:
+ event_id = event_failed_pull_attempt["event_id"]
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
- if current_time
- < event_failed_pull_attempt["last_attempt_ts"]
- + (
- 2
- ** min(
- event_failed_pull_attempt["num_attempts"],
- BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
+ backoff_end_time = (
+ event_failed_pull_attempt["last_attempt_ts"]
+ + (
+ 2
+ ** min(
+ event_failed_pull_attempt["num_attempts"],
+ BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
+ )
)
+ * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
)
- * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
- ]
+
+ if current_time < backoff_end_time: # `backoff_end_time` is exclusive
+ event_ids_with_backoff[event_id] = backoff_end_time
+
+ return event_ids_with_backoff
async def get_missing_events(
self,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eeccf5db24..6afc51320a 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -100,7 +100,6 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
-from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -289,180 +288,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
unique=True,
)
- self.db_pool.updates.register_background_update_handler(
- "event_push_backfill_thread_id",
- self._background_backfill_thread_id,
- )
-
- # Indexes which will be used to quickly make the thread_id column non-null.
- self.db_pool.updates.register_background_index_update(
- "event_push_actions_thread_id_null",
- index_name="event_push_actions_thread_id_null",
- table="event_push_actions",
- columns=["thread_id"],
- where_clause="thread_id IS NULL",
- )
- self.db_pool.updates.register_background_index_update(
- "event_push_summary_thread_id_null",
- index_name="event_push_summary_thread_id_null",
- table="event_push_summary",
- columns=["thread_id"],
- where_clause="thread_id IS NULL",
- )
-
- # Check ASAP (and then later, every 1s) to see if we have finished
- # background updates the event_push_actions and event_push_summary tables.
- self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
- self._event_push_backfill_thread_id_done = False
-
- @wrap_as_background_process("check_event_push_backfill_thread_id")
- async def _check_event_push_backfill_thread_id(self) -> None:
- """
- Has thread_id finished backfilling?
-
- If not, we need to just-in-time update it so the queries work.
- """
- done = await self.db_pool.updates.has_completed_background_update(
- "event_push_backfill_thread_id"
- )
-
- if done:
- self._event_push_backfill_thread_id_done = True
- else:
- # Reschedule to run.
- self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
-
- async def _background_backfill_thread_id(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """
- Fill in the thread_id field for event_push_actions and event_push_summary.
-
- This is preparatory so that it can be made non-nullable in the future.
-
- Because all current (null) data is done in an unthreaded manner this
- simply assumes it is on the "main" timeline. Since event_push_actions
- are periodically cleared it is not possible to correctly re-calculate
- the thread_id.
- """
- event_push_actions_done = progress.get("event_push_actions_done", False)
-
- def add_thread_id_txn(
- txn: LoggingTransaction, start_stream_ordering: int
- ) -> int:
- sql = """
- SELECT stream_ordering
- FROM event_push_actions
- WHERE
- thread_id IS NULL
- AND stream_ordering > ?
- ORDER BY stream_ordering
- LIMIT ?
- """
- txn.execute(sql, (start_stream_ordering, batch_size))
-
- # No more rows to process.
- rows = txn.fetchall()
- if not rows:
- progress["event_push_actions_done"] = True
- self.db_pool.updates._background_update_progress_txn(
- txn, "event_push_backfill_thread_id", progress
- )
- return 0
-
- # Update the thread ID for any of those rows.
- max_stream_ordering = rows[-1][0]
-
- sql = """
- UPDATE event_push_actions
- SET thread_id = 'main'
- WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
- """
- txn.execute(
- sql,
- (
- start_stream_ordering,
- max_stream_ordering,
- ),
- )
-
- # Update progress.
- processed_rows = txn.rowcount
- progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
- self.db_pool.updates._background_update_progress_txn(
- txn, "event_push_backfill_thread_id", progress
- )
-
- return processed_rows
-
- def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
- min_user_id = progress.get("max_summary_user_id", "")
- min_room_id = progress.get("max_summary_room_id", "")
-
- # Slightly overcomplicated query for getting the Nth user ID / room
- # ID tuple, or the last if there are less than N remaining.
- sql = """
- SELECT user_id, room_id FROM (
- SELECT user_id, room_id FROM event_push_summary
- WHERE (user_id, room_id) > (?, ?)
- AND thread_id IS NULL
- ORDER BY user_id, room_id
- LIMIT ?
- ) AS e
- ORDER BY user_id DESC, room_id DESC
- LIMIT 1
- """
-
- txn.execute(sql, (min_user_id, min_room_id, batch_size))
- row = txn.fetchone()
- if not row:
- return 0
-
- max_user_id, max_room_id = row
-
- sql = """
- UPDATE event_push_summary
- SET thread_id = 'main'
- WHERE
- (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
- AND thread_id IS NULL
- """
- txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
- processed_rows = txn.rowcount
-
- progress["max_summary_user_id"] = max_user_id
- progress["max_summary_room_id"] = max_room_id
- self.db_pool.updates._background_update_progress_txn(
- txn, "event_push_backfill_thread_id", progress
- )
-
- return processed_rows
-
- # First update the event_push_actions table, then the event_push_summary table.
- #
- # Note that the event_push_actions_staging table is ignored since it is
- # assumed that items in that table will only exist for a short period of
- # time.
- if not event_push_actions_done:
- result = await self.db_pool.runInteraction(
- "event_push_backfill_thread_id",
- add_thread_id_txn,
- progress.get("max_event_push_actions_stream_ordering", 0),
- )
- else:
- result = await self.db_pool.runInteraction(
- "event_push_backfill_thread_id",
- add_thread_id_summary_txn,
- )
-
- # Only done after the event_push_summary table is done.
- if not result:
- await self.db_pool.updates._end_background_update(
- "event_push_backfill_thread_id"
- )
-
- return result
-
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.
@@ -711,25 +536,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
- # First ensure that the existing rows have an updated thread_id field.
- if not self._event_push_backfill_thread_id_done:
- txn.execute(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- (MAIN_TIMELINE, room_id, user_id),
- )
- txn.execute(
- """
- UPDATE event_push_actions
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- (MAIN_TIMELINE, room_id, user_id),
- )
-
# First we pull the counts from the summary table.
#
# We check that `last_receipt_stream_ordering` matches the stream ordering of the
@@ -1545,25 +1351,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering, *thread_args),
)
- # First ensure that the existing rows have an updated thread_id field.
- if not self._event_push_backfill_thread_id_done:
- txn.execute(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- (MAIN_TIMELINE, room_id, user_id),
- )
- txn.execute(
- """
- UPDATE event_push_actions
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- (MAIN_TIMELINE, room_id, user_id),
- )
-
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
unread_counts = self._get_notif_unread_count_for_user_room(
@@ -1698,19 +1485,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
"""
- # Ensure that any new actions have an updated thread_id.
- if not self._event_push_backfill_thread_id_done:
- txn.execute(
- """
- UPDATE event_push_actions
- SET thread_id = ?
- WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
- """,
- (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
- )
-
- # XXX Do we need to update summaries here too?
-
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id, thread_id,
@@ -1773,20 +1547,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
- # Ensure that any updated threads have the proper thread_id.
- if not self._event_push_backfill_thread_id_done:
- txn.execute_batch(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- [
- (MAIN_TIMELINE, room_id, user_id)
- for user_id, room_id, _ in summaries
- ],
- )
-
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 28751e89a5..ca8c59297c 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -34,6 +34,13 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
":memory:",
)
+ # A connection to a database that has already been prepared, to use as a
+ # base for an in-memory connection. This is used during unit tests to
+ # speed up setting up the DB.
+ self._prepped_conn: Optional[sqlite3.Connection] = database_config.get(
+ "_TEST_PREPPED_CONN"
+ )
+
if platform.python_implementation() == "PyPy":
# pypy's sqlite3 module doesn't handle bytearrays, convert them
# back to bytes.
@@ -84,7 +91,15 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
# In memory databases need to be rebuilt each time. Ideally we'd
# reuse the same connection as we do when starting up, but that
# would involve using adbapi before we have started the reactor.
- prepare_database(db_conn, self, config=None)
+ #
+ # If we have a `prepped_conn` we can use that to initialise the DB,
+ # otherwise we need to call `prepare_database`.
+ if self._prepped_conn is not None:
+ # Initialise the new DB from the pre-prepared DB.
+ assert isinstance(db_conn.conn, sqlite3.Connection)
+ self._prepped_conn.backup(db_conn.conn)
+ else:
+ prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index d3103a6c7a..72bbb3a7c2 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -95,9 +95,9 @@ Changes in SCHEMA_VERSION = 74:
SCHEMA_COMPAT_VERSION = (
- # The threads_id column must exist for event_push_actions, event_push_summary,
- # receipts_linearized, and receipts_graph.
- 73
+ # The threads_id column must written to with non-null values event_push_actions,
+ # event_push_actions_staging, and event_push_summary.
+ 74
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/74/02thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/74/02thread_notifications_backfill.sql
new file mode 100644
index 0000000000..ce6f9ff937
--- /dev/null
+++ b/synapse/storage/schema/main/delta/74/02thread_notifications_backfill.sql
@@ -0,0 +1,28 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Force the background updates from 06thread_notifications.sql to run in the
+-- foreground as code will now require those to be "done".
+
+DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id';
+
+-- Overwrite any null thread_id values.
+UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL;
+UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL;
+UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL;
+
+-- Drop the background updates to calculate the indexes used to find null thread_ids.
+DELETE FROM background_updates WHERE update_name = 'event_push_actions_thread_id_null';
+DELETE FROM background_updates WHERE update_name = 'event_push_summary_thread_id_null';
diff --git a/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.postgres
new file mode 100644
index 0000000000..5f68667425
--- /dev/null
+++ b/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop the indexes used to find null thread_ids.
+DROP INDEX IF EXISTS event_push_actions_thread_id_null;
+DROP INDEX IF EXISTS event_push_summary_thread_id_null;
+
+-- The thread_id columns can now be made non-nullable.
+ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL;
+ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL;
+ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL;
diff --git a/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.sqlite
new file mode 100644
index 0000000000..f46b233560
--- /dev/null
+++ b/synapse/storage/schema/main/delta/74/03thread_notifications_not_null.sql.sqlite
@@ -0,0 +1,99 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- The thread_id columns can now be made non-nullable.
+--
+-- SQLite doesn't support modifying columns to an existing table, so it must
+-- be recreated.
+
+-- Create the new tables.
+CREATE TABLE event_push_actions_staging_new (
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ notif SMALLINT NOT NULL,
+ highlight SMALLINT NOT NULL,
+ unread SMALLINT,
+ thread_id TEXT NOT NULL,
+ inserted_ts BIGINT
+);
+
+CREATE TABLE event_push_actions_new (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ profile_tag VARCHAR(32),
+ actions TEXT NOT NULL,
+ topological_ordering BIGINT,
+ stream_ordering BIGINT,
+ notif SMALLINT,
+ highlight SMALLINT,
+ unread SMALLINT,
+ thread_id TEXT NOT NULL,
+ CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
+);
+
+CREATE TABLE event_push_summary_new (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notif_count BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ unread_count BIGINT,
+ last_receipt_stream_ordering BIGINT,
+ thread_id TEXT NOT NULL
+);
+
+-- Copy the data.
+INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts)
+ SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts
+ FROM event_push_actions_staging;
+
+INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id)
+ SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id
+ FROM event_push_actions;
+
+INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id)
+ SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id
+ FROM event_push_summary;
+
+-- Drop the old tables.
+DROP TABLE event_push_actions_staging;
+DROP TABLE event_push_actions;
+DROP TABLE event_push_summary;
+
+-- Rename the tables.
+ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging;
+ALTER TABLE event_push_actions_new RENAME TO event_push_actions;
+ALTER TABLE event_push_summary_new RENAME TO event_push_summary;
+
+-- Recreate the indexes.
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
+
+CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering);
+CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering );
+CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);
+CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id );
+CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering);
+
+CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary (user_id, room_id, thread_id) ;
+
+-- Recreate some indexes in the background, by re-running the background updates
+-- from 72/02event_push_actions_index.sql and 72/06thread_notifications.sql.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7403, 'event_push_summary_unique_index2', '{}')
+ ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7403, 'event_push_actions_stream_highlight_index', '{}')
+ ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 4ff04fc66b..013b9ee550 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -960,7 +960,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
- namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
@@ -1015,3 +1015,122 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
},
)
+
+ @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
+ def test_query_local_devices_appservice(self) -> None:
+ """Test that querying of appservices for keys overrides responses from the database."""
+ local_user = "@boris:" + self.hs.hostname
+ device_1 = "abc"
+ device_2 = "def"
+ device_3 = "ghi"
+
+ # There are 3 devices:
+ #
+ # 1. One which is uploaded to the homeserver.
+ # 2. One which is uploaded to the homeserver, but a newer copy is returned
+ # by the appservice.
+ # 3. One which is only returned by the appservice.
+ device_key_1: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_1,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2a: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ device_key_2b: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ # The device ID is the same (above), but the keys are different.
+ "keys": {
+ "ed25519:xyz": "base64+ed25519+key",
+ "curve25519:xyz": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
+ }
+ device_key_3: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_3,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:jkl": "base64+ed25519+key",
+ "curve25519:jkl": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
+ }
+
+ # Upload keys for devices 1 & 2a.
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_1, {"device_keys": device_key_1}
+ )
+ )
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_2, {"device_keys": device_key_2a}
+ )
+ )
+
+ # Inject an appservice interested in this user.
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+ self.hs.get_datastores().main.services_cache = [appservice]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [appservice]
+ )
+
+ # Setup a response.
+ self.appservice_api.query_keys.return_value = make_awaitable(
+ {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
+ }
+ }
+ )
+
+ # Request all devices.
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
+ self.assertIn(local_user, res)
+ for res_key in res[local_user].values():
+ res_key.pop("unsigned", None)
+ self.assertDictEqual(
+ res,
+ {
+ local_user: {
+ device_1: device_key_1,
+ device_2: device_key_2b,
+ device_3: device_key_3,
+ }
+ },
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 951caaa6b3..0a8bae54fb 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -922,7 +922,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- @override_config({"oidc_config": DEFAULT_CONFIG})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": True}})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
userinfo: dict = {
@@ -975,6 +975,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": False}})
+ def test_map_userinfo_to_user_does_not_register_new_user(self) -> None:
+ """Ensures new users are not registered if the enabled registration flag is disabled."""
+ userinfo: dict = {
+ "sub": "test_user",
+ "username": "test_user",
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "User does not exist and registrations are disabled",
+ )
+
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 46a8e2013e..0f1a8a145f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -54,6 +54,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 01df1be047..b9075e3f20 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -37,11 +37,6 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
new file mode 100644
index 0000000000..fb9eac668f
--- /dev/null
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
+from synapse.types import JsonDict
+
+from tests.replication._base import BaseStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class ToDeviceStreamTestCase(BaseStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ ]
+
+ def test_to_device_stream(self) -> None:
+ store = self.hs.get_datastores().main
+
+ user1 = self.register_user("user1", "pass")
+ self.login("user1", "pass", "device")
+ user2 = self.register_user("user2", "pass")
+ self.login("user2", "pass", "device")
+
+ # connect to pull the updates related to users creation/login
+ self.reconnect()
+ self.replicate()
+ self.test_handler.received_rdata_rows.clear()
+ # disconnect so we can accumulate the updates without pulling them
+ self.disconnect()
+
+ msg: JsonDict = {}
+ msg["sender"] = "@sender:example.org"
+ msg["type"] = "m.new_device"
+
+ # add messages to the device inbox for user1 up until the
+ # limit defined for a stream update batch
+ for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
+ msg["content"] = {"device": {}}
+ messages = {user1: {"device": msg}}
+
+ self.get_success(
+ store.add_messages_from_remote_to_device_inbox(
+ "example.org",
+ f"{i}",
+ messages,
+ )
+ )
+
+ # add one more message, for user2 this time
+ # this message would be dropped before fixing #15335
+ msg["content"] = {"device": {}}
+ messages = {user2: {"device": msg}}
+
+ self.get_success(
+ store.add_messages_from_remote_to_device_inbox(
+ "example.org",
+ f"{_STREAM_UPDATE_TARGET_ROW_COUNT}",
+ messages,
+ )
+ )
+
+ # replication is disconnected so we shouldn't get any updates yet
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should receive the fact that we have to_device updates
+ # for user1 and user2
+ received_rows = self.test_handler.received_rdata_rows
+ self.assertEqual(len(received_rows), 2)
+ self.assertEqual(received_rows[0][2].entity, user1)
+ self.assertEqual(received_rows[1][2].entity, user2)
diff --git a/tests/server.py b/tests/server.py
index bb059630fa..b52ff1c463 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ import json
import logging
import os
import os.path
+import sqlite3
import time
import uuid
import warnings
@@ -79,7 +80,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
@@ -104,6 +107,10 @@ P = ParamSpec("P")
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+# A pre-prepared SQLite DB that is used as a template when creating new SQLite
+# DB each test run. This dramatically speeds up test set up when using SQLite.
+PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
+
class TimedOutException(Exception):
"""
@@ -899,6 +906,22 @@ def setup_test_homeserver(
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
+ # Check if we have set up a DB that we can use as a template.
+ global PREPPED_SQLITE_DB_CONN
+ if PREPPED_SQLITE_DB_CONN is None:
+ temp_engine = create_engine(database_config)
+ PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
+ sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
+ )
+
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+ prepare_database(
+ PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
+ )
+
+ database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
+
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3e1984c15c..81e50bdd55 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -1143,19 +1143,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
tok = self.login("alice", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ failure_time = self.clock.time_msec()
self.get_success(
self.store.record_event_failed_pull_attempt(
room_id, "$failed_event_id", "fake cause"
)
)
- event_ids_to_backoff = self.get_success(
+ event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
- self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
+ self.assertEqual(
+ event_ids_with_backoff,
+ # We expect a 2^1 hour backoff after a single failed attempt.
+ {"$failed_event_id": failure_time + 2 * 60 * 60 * 1000},
+ )
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
@@ -1179,14 +1184,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# attempt (2^1 hours).
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
- event_ids_to_backoff = self.get_success(
+ event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
# Since this function only returns events we should backoff from, time has
# elapsed past the backoff range so there is no events to backoff from.
- self.assertEqual(event_ids_to_backoff, [])
+ self.assertEqual(event_ids_with_backoff, {})
@attr.s(auto_attribs=True)
diff --git a/tests/unittest.py b/tests/unittest.py
index f9160faa1d..8a16fd3665 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -146,6 +146,9 @@ class TestCase(unittest.TestCase):
% (current_context(),)
)
+ # Disable GC for duration of test. See below for why.
+ gc.disable()
+
old_level = logging.getLogger().level
if level is not None and old_level != level:
@@ -163,12 +166,19 @@ class TestCase(unittest.TestCase):
return orig()
+ # We want to force a GC to workaround problems with deferreds leaking
+ # logcontexts when they are GCed (see the logcontext docs).
+ #
+ # The easiest way to do this would be to do a full GC after each test
+ # run, but that is very expensive. Instead, we disable GC (above) for
+ # the duration of the test so that we only need to run a gen-0 GC, which
+ # is a lot quicker.
+
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
- # force a GC to workaround problems with deferreds leaking logcontexts when
- # they are GCed (see the logcontext docs)
- gc.collect()
+ gc.collect(0)
+ gc.enable()
set_current_context(SENTINEL_CONTEXT)
return ret
|