diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index d1badbdf3b..58b8adbd32 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -66,7 +66,7 @@ class AccountDataServlet(RestServlet):
raise AuthError(403, "Cannot get account data for other users.")
event = await self.store.get_global_account_data_by_type_for_user(
- account_data_type, user_id
+ user_id, account_data_type
)
if event is None:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 8b56c76aed..e3492f9f93 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
+ self.inhibit_user_in_use_error = (
+ hs.config.registration.inhibit_user_in_use_error
+ )
+
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
+ if self.inhibit_user_in_use_error:
+ return 200, {"available": True}
+
ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -418,10 +425,14 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
self._refresh_tokens_enabled = (
hs.config.registration.refreshable_access_token_lifetime is not None
)
+ self._inhibit_user_in_use_error = (
+ hs.config.registration.inhibit_user_in_use_error
+ )
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -564,6 +575,7 @@ class RegisterRestServlet(RestServlet):
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
+ inhibit_user_in_use_error=self._inhibit_user_in_use_error,
)
# Check if the user-interactive authentication flows are complete, if
@@ -627,7 +639,16 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ desired_username = await (
+ self.password_auth_provider.get_username_for_registration(
+ auth_result,
+ params,
+ )
+ )
+
+ if desired_username is None:
+ desired_username = params.get("username", None)
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90bb9142a0..90355e44b2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = await self.room_context_handler.get_event_context(
+ event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
)
- if not results:
+ if not event_context:
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- aggregations = results.pop("aggregations", None)
- results["events_before"] = self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=aggregations
- )
- results["event"] = self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=aggregations
- )
- results["events_after"] = self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=aggregations
- )
- results["state"] = self._event_serializer.serialize_events(
- results["state"], time_now
- )
+ results = {
+ "events_before": self._event_serializer.serialize_events(
+ event_context.events_before,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "event": self._event_serializer.serialize_event(
+ event_context.event,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "events_after": self._event_serializer.serialize_events(
+ event_context.events_after,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "state": self._event_serializer.serialize_events(
+ event_context.state, time_now
+ ),
+ "start": event_context.start,
+ "end": event_context.end,
+ }
return 200, results
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d20ae1421e..f9615da525 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
def serialize(
events: Iterable[EventBase],
- aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+ aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
|