summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/device.py3
-rw-r--r--synapse/handlers/e2e_keys.py6
-rw-r--r--synapse/handlers/room_list.py43
-rw-r--r--synapse/handlers/room_summary.py26
-rw-r--r--synapse/logging/opentracing.py7
-rw-r--r--synapse/rest/admin/media.py6
-rw-r--r--synapse/rest/admin/registration_tokens.py13
-rw-r--r--synapse/rest/admin/rooms.py11
-rw-r--r--synapse/rest/admin/users.py10
-rw-r--r--synapse/storage/background_updates.py14
-rw-r--r--synapse/storage/controllers/persist_events.py252
-rw-r--r--synapse/storage/database.py15
-rw-r--r--synapse/storage/databases/main/__init__.py52
-rw-r--r--synapse/storage/databases/main/devices.py55
-rw-r--r--synapse/storage/databases/main/events.py400
-rw-r--r--synapse/storage/databases/main/media_repository.py48
-rw-r--r--synapse/storage/databases/main/registration.py42
-rw-r--r--synapse/storage/databases/main/room.py82
-rw-r--r--synapse/util/__init__.py2
21 files changed, 640 insertions, 457 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py

index 3b27925517..2bb2c64ebe 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache @@ -999,6 +999,12 @@ class FederationServer(FederationBase): async def on_claim_client_keys( self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool ) -> Dict[str, Any]: + if any( + not self.hs.is_mine(UserID.from_string(user_id)) + for user_id, _, _, _ in query + ): + raise SynapseError(400, "User is not hosted on this homeserver") + log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self._e2e_keys_handler.claim_local_one_time_keys( query, always_include_fallback_keys=always_include_fallback_keys diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2c2baeac67..d06f8e3296 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -283,7 +283,7 @@ class AdminHandler: start, limit, user_id ) for media in media_ids: - writer.write_media_id(media["media_id"], media) + writer.write_media_id(media.media_id, attr.asdict(media)) logger.info( "[%s] Written %d media_ids of %s", diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 3ce96ef3cb..93472d0117 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -328,6 +328,9 @@ class DeviceWorkerHandler: return result async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: + if not self.hs.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "User is not hosted on this homeserver") + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( user_id ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d340d4aebe..d06524495f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -542,6 +542,12 @@ class E2eKeysHandler: device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} ) + if any( + not self.is_mine(UserID.from_string(user_id)) + for user_id in device_keys_query + ): + raise SynapseError(400, "User is not hosted on this homeserver") + res = await self.query_local_devices( device_keys_query, include_displaynames=( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..2947e154be 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -33,6 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.storage.databases.main.room import LargestRoomStats from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -170,26 +171,24 @@ class RoomListHandler: ignore_non_federatable=from_federation, ) - def build_room_entry(room: JsonDict) -> JsonDict: + def build_room_entry(room: LargestRoomStats) -> JsonDict: entry = { - "room_id": room["room_id"], - "name": room["name"], - "topic": room["topic"], - "canonical_alias": room["canonical_alias"], - "num_joined_members": room["joined_members"], - "avatar_url": room["avatar"], - "world_readable": room["history_visibility"] + "room_id": room.room_id, + "name": room.name, + "topic": room.topic, + "canonical_alias": room.canonical_alias, + "num_joined_members": room.joined_members, + "avatar_url": room.avatar, + "world_readable": room.history_visibility == HistoryVisibility.WORLD_READABLE, - "guest_can_join": room["guest_access"] == "can_join", - "join_rule": room["join_rules"], - "room_type": room["room_type"], + "guest_can_join": room.guest_access == "can_join", + "join_rule": room.join_rules, + "room_type": room.room_type, } # Filter out Nones – rather omit the field altogether return {k: v for k, v in entry.items() if v is not None} - results = [build_room_entry(r) for r in results] - response: JsonDict = {} num_results = len(results) if limit is not None: @@ -212,33 +211,33 @@ class RoomListHandler: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() if more_to_come: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() else: if has_batch_token: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() if more_to_come: response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() - response["chunk"] = results + response["chunk"] = [build_room_entry(r) for r in results] response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index dd559b4c45..1dfb12e065 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -703,24 +703,24 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rule": stats["join_rules"], + entry: JsonDict = { + "room_id": stats.room_id, + "name": stats.name, + "topic": stats.topic, + "canonical_alias": stats.canonical_alias, + "num_joined_members": stats.joined_members, + "avatar_url": stats.avatar, + "join_rule": stats.join_rules, "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + stats.history_visibility == HistoryVisibility.WORLD_READABLE ), - "guest_can_join": stats["guest_access"] == "can_join", - "room_type": stats["room_type"], + "guest_can_join": stats.guest_access == "can_join", + "room_type": stats.room_type, } if self._msc3266_enabled: - entry["im.nheko.summary.version"] = stats["version"] - entry["im.nheko.summary.encryption"] = stats["encryption"] + entry["im.nheko.summary.version"] = stats.version + entry["im.nheko.summary.encryption"] = stats.encryption # Federation requests need to provide additional information so the # requested server is able to filter the response appropriately. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 4454fe29a5..e297fa9c8b 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -1019,11 +1019,14 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func + # getfullargspec is somewhat expensive, so ensure it is only called a single + # time (the function signature shouldn't change anyway). + argspec = inspect.getfullargspec(func) + @contextlib.contextmanager def _wrapping_logic( - func: Callable[P, R], *args: P.args, **kwargs: P.kwargs + _func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: - argspec = inspect.getfullargspec(func) # We use `[1:]` to skip the `self` object reference and `start=1` to # make the index line up with `argspec.args`. # diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index b7637dff0b..8cf5268854 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -17,6 +17,8 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple +import attr + from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet): start, limit, user_id, order_by, direction ) - ret = {"media": media, "total": total} + ret = {"media": [attr.asdict(m) for m in media], "total": total} if (start + limit) < total: ret["next_token"] = start + len(media) @@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet): ) deleted_media, total = await self.media_repository.delete_local_media_ids( - [row["media_id"] for row in media] + [m.media_id for m in media] ) return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index ffce92d45e..f3e06d3da3 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py
@@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return HTTPStatus.OK, {"registration_tokens": token_list} + return HTTPStatus.OK, { + "registration_tokens": [ + { + "token": t[0], + "uses_allowed": t[1], + "pending": t[2], + "completed": t[3], + "expiry_time": t[4], + } + for t in token_list + ] + } class NewRegistrationTokenRestServlet(RestServlet): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0659f22a89..23a034522c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -16,6 +16,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse +import attr + from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter @@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet): raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) - ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) + result = attr.asdict(ret) + result["joined_local_devices"] = await self.store.count_devices_by_users( + members + ) + result["forgotten"] = await self.store.is_locally_forgotten_room(room_id) - return HTTPStatus.OK, ret + return HTTPStatus.OK, result async def on_DELETE( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7fe16130e7..73878dd99d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -18,6 +18,8 @@ import secrets from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import attr + from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( @@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet): ) # If support for MSC3866 is not enabled, don't show the approval flag. + filter = None if not self._msc3866_enabled: - for user in users: - del user["approved"] - ret = {"users": users, "total": total} + def _filter(a: attr.Attribute) -> bool: + return a.name != "approved" + + ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 12829d3d7d..7426dbcad6 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -28,6 +28,7 @@ from typing import ( Sequence, Tuple, Type, + cast, ) import attr @@ -488,14 +489,14 @@ class BackgroundUpdater: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: + def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates ORDER BY ordering, update_name """ ) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, Optional[str]]], txn.fetchall()) if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( @@ -507,14 +508,13 @@ class BackgroundUpdater: return True # find the first update which isn't dependent on another one in the queue. - pending = {update["update_name"] for update in all_pending_updates} - for upd in all_pending_updates: - depends_on = upd["depends_on"] + pending = {update_name for update_name, depends_on in all_pending_updates} + for update_name, depends_on in all_pending_updates: if not depends_on or depends_on not in pending: break logger.info( "Not starting on bg update %s until %s is done", - upd["update_name"], + update_name, depends_on, ) else: @@ -524,7 +524,7 @@ class BackgroundUpdater: "another: dependency cycle?" ) - self._current_background_update = upd["update_name"] + self._current_background_update = update_name # We have a background update to run, otherwise we would have returned # early. diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController: return await res.get_state(self._state_controller, StateFilter.all()) async def _persist_event_batch( - self, _room_id: str, task: _PersistEventsTask + self, room_id: str, task: _PersistEventsTask ) -> Dict[str, str]: """Callback for the _event_persist_queue Calculates the change to current state and forward extremities, and persists the given events and with those updates. + Assumes that we are only persisting events for one room at a time. + Returns: A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. @@ -594,140 +596,23 @@ class EventsPersistenceStorageController: # We can't easily parallelize these since different chunks # might contain the same event. :( - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->set[event_ids] giving the new forward - # extremities in each room - new_forward_extremities: Dict[str, Set[str]] = {} - - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room: Dict[str, DeltaState] = {} + new_forward_extremities = None + state_delta_for_room = None if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. + # Work out the new "current state" for the room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( - await self.main_store.get_latest_event_ids_in_room(room_id) - ) - new_latest_event_ids = await self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.debug("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = await self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids, new_latest_event_ids = res - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - delta = None - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - delta = DeltaState([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = await self._calculate_state_delta( - room_id, current_state - ) - - if delta: - # If we have a change of state then lets check - # whether we're actually still a member of the room, - # or if our last user left. If we're no longer in - # the room then we delete the current state and - # extremities. - is_still_joined = await self._is_server_still_joined( - room_id, - ev_ctx_rm, - delta, - ) - if not is_still_joined: - logger.info("Server no longer in room %s", room_id) - delta.no_longer_in_room = True - - state_delta_for_room[room_id] = delta + ( + new_forward_extremities, + state_delta_for_room, + ) = await self._calculate_new_forward_extremities_and_state_delta( + room_id, chunk + ) await self.persist_events_store._persist_events_and_state_updates( + room_id, chunk, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, @@ -737,6 +622,117 @@ class EventsPersistenceStorageController: return replaced_events + async def _calculate_new_forward_extremities_and_state_delta( + self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]] + ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]: + """Calculates the new forward extremities and state delta for a room + given events to persist. + + Assumes that we are only persisting events for one room at a time. + + Returns: + A tuple of: + A set of str giving the new forward extremities the room + + The state delta for the room. + """ + + latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + new_latest_event_ids = await self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + return (None, None) + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities = new_latest_event_ids + + len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1 + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + return (new_forward_extremities, None) + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.debug("Calculating state delta for room %s", room_id) + with Measure(self._clock, "persist_events.get_new_state_after_events"): + res = await self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids, new_latest_event_ids = res + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities = new_latest_event_ids + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure(self._clock, "persist_events.calculate_state_delta"): + delta = await self._calculate_state_delta(room_id, current_state) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + ev_ctx_rm, + delta, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + delta.no_longer_in_room = True + + return (new_forward_extremities, delta) + async def _calculate_new_extremities( self, room_id: str, diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a4e7048368..6d54bb0eb2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -18,7 +18,6 @@ import logging import time import types from collections import defaultdict -from sys import intern from time import monotonic as monotonic_time from typing import ( TYPE_CHECKING, @@ -1042,20 +1041,6 @@ class DatabasePool: self._db_pool.runWithConnection(inner_func, *args, **kwargs) ) - @staticmethod - def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: - """Converts a SQL cursor into an list of dicts. - - Args: - cursor: The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - assert cursor.description is not None, "cursor.description was None" - col_headers = [intern(str(column[0])) for column in cursor.description] - results = [dict(zip(col_headers, row)) for row in cursor] - return results - async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 840d725114..89f4077351 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import attr + from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage._base import make_in_list_sql_clause @@ -28,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import get_domain_from_id from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -82,6 +84,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPaginateResponse: + """This is very similar to UserInfo, but not quite the same.""" + + name: str + user_type: Optional[str] + is_guest: bool + admin: bool + deactivated: bool + shadow_banned: bool + displayname: Optional[str] + avatar_url: Optional[str] + creation_ts: Optional[int] + approved: bool + erased: bool + last_seen_ts: int + locked: bool + + class DataStore( EventsBackgroundUpdatesStore, ExperimentalFeaturesStore, @@ -156,7 +177,7 @@ class DataStore( approved: bool = True, not_user_types: Optional[List[str]] = None, locked: bool = False, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -182,7 +203,7 @@ class DataStore( def get_users_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: filters = [] args: list = [] @@ -282,13 +303,24 @@ class DataStore( """ args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) - - # some of those boolean values are returned as integers when we're on SQLite - columns_to_boolify = ["erased"] - for user in users: - for column in columns_to_boolify: - user[column] = bool(user[column]) + users = [ + UserPaginateResponse( + name=row[0], + user_type=row[1], + is_guest=bool(row[2]), + admin=bool(row[3]), + deactivated=bool(row[4]), + shadow_banned=bool(row[5]), + displayname=row[6], + avatar_url=row[7], + creation_ts=row[8], + approved=bool(row[9]), + erased=bool(row[10]), + last_seen_ts=row[11], + locked=bool(row[12]), + ) + for row in txn + ] return users, count diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index ae0536fbaf..303ef6ea27 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -1622,7 +1622,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # # For each duplicate, we delete all the existing rows and put one back. - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] last_row = progress.get( "last_row", {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, @@ -1630,44 +1629,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( - [(x, last_row[x]) for x in KEY_COLS] + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] ) - sql = """ + sql = f""" SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id HAVING count(*) > 1 - ORDER BY %s + ORDER BY stream_id, destination, user_id, device_id LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) + """ txn.execute(sql, args + [batch_size]) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() - row = None - for row in rows: + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", - {x: row[x] for x in KEY_COLS}, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, ) - row["sent"] = False self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", - row, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, ) - if row: + if rows: self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - {"last_row": row}, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, ) return len(rows) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3c1492e3ad..647ba182f6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState: Attributes: to_delete: List of type/state_keys to delete from current state to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room + no_longer_in_room: The server is no longer in the room, so the room should e.g. be removed from `current_state_events` table. """ @@ -131,22 +131,25 @@ class PersistEventsStore: @trace async def _persist_events_and_state_updates( self, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and - forward extremities tables. + forward extremities tables. + + Assumes that we are only persisting events for one room at a time. Args: + room_id: events_and_contexts: - state_delta_for_room: Map from room_id to the delta to apply to - room state - new_forward_extremities: Map from room_id to set of event IDs - that are the new forward extremities of the room. + state_delta_for_room: The delta to apply to the room state + new_forward_extremities: A set of event IDs that are the new forward + extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True for backfilled events because backfilled events get a negative @@ -196,6 +199,7 @@ class PersistEventsStore: await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, + room_id=room_id, events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, @@ -221,9 +225,9 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, latest_event_ids in new_forward_extremities.items(): + if new_forward_extremities: self.store.get_latest_event_ids_in_room.prefill( - (room_id,), frozenset(latest_event_ids) + (room_id,), frozenset(new_forward_extremities) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: @@ -336,10 +340,11 @@ class PersistEventsStore: self, txn: LoggingTransaction, *, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], ) -> None: """Insert some number of room events into the necessary database tables. @@ -347,8 +352,11 @@ class PersistEventsStore: and the rejections table. Things reading from those table will need to check whether the event was rejected. + Assumes that we are only persisting events for one room at a time. + Args: txn + room_id: The room the events are from events_and_contexts: events to persist inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True @@ -357,10 +365,9 @@ class PersistEventsStore: delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room: The current-state delta for each room. - new_forward_extremities: The new forward extremities for each room. - For each room, a list of the event ids which are the forward - extremities. + state_delta_for_room: The current-state delta for the room. + new_forward_extremities: The new forward extremities for the room: + a set of the event ids which are the forward extremities. Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -376,14 +383,13 @@ class PersistEventsStore: # # Annoyingly SQLite doesn't support row level locking. if isinstance(self.database_engine, PostgresEngine): - for room_id in {e.room_id for e, _ in events_and_contexts}: - txn.execute( - "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", - (room_id,), - ) - row = txn.fetchone() - if row is None: - raise Exception(f"Room does not exist {room_id}") + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", + (room_id,), + ) + row = txn.fetchone() + if row is None: + raise Exception(f"Room does not exist {room_id}") # stream orderings should have been assigned by now assert min_stream_order @@ -419,7 +425,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, room_id, events_and_contexts=events_and_contexts + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -432,11 +440,13 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremities, - max_stream_order=max_stream_order, - ) + if new_forward_extremities: + self._update_forward_extremities_txn( + txn, + room_id, + new_forward_extremities=new_forward_extremities, + max_stream_order=max_stream_order, + ) self._persist_transaction_ids_txn(txn, events_and_contexts) @@ -464,7 +474,10 @@ class PersistEventsStore: # We call this last as it assumes we've inserted the events into # room_memberships, where applicable. # NB: This function invalidates all state related caches - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + if state_delta_for_room: + self._update_current_state_txn( + txn, room_id, state_delta_for_room, min_stream_order + ) def _persist_event_auth_chain_txn( self, @@ -1026,74 +1039,75 @@ class PersistEventsStore: await self.db_pool.runInteraction( "update_current_state", self._update_current_state_txn, - state_delta_by_room={room_id: state_delta}, + room_id, + delta_state=state_delta, stream_id=stream_ordering, ) def _update_current_state_txn( self, txn: LoggingTransaction, - state_delta_by_room: Dict[str, DeltaState], + room_id: str, + delta_state: DeltaState, stream_id: int, ) -> None: - for room_id, delta_state in state_delta_by_room.items(): - to_delete = delta_state.to_delete - to_insert = delta_state.to_insert - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } - if delta_state.no_longer_in_room: - # Server is no longer in the room so we delete the room from - # current_state_events, being careful we've already updated the - # rooms.room_version column (which gets populated in a - # background task). - self._upsert_room_version_txn(txn, room_id) + if delta_state.no_longer_in_room: + # Server is no longer in the room so we delete the room from + # current_state_events, being careful we've already updated the + # rooms.room_version column (which gets populated in a + # background task). + self._upsert_room_version_txn(txn, room_id) - # Before deleting we populate the current_state_delta_stream - # so that async background tasks get told what happened. - sql = """ + # Before deleting we populate the current_state_delta_stream + # so that async background tasks get told what happened. + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, self._instance_name, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) - # We also want to invalidate the membership caches for users - # that were in the room. - users_in_room = self.store.get_users_in_room_txn(txn, room_id) - members_changed.update(users_in_room) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) - self.db_pool.simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - ) - else: - # We're still in the room, so we update the current state as normal. + self.db_pool.simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, + ) + else: + # We're still in the room, so we update the current state as normal. - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, ?, ?, ?, ?, ( @@ -1101,39 +1115,39 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.execute_batch( - sql, + txn.execute_batch( + sql, + ( ( - ( - stream_id, - self._instance_name, - room_id, - etype, - state_key, - to_insert.get((etype, state_key)), - room_id, - etype, - state_key, - ) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - # Now we actually update the current_state_events table + stream_id, + self._instance_name, + room_id, + etype, + state_key, + to_insert.get((etype, state_key)), + room_id, + etype, + state_key, + ) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + # Now we actually update the current_state_events table - txn.execute_batch( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) + txn.execute_batch( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.execute_batch( - """INSERT INTO current_state_events + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.execute_batch( + """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, ?, @@ -1141,34 +1155,34 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[0], key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - ], - ) + [ + (room_id, key[0], key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + ], + ) - # We now update `local_current_membership`. We do this regardless - # of whether we're still in the room or not to handle the case where - # e.g. we just got banned (where we need to record that fact here). - - # Note: Do we really want to delete rows here (that we do not - # subsequently reinsert below)? While technically correct it means - # we have no record of the fact the user *was* a member of the - # room but got, say, state reset out of it. - if to_delete or to_insert: - txn.execute_batch( - "DELETE FROM local_current_membership" - " WHERE room_id = ? AND user_id = ?", - ( - (room_id, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - if etype == EventTypes.Member and self.is_mine_id(state_key) - ), - ) + # We now update `local_current_membership`. We do this regardless + # of whether we're still in the room or not to handle the case where + # e.g. we just got banned (where we need to record that fact here). - if to_insert: - txn.execute_batch( - """INSERT INTO local_current_membership + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.execute_batch( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.execute_batch( + """INSERT INTO local_current_membership (room_id, user_id, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, @@ -1176,29 +1190,27 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - if key[0] == EventTypes.Member and self.is_mine_id(key[1]) - ], - ) - - txn.call_after( - self.store._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, + [ + (room_id, key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], ) - # Invalidate the various caches - self.store._invalidate_state_caches_and_stream( - txn, room_id, members_changed - ) + txn.call_after( + self.store._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) - # Check if any of the remote membership changes requires us to - # unsubscribe from their device lists. - self.store.handle_potentially_left_users_txn( - txn, {m for m in members_changed if not self.hs.is_mine_id(m)} - ) + # Invalidate the various caches + self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed) + + # Check if any of the remote membership changes requires us to + # unsubscribe from their device lists. + self.store.handle_potentially_left_users_txn( + txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + ) def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state @@ -1232,23 +1244,19 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn: LoggingTransaction, - new_forward_extremities: Dict[str, Set[str]], + room_id: str, + new_forward_extremities: Set[str], max_stream_order: int, ) -> None: - for room_id in new_forward_extremities.keys(): - self.db_pool.simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) + self.db_pool.simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", keys=("event_id", "room_id"), - values=[ - (ev_id, room_id) - for room_id, new_extrem in new_forward_extremities.items() - for ev_id in new_extrem - ], + values=[(ev_id, room_id) for ev_id in new_forward_extremities], ) # We now insert into stream_ordering_to_exterm a mapping from room_id, # new stream_ordering to new forward extremeties in the room. @@ -1260,8 +1268,7 @@ class PersistEventsStore: keys=("room_id", "event_id", "stream_ordering"), values=[ (room_id, event_id, max_stream_order) - for room_id, new_extrem in new_forward_extremities.items() - for event_id in new_extrem + for event_id in new_forward_extremities ], ) @@ -1298,36 +1305,45 @@ class PersistEventsStore: def _update_room_depths_txn( self, txn: LoggingTransaction, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], ) -> None: """Update min_depth for each room Args: txn: db connection + room_id: The room ID events_and_contexts: events we are persisting """ - depth_updates: Dict[str, int] = {} + stream_ordering: Optional[int] = None + depth_update = 0 for event, context in events_and_contexts: - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. + # Don't update the stream ordering for backfilled events because + # backfilled events have negative stream_ordering and happened in the + # past, so we know that we don't need to update the stream_ordering + # tip/front for the room. assert event.internal_metadata.stream_ordering is not None if event.internal_metadata.stream_ordering >= 0: - txn.call_after( - self.store._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) + if stream_ordering is None: + stream_ordering = event.internal_metadata.stream_ordering + else: + stream_ordering = max( + stream_ordering, event.internal_metadata.stream_ordering + ) if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) + depth_update = max(event.depth, depth_update) + + # Then update the `stream_ordering` position to mark the latest event as + # the front of the room. + if stream_ordering is not None: + txn.call_after( + self.store._events_stream_cache.entity_has_changed, + room_id, + stream_ordering, + ) - for room_id, depth in depth_updates.items(): - self._update_min_depth_for_room_txn(txn, room_id, depth) + self._update_min_depth_for_room_txn(txn, room_id, depth_update) def _update_outliers_txn( self, @@ -1350,13 +1366,19 @@ class PersistEventsStore: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], + rows = cast( + List[Tuple[str, bool]], + self.db_pool.simple_select_many_txn( + txn, + "events", + "event_id", + [event.event_id for event, _ in events_and_contexts], + keyvalues={}, + retcols=("event_id", "outlier"), + ), ) - have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn)) + have_persisted = dict(rows) logger.debug( "_update_outliers_txn: events=%s have_persisted=%s", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..c8d7c9fd32 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo @@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: int + upload_name: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + last_access_ts, + quarantined_by, + safe_from_quarantine FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + safe_from_quarantine=bool(row[7]), + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..933d76e905 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def get_registration_tokens( self, valid: Optional[bool] = None - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: """List all registration tokens. Used by the admin API. Args: @@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Default is None: return all tokens regardless of validity. Returns: - A list of dicts, each containing details of a token. + A list of tuples containing: + * The token + * The number of users allowed (or None) + * Whether it is pending + * Whether it has been completed + * An expiry time (or None if no expiry) """ def select_registration_tokens_txn( txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: if valid is None: # Return all tokens regardless of validity - txn.execute("SELECT * FROM registration_tokens") + txn.execute( + """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + """ + ) elif valid: # Select valid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "(uses_allowed > pending + completed OR uses_allowed IS NULL) " - "AND (expiry_time > ? OR expiry_time IS NULL)" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) + AND (expiry_time > ? OR expiry_time IS NULL) + """ txn.execute(sql, [now]) else: # Select invalid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "uses_allowed <= pending + completed OR expiry_time <= ?" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE uses_allowed <= pending + completed OR expiry_time <= ? + """ txn.execute(sql, [now]) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + ) return await self.db_pool.runInteraction( "select_registration_tokens", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..6d4b9891e7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride: burst_count: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LargestRoomStats: + room_id: str + name: Optional[str] + canonical_alias: Optional[str] + joined_members: int + join_rules: Optional[str] + guest_access: Optional[str] + history_visibility: Optional[str] + state_events: int + avatar: Optional[str] + topic: Optional[str] + room_type: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomStats(LargestRoomStats): + joined_local_members: int + version: Optional[str] + creator: Optional[str] + encryption: Optional[str] + federatable: bool + public: bool + + class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate @@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) - async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. Args: @@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[RoomStats]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db_pool.cursor_to_dict(txn)[0] - except IndexError: + row = txn.fetchone() + if not row: return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res + return RoomStats( + room_id=row[0], + name=row[1], + canonical_alias=row[2], + joined_members=row[3], + joined_local_members=row[4], + version=row[5], + creator=row[6], + encryption=row[7], + federatable=bool(row[8]), + public=bool(row[9]), + join_rules=row[10], + guest_access=row[11], + history_visibility=row[12], + state_events=row[13], + avatar=row[14], + topic=row[15], + room_type=row[16], + ) return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id @@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_largest_public_rooms_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: txn.execute(sql, query_args) - results = self.db_pool.cursor_to_dict(txn) + results = [ + LargestRoomStats( + room_id=r[0], + name=r[1], + canonical_alias=r[3], + joined_members=r[4], + join_rules=r[8], + guest_access=r[7], + history_visibility=r[6], + state_events=0, + avatar=r[5], + topic=r[2], + room_type=r[9], + ) + for r in txn + ] if not forwards: results.reverse() return results - ret_val = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9f3b8741c1..8d9df352b2 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -93,7 +93,7 @@ class Clock: _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations + @defer.inlineCallbacks def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext():