summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/account_data.py2
-rw-r--r--synapse/rest/client/register.py23
-rw-r--r--synapse/rest/client/room.py39
-rw-r--r--synapse/rest/client/sync.py3
4 files changed, 49 insertions, 18 deletions
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index d1badbdf3b..58b8adbd32 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -66,7 +66,7 @@ class AccountDataServlet(RestServlet):
             raise AuthError(403, "Cannot get account data for other users.")
 
         event = await self.store.get_global_account_data_by_type_for_user(
-            account_data_type, user_id
+            user_id, account_data_type
         )
 
         if event is None:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 8b56c76aed..e3492f9f93 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet):
             ),
         )
 
+        self.inhibit_user_in_use_error = (
+            hs.config.registration.inhibit_user_in_use_error
+        )
+
     async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         if not self.hs.config.registration.enable_registration:
             raise SynapseError(
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
+        if self.inhibit_user_in_use_error:
+            return 200, {"available": True}
+
         ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
@@ -418,10 +425,14 @@ class RegisterRestServlet(RestServlet):
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
+        self.password_auth_provider = hs.get_password_auth_provider()
         self._registration_enabled = self.hs.config.registration.enable_registration
         self._refresh_tokens_enabled = (
             hs.config.registration.refreshable_access_token_lifetime is not None
         )
+        self._inhibit_user_in_use_error = (
+            hs.config.registration.inhibit_user_in_use_error
+        )
 
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
@@ -564,6 +575,7 @@ class RegisterRestServlet(RestServlet):
                 desired_username,
                 guest_access_token=guest_access_token,
                 assigned_user_id=registered_user_id,
+                inhibit_user_in_use_error=self._inhibit_user_in_use_error,
             )
 
         # Check if the user-interactive authentication flows are complete, if
@@ -627,7 +639,16 @@ class RegisterRestServlet(RestServlet):
             if not password_hash:
                 raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
-            desired_username = params.get("username", None)
+            desired_username = await (
+                self.password_auth_provider.get_username_for_registration(
+                    auth_result,
+                    params,
+                )
+            )
+
+            if desired_username is None:
+                desired_username = params.get("username", None)
+
             guest_access_token = params.get("guest_access_token", None)
 
             if desired_username is not None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90bb9142a0..90355e44b2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
         else:
             event_filter = None
 
-        results = await self.room_context_handler.get_event_context(
+        event_context = await self.room_context_handler.get_event_context(
             requester, room_id, event_id, limit, event_filter
         )
 
-        if not results:
+        if not event_context:
             raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
 
         time_now = self.clock.time_msec()
-        aggregations = results.pop("aggregations", None)
-        results["events_before"] = self._event_serializer.serialize_events(
-            results["events_before"], time_now, bundle_aggregations=aggregations
-        )
-        results["event"] = self._event_serializer.serialize_event(
-            results["event"], time_now, bundle_aggregations=aggregations
-        )
-        results["events_after"] = self._event_serializer.serialize_events(
-            results["events_after"], time_now, bundle_aggregations=aggregations
-        )
-        results["state"] = self._event_serializer.serialize_events(
-            results["state"], time_now
-        )
+        results = {
+            "events_before": self._event_serializer.serialize_events(
+                event_context.events_before,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "event": self._event_serializer.serialize_event(
+                event_context.event,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "events_after": self._event_serializer.serialize_events(
+                event_context.events_after,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "state": self._event_serializer.serialize_events(
+                event_context.state, time_now
+            ),
+            "start": event_context.start,
+            "end": event_context.end,
+        }
 
         return 200, results
 
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d20ae1421e..f9615da525 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
 
         def serialize(
             events: Iterable[EventBase],
-            aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+            aggregations: Optional[Dict[str, BundledAggregations]] = None,
         ) -> List[JsonDict]:
             return self._event_serializer.serialize_events(
                 events,