diff --git a/CHANGES.md b/CHANGES.md
index 650dc8487d..5de819ea1e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -5,7 +5,7 @@ Bugfixes
--------
- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386))
-- Fix URLs being accidentally escaped in Jinja2 templates. Broke in v1.20.0. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
+- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
Synapse 1.20.0 (2020-09-22)
diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature
new file mode 100644
index 0000000000..4ee5b6a56e
--- /dev/null
+++ b/changelog.d/8345.feature
@@ -0,0 +1 @@
+Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang.
diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc
new file mode 100644
index 0000000000..a56e36de4b
--- /dev/null
+++ b/changelog.d/8372.misc
@@ -0,0 +1 @@
+Add type annotations to `SimpleHttpClient`.
diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix
new file mode 100644
index 0000000000..155bc3404f
--- /dev/null
+++ b/changelog.d/8374.bugfix
@@ -0,0 +1 @@
+Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers.
diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix
new file mode 100644
index 0000000000..24983a1e95
--- /dev/null
+++ b/changelog.d/8386.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail.
diff --git a/changelog.d/8387.feature b/changelog.d/8387.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8387.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc
new file mode 100644
index 0000000000..aaaef88b66
--- /dev/null
+++ b/changelog.d/8388.misc
@@ -0,0 +1 @@
+Add `EventStreamPosition` type.
diff --git a/changelog.d/8396.feature b/changelog.d/8396.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8396.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix
new file mode 100644
index 0000000000..e432aeebf1
--- /dev/null
+++ b/changelog.d/8398.bugfix
@@ -0,0 +1 @@
+Fix "Re-starting finished log context" warning when receiving an event we already had over federation.
diff --git a/changelog.d/8405.feature b/changelog.d/8405.feature
new file mode 100644
index 0000000000..f3c4a74bc7
--- /dev/null
+++ b/changelog.d/8405.feature
@@ -0,0 +1 @@
+Consolidate the SSO error template across all configuration.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 53c67c6b9a..76f588fa9f 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1864,6 +1864,11 @@ oidc_config:
#
#skip_verification: true
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 56edee5db8..267faa2743 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -630,6 +630,7 @@ class Porter(object):
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
+ await self._setup_events_stream_seqs()
self.progress.done()
except Exception as e:
@@ -806,6 +807,29 @@ class Porter(object):
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
+ def _setup_events_stream_seqs(self):
+ def r(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
+ )
+
+ txn.execute("SELECT -MIN(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
+ (next_id,),
+ )
+
+ return self.postgres_store.db_pool.runInteraction(
+ "_setup_events_stream_seqs", r
+ )
+
##############################################
# The following is simply UI stuff
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 1514c0f691..c526c28b93 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri, {})
+ info = await self.get_json(uri)
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index e0939bce84..70fc8a2f62 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+ self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
@@ -158,6 +159,11 @@ class OIDCConfig(Config):
#
#skip_verification: true
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index bd4b47b341..99aa8b3bf1 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -169,12 +169,6 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
- # We enable autoescape here as the message may potentially come from a
- # remote resource
- self.saml2_error_html_template = self.read_templates(
- ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True
- )[0]
-
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 42e4087a92..c04ad77cf9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -42,7 +42,6 @@ from synapse.api.errors import (
)
from synapse.logging.context import (
PreserveLoggingContext,
- current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -233,8 +232,6 @@ class Keyring:
"""
try:
- ctx = current_context()
-
# map from server name to a set of outstanding request ids
server_to_request_ids = {}
@@ -265,12 +262,8 @@ class Keyring:
# if there are no more requests for this server, we can drop the lock.
if not server_requests:
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name)
-
- # ... but not immediately, as that can cause stack explosions if
- # we get a long queue of lookups.
- self.clock.call_later(0, drop_server_lock, server_name)
+ logger.debug("Releasing key lookup lock on %s", server_name)
+ drop_server_lock(server_name)
return res
@@ -335,20 +328,32 @@ class Keyring:
)
# look for any requests which weren't satisfied
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- verify_request.key_ready.errback(
- SynapseError(
- 401,
- "No key for %s with ids in %s (min_validity %i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- ),
- Codes.UNAUTHORIZED,
- )
+ while remaining_requests:
+ verify_request = remaining_requests.pop()
+ rq_str = (
+ "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
)
+ )
+
+ # If we run the errback immediately, it may cancel our
+ # loggingcontext while we are still in it, so instead we
+ # schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we
+ # has a massive queue of lookups waiting for this server).
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.errback,
+ SynapseError(
+ 401,
+ "Failed to find any key to satisfy %s" % (rq_str,),
+ Codes.UNAUTHORIZED,
+ ),
+ )
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
@@ -410,10 +415,23 @@ class Keyring:
# key was not valid at this point
continue
- with PreserveLoggingContext():
- verify_request.key_ready.callback(
- (server_name, key_id, fetch_key_result.verify_key)
- )
+ # we have a valid key for this request. If we run the callback
+ # immediately, it may cancel our loggingcontext while we are still in
+ # it, so instead we schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we had
+ # a massive queue of lookups waiting for this server).
+ logger.debug(
+ "Found key %s:%s for %s",
+ server_name,
+ key_id,
+ verify_request.request_name,
+ )
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.callback,
+ (server_name, key_id, fetch_key_result.verify_key),
+ )
completed.append(verify_request)
break
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7b0c0021db..81134fcdf6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
JsonDict,
MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
StateMap,
UserID,
get_domain_from_id,
@@ -2966,7 +2968,7 @@ class FederationHandler(BaseHandler):
)
return result["max_stream_id"]
else:
- max_stream_id = await self.storage.persistence.persist_events(
+ max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
@@ -2977,12 +2979,12 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- await self._notify_persisted_event(event, max_stream_id)
+ await self._notify_persisted_event(event, max_stream_token)
- return max_stream_id
+ return max_stream_token.stream
async def _notify_persisted_event(
- self, event: EventBase, max_stream_id: int
+ self, event: EventBase, max_stream_token: RoomStreamToken
) -> None:
"""Checks to see if notifier/pushers should be notified about the
event or not.
@@ -3008,9 +3010,11 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier():
return
- event_stream_id = event.internal_metadata.stream_ordering
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
async def _clean_room_for_join(self, room_id: str) -> None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index dfd414d886..8949343801 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1141,7 +1141,7 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+ event_pos, max_stream_token = await self.storage.persistence.persist_event(
event, context=context
)
@@ -1152,7 +1152,7 @@ class EventCreationHandler:
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -1164,7 +1164,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
- return event_stream_id
+ return event_pos.stream
async def _bump_active_time(self, user: UserID) -> None:
try:
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4230dbaf99..0e06e4408d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -114,6 +114,7 @@ class OidcHandler:
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
+ self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
@@ -849,7 +850,8 @@ class OidcHandler:
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
- If a user already exists with the mxid we've mapped, raise an exception.
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
Args:
userinfo: an object representing the user
@@ -905,21 +907,31 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname)
- if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
- # This mxid is taken
- raise MappingException(
- "mxid '{}' is already taken".format(user_id.to_string())
+ user_id = UserID(localpart, self._hostname).to_string()
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ if self._allow_existing_users:
+ if len(users) == 1:
+ registered_user_id = next(iter(users))
+ elif user_id in users:
+ registered_user_id = user_id
+ else:
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, list(users.keys())
+ )
+ )
+ else:
+ # This mxid is taken
+ raise MappingException("mxid '{}' is already taken".format(user_id))
+ else:
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
-
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9b3a4f638b..e948efef2e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -967,7 +967,7 @@ class SyncHandler:
raise NotImplementedError()
else:
joined_room_ids = await self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id
+ user_id, now_token.room_key
)
sync_result_builder = SyncResultBuilder(
sync_config,
@@ -1916,7 +1916,7 @@ class SyncHandler:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at(
- self, user_id: str, stream_ordering: int
+ self, user_id: str, room_key: RoomStreamToken
) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
@@ -1942,15 +1942,15 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, membership_stream_ordering in joined_rooms:
- if membership_stream_ordering <= stream_ordering:
+ for room_id, event_pos in joined_rooms:
+ if not event_pos.persisted_after(room_key):
joined_room_ids.add(room_id)
continue
logger.info("User joined room after current token: %s", room_id)
extrems = await self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering
+ room_id, event_pos.stream
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
import logging
import urllib
from io import BytesIO
+from typing import (
+ Any,
+ BinaryIO,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import treq
from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
@@ -285,13 +311,26 @@ class SimpleHttpClient:
ip_blacklist=self._ip_blacklist,
)
- async def request(self, method, uri, data=None, headers=None):
+ async def request(
+ self,
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
"""
Args:
- method (str): HTTP method to use.
- uri (str): URI to query.
- data (bytes): Data to send in the request body, if applicable.
- headers (t.w.http_headers.Headers): Request headers.
+ method: HTTP method to use.
+ uri: URI to query.
+ data: Data to send in the request body, if applicable.
+ headers: Request headers.
+
+ Returns:
+ Response object, once the headers have been read.
+
+ Raises:
+ RequestTimedOutError if the request times out before the headers are read
+
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
headers=headers,
**self._extra_treq_args
)
+ # we use our own timeout mechanism rather than treq's as a workaround
+ # for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred,
60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
set_tag("error_reason", e.args[0])
raise
- async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(
+ self,
+ uri: str,
+ args: Mapping[str, Union[str, List[str]]] = {},
+ headers: Optional[RawHeaders] = None,
+ ) -> Any:
"""
Args:
- uri (str):
- args (dict[str, str|List[str]]): query params
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: uri to query
+ args: parameters to be url-encoded in the body
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(
+ self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+ ) -> Any:
"""
Args:
- uri (str):
- post_json (object):
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: URI to query.
+ post_json: request body, to be encoded as json
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_json(self, uri, args={}, headers=None):
- """ Gets some json from the given URI.
+ async def get_json(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ ) -> Any:
+ """Gets some json from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query string
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
- async def put_json(self, uri, json_body, args={}, headers=None):
- """ Puts some json to the given URI.
+ async def put_json(
+ self,
+ uri: str,
+ json_body: Any,
+ args: QueryParams = {},
+ headers: RawHeaders = None,
+ ) -> Any:
+ """Puts some json to the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- json_body (dict): The JSON to put in the HTTP body,
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ json_body: The JSON to put in the HTTP body,
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_raw(self, uri, args={}, headers=None):
- """ Gets raw text from the given URI.
+ async def get_raw(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ ) -> bytes:
+ """Gets raw text from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- async def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(
+ self,
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
- url (str): The URL to GET
- output_stream (file): File to write the response body to.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ url: The URL to GET
+ output_stream: File to write the response body to.
+ headers: A map from header name to a list of values for that header
Returns:
- A (int,dict,string,int) tuple of the file length, dict of the response
+ A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
+
+ Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
+ SynapseError: if the response is not a 2xx, the remote file is too large, or
+ another exception happens during the download.
"""
actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a8fd3ef886..441b3d15e2 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StreamToken,
+ UserID,
+)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -187,7 +193,7 @@ class Notifier:
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
- ) # type: List[Tuple[int, EventBase, Collection[UserID]]]
+ ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -246,8 +252,8 @@ class Notifier:
def on_new_room_event(
self,
event: EventBase,
- room_stream_id: int,
- max_room_stream_id: int,
+ event_pos: PersistedEventPosition,
+ max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
@@ -261,16 +267,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((room_stream_id, event, extra_users))
- self._notify_pending_new_room_events(max_room_stream_id)
+ self.pending_new_room_events.append((event_pos, event, extra_users))
+ self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
- def _notify_pending_new_room_events(self, max_room_stream_id: int):
+ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
- max_room_stream_id: The highest stream_id below which all
+ max_room_stream_token: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
@@ -279,11 +285,9 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
- for room_stream_id, event, extra_users in pending:
- if room_stream_id > max_room_stream_id:
- self.pending_new_room_events.append(
- (room_stream_id, event, extra_users)
- )
+ for event_pos, event, extra_users in pending:
+ if event_pos.persisted_after(max_room_stream_token):
+ self.pending_new_room_events.append((event_pos, event, extra_users))
else:
if (
event.type == EventTypes.Member
@@ -296,39 +300,38 @@ class Notifier:
if users or rooms:
self.on_new_event(
- "room_key",
- RoomStreamToken(None, max_room_stream_id),
- users=users,
- rooms=rooms,
+ "room_key", max_room_stream_token, users=users, rooms=rooms,
)
- self._on_updated_room_token(max_room_stream_id)
+ self._on_updated_room_token(max_room_stream_token)
- def _on_updated_room_token(self, max_room_stream_id: int):
+ def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
"""Poke services that might care that the room position has been
updated.
"""
# poke any interested application service.
run_as_background_process(
- "_notify_app_services", self._notify_app_services, max_room_stream_id
+ "_notify_app_services", self._notify_app_services, max_room_stream_token
)
run_as_background_process(
- "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
+ "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
)
if self.federation_sender:
- self.federation_sender.notify_new_events(max_room_stream_id)
+ self.federation_sender.notify_new_events(max_room_stream_token.stream)
- async def _notify_app_services(self, max_room_stream_id: int):
+ async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
try:
- await self.appservice_handler.notify_interested_services(max_room_stream_id)
+ await self.appservice_handler.notify_interested_services(
+ max_room_stream_token.stream
+ )
except Exception:
logger.exception("Error notifying application services of event")
- async def _notify_pusher_pool(self, max_room_stream_id: int):
+ async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
- await self._pusher_pool.on_new_notifications(max_room_stream_id)
+ await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
except Exception:
logger.exception("Error pusher pool of event")
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d25fa49e1a..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e82b9e386f..55af3d41ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
-from synapse.types import UserID
+from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -151,8 +151,14 @@ class ReplicationDataHandler:
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ max_token = RoomStreamToken(
+ None, self.store.get_room_max_stream_ordering()
+ )
+ event_pos = PersistedEventPosition(instance_name, token)
+ self.notifier.on_new_room_event(
+ event, event_pos, max_token, extra_users
+ )
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index af8459719a..944bc9c9ca 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -12,7 +12,7 @@
<p>
There was an error during authentication:
</p>
- <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+ <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ccb3384db9..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,14 +160,20 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
- instance_name="master",
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
)
else:
self._cache_id_gen = None
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index de9e8d1dc6..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="events",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
+ writers=hs.config.worker.writers.events,
)
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="backfill",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
+ writers=hs.config.worker.writers.events,
)
else:
# We shouldn't be running in worker mode with SQLite, but its useful
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a06451b7f0..2ed696cc14 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -472,7 +472,7 @@ class RegistrationWorkerStore(SQLBaseStore):
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
- ) -> str:
+ ) -> Optional[str]:
"""Look up a user by their external auth id
Args:
@@ -480,7 +480,7 @@ class RegistrationWorkerStore(SQLBaseStore):
external_id: id on that system
Returns:
- str|None: the mxid of the user, or None if they are not known
+ the mxid of the user, or None if they are not known
"""
return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4fa8767b01..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
else:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ for room_id, instance, stream_id in txn
+ )
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
SELECT setval('events_backfill_stream_seq', (
- SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+ SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+ stream_name TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..603cd7d825 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -190,6 +190,7 @@ class EventsPersistenceStorage:
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -198,7 +199,7 @@ class EventsPersistenceStorage:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> int:
+ ) -> RoomStreamToken:
"""
Write events to the database
Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
- return self.main_store.get_current_events_token()
+ return RoomStreamToken(None, self.main_store.get_current_events_token())
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[int, int]:
+ ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
@@ -247,7 +248,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
max_persisted_id = self.main_store.get_current_events_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
+ event_stream_id = event.internal_metadata.stream_ordering
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return pos, RoomStreamToken(None, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+ "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b0353ac2dc..4269eaf918 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Union
import attr
from typing_extensions import Deque
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
+ stream_name: A name for the stream.
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ writers: A list of known writers to use to populate current positions
+ on startup. Can be empty if nothing uses `get_current_token` or
+ `get_positions` (e.g. caches stream).
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
self,
db_conn,
db: DatabasePool,
+ stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ writers: List[str],
positive: bool = True,
):
self._db = db
+ self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
+ self._writers = writers
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = self._load_current_ids(
- db_conn, table, instance_column, id_column
- )
+ self._current_positions = {} # type: Dict[str, int]
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
@@ -251,30 +258,84 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ # This goes and fills out the above state from the database.
+ self._load_current_ids(db_conn, table, instance_column, id_column)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
- ) -> Dict[str, int]:
- # If positive stream aggregate via MAX. For negative stream use MIN
- # *and* negate the result to get a positive number.
- sql = """
- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
- GROUP BY %(instance)s
- """ % {
- "instance": instance_column,
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
-
+ ):
cur = db_conn.cursor()
- cur.execute(sql)
- # `cur` is an iterable over returned rows, which are 2-tuples.
- current_positions = dict(cur)
+ # Load the current positions of all writers for the stream.
+ if self._writers:
+ sql = """
+ SELECT instance_name, stream_id FROM stream_positions
+ WHERE stream_name = ?
+ """
+ sql = self._db.engine.convert_param_style(sql)
- cur.close()
+ cur.execute(sql, (self._stream_name,))
+
+ self._current_positions = {
+ instance: stream_id * self._return_factor
+ for instance, stream_id in cur
+ if instance in self._writers
+ }
+
+ # We set the `_persisted_upto_position` to be the minimum of all current
+ # positions. If empty we use the max stream ID from the DB table.
+ min_stream_id = min(self._current_positions.values(), default=None)
+
+ if min_stream_id is None:
+ # We add a GREATEST here to ensure that the result is always
+ # positive. (This can be a problem for e.g. backfill streams where
+ # the server has never backfilled).
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+ self._persisted_upto_position = stream_id
+ else:
+ # If we have a min_stream_id then we pull out everything greater
+ # than it from the DB so that we can prefill
+ # `_known_persisted_positions` and get a more accurate
+ # `_persisted_upto_position`.
+ #
+ # We also check if any of the later rows are from this instance, in
+ # which case we use that for this instance's current position. This
+ # is to handle the case where we didn't finish persisting to the
+ # stream positions table before restart (or the stream position
+ # table otherwise got out of date).
+
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (min_stream_id,))
+
+ self._persisted_upto_position = min_stream_id
+
+ with self._lock:
+ for (instance, stream_id,) in cur:
+ stream_id = self._return_factor * stream_id
+ self._add_persisted_position(stream_id)
- return current_positions
+ if instance == self._instance_name:
+ self._current_positions[instance] = stream_id
+
+ cur.close()
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
@@ -316,6 +377,21 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persited row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
@@ -447,6 +523,28 @@ class MultiWriterIdGenerator:
# do.
break
+ def _update_stream_positions_table_txn(self, txn):
+ """Update the `stream_positions` table with newly persisted position.
+ """
+
+ if not self._writers:
+ return
+
+ # We upsert the value, ensuring on conflict that we always increase the
+ # value (or decrease if stream goes backwards).
+ sql = """
+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (stream_name, instance_name)
+ DO UPDATE SET
+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+ """ % {
+ "agg": "GREATEST" if self._positive else "LEAST",
+ }
+
+ pos = (self.get_current_token_for_writer(self._instance_name),)
+ txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
@attr.s(slots=True)
class _AsyncCtxManagerWrapper:
@@ -503,4 +601,16 @@ class _MultiWriterCtxManager:
if exc_type is not None:
return False
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self.id_gen._writers:
+ await self.id_gen._db.runInteraction(
+ "MultiWriterIdGenerator._update_table",
+ self.id_gen._update_stream_positions_table_txn,
+ )
+
return False
diff --git a/synapse/types.py b/synapse/types.py
index 93fee6c92a..07b421077c 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -509,6 +509,21 @@ class StreamToken:
StreamToken.START = StreamToken.from_string("s0_0")
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
+ instance_name = attr.ib(type=str)
+ stream = attr.ib(type=int)
+
+ def persisted_after(self, token: RoomStreamToken) -> bool:
+ return token.stream < self.stream
+
+
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..5cf408f21f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
@@ -68,54 +68,40 @@ class MockPerspectiveServer:
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
+
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- async def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- await persp_deferred
- return persp_resp
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
- self.http_client.post_json.side_effect = get_perspectives
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 89ec5fcb31..5910772aa8 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
self.assertEqual(mxid, "@test_user_2:test")
+
+ # Test if the mxid is already taken
+ store = self.hs.get_datastore()
+ user3 = UserID.from_string("@test_user_3:test")
+ self.get_success(
+ store.register_user(user_id=user3.to_string(), password_hash=None)
+ )
+ userinfo = {"sub": "test3", "username": "test_user_3"}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+
+ @override_config({"oidc_config": {"allow_existing_users": True}})
+ def test_map_userinfo_to_existing_user(self):
+ """Existing users can log in with OpenID Connect when allow_existing_users is True."""
+ store = self.hs.get_datastore()
+ user4 = UserID.from_string("@test_user_4:test")
+ self.get_success(
+ store.register_user(user_id=user4.to_string(), password_hash=None)
+ )
+ userinfo = {
+ "sub": "test4",
+ "username": "test_user_4",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_4:test")
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index bc578411d6..c0ee1cfbd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ {(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
- self.assertEqual(
- joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
)
+ self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index fb8f5bc255..d4ff55fbff 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
return self.get_success(self.db_pool.runWithConnection(_create))
@@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
@@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
positive=False,
)
@@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
@@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
- id_gen_1 = self._create_id_generator("first")
- id_gen_2 = self._create_id_generator("second")
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
async def _get_next_async():
async with id_gen_1.get_next() as stream_id:
|