diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index a75f6f2cc4..36ee9c8b8f 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -115,7 +115,7 @@ class InternalAuth(BaseAuth):
Once get_user_by_req has set up the opentracing span, this does the actual work.
"""
try:
- ip_addr = request.getClientAddress().host
+ ip_addr = request.get_client_ip_if_available()
user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index b78f419994..afef6712e1 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -80,10 +80,6 @@ class UserPresenceState:
def as_dict(self) -> JsonDict:
return attr.asdict(self)
- @staticmethod
- def from_dict(d: JsonDict) -> "UserPresenceState":
- return UserPresenceState(**d)
-
def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return attr.evolve(self, **kwargs)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c8bc46415d..1a7fa175ec 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1402,7 +1402,7 @@ class FederationClient(FederationBase):
The remote homeserver return some state from the room. The response
dictionary is in the form:
- {"knock_state_events": [<state event dict>, ...]}
+ {"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
@@ -1429,7 +1429,7 @@ class FederationClient(FederationBase):
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
- {"knock_state_events": [<state event dict>, ...]}
+ {"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ec8e770430..6ac8d16095 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -850,14 +850,7 @@ class FederationServer(FederationBase):
context, self._room_prejoin_state_types
)
)
- return {
- "knock_room_state": stripped_room_state,
- # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
- # Thus, we also populate a 'knock_state_events' with the same content to
- # support old instances.
- # See https://github.com/matrix-org/synapse/issues/14088.
- "knock_state_events": stripped_room_state,
- }
+ return {"knock_room_state": stripped_room_state}
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 6520795635..525968bcba 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -395,7 +395,7 @@ class PresenceDestinationsRow(BaseFederationRow):
@staticmethod
def from_data(data: JsonDict) -> "PresenceDestinationsRow":
return PresenceDestinationsRow(
- state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
+ state=UserPresenceState(**data["state"]), destinations=data["dests"]
)
def to_data(self) -> JsonDict:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index fb20fd8a10..7b6b1da090 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -67,7 +67,7 @@ The loop continues so long as there is anything to send. At each iteration of th
When the `PerDestinationQueue` has the catch-up flag set, the *Catch-Up Transmission Loop*
(`_catch_up_transmission_loop`) is used in lieu of the regular `_transaction_transmission_loop`.
-(Only once the catch-up mode has been exited can the regular tranaction transmission behaviour
+(Only once the catch-up mode has been exited can the regular transaction transmission behaviour
be resumed.)
*Catch-Up Mode*, entered upon Synapse startup or once a homeserver has fallen behind due to
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 37903a79ec..8b81d8a09a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -431,7 +431,7 @@ class TransportLayerClient:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
- {"knock_state_events": [<state event dict>, ...]}
+ {"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
"""
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index f1a7a05df6..6c2a49a3b9 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -212,8 +212,8 @@ class AccountValidityHandler:
addresses = []
for threepid in threepids:
- if threepid["medium"] == "email":
- addresses.append(threepid["address"])
+ if threepid.medium == "email":
+ addresses.append(threepid.address)
return addresses
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index ba9704a065..2c2baeac67 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,6 +16,8 @@ import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
+import attr
+
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
@@ -93,7 +95,7 @@ class AdminHandler:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
- user_info_dict["threepids"] = threepids
+ user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
@@ -171,8 +173,8 @@ class AdminHandler:
else:
stream_ordering = room.stream_ordering
- from_key = RoomStreamToken(0, 0)
- to_key = RoomStreamToken(None, stream_ordering)
+ from_key = RoomStreamToken(topological=0, stream=0)
+ to_key = RoomStreamToken(stream=stream_ordering)
# Events that we've processed in this room
written_events: Set[str] = set()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 7de7bd3289..c200a45f3a 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -216,7 +216,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
@@ -326,7 +326,7 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: int,
users: Collection[Union[str, UserID]],
) -> None:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 67adeae6a7..6a8f8f2fd1 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -117,9 +117,9 @@ class DeactivateAccountHandler:
# Remove any local threepid associations for this account.
local_threepids = await self.store.user_get_threepids(user_id)
- for threepid in local_threepids:
+ for local_threepid in local_threepids:
await self._auth_handler.delete_local_threepid(
- user_id, threepid["medium"], threepid["address"]
+ user_id, local_threepid.medium, local_threepid.address
)
# delete any devices belonging to the user, which will also
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 86ad96d030..544bc7c13d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,17 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes
@@ -41,6 +31,7 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
+from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import (
JsonDict,
JsonMapping,
@@ -845,7 +836,6 @@ class DeviceHandler(DeviceWorkerHandler):
else:
assert max_stream_id == stream_id
# Avoid moving `room_id` backwards.
- pass
if self._handle_new_device_update_new_data:
continue
@@ -1009,14 +999,14 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
+ device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
) -> None:
- ip = client_ips.get((device["user_id"], device["device_id"]), {})
+ ip = client_ips.get((device["user_id"], device["device_id"]))
device.update(
{
- "last_seen_user_agent": ip.get("user_agent"),
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
+ "last_seen_user_agent": ip.user_agent if ip else None,
+ "last_seen_ts": ip.last_seen if ip else None,
+ "last_seen_ip": ip.ip if ip else None,
}
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 29cd45550a..9d72794e8b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -868,19 +868,10 @@ class FederationHandler:
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
- stripped_room_state = (
- knock_response.get("knock_room_state")
- # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
- # Thus, we also check for a 'knock_state_events' to support old instances.
- # See https://github.com/matrix-org/synapse/issues/14088.
- or knock_response.get("knock_state_events")
- )
+ stripped_room_state = knock_response.get("knock_room_state")
if stripped_room_state is None:
- raise KeyError(
- "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
- "send_knock response"
- )
+ raise KeyError("Missing 'knock_room_state' field in send_knock response")
event.unsigned["knock_room_state"] = stripped_room_state
@@ -1506,7 +1497,6 @@ class FederationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
@@ -1582,7 +1572,6 @@ class FederationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
async def add_display_name_to_third_party_invite(
self,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5737f8014d..c34bd7db95 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -192,8 +192,7 @@ class InitialSyncHandler:
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
- None,
- event.stream_ordering,
+ stream=event.stream_ordering,
)
deferred_room_state = run_in_background(
self._state_storage_controller.get_state_for_events,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 44dbbf81dd..41a35ce510 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1133,7 +1133,6 @@ class EventCreationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
# we know it was persisted, so must have a stream ordering
assert ev.internal_metadata.stream_ordering
@@ -2038,7 +2037,6 @@ class EventCreationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
return True
except AuthError:
logger.info(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7c7cda3e95..dfc0b9db07 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -110,6 +110,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -1499,9 +1500,9 @@ class PresenceHandler(BasePresenceHandler):
# We may get multiple deltas for different rooms, but we want to
# handle them on a room by room basis, so we batch them up by
# room.
- deltas_by_room: Dict[str, List[JsonDict]] = {}
+ deltas_by_room: Dict[str, List[StateDelta]] = {}
for delta in deltas:
- deltas_by_room.setdefault(delta["room_id"], []).append(delta)
+ deltas_by_room.setdefault(delta.room_id, []).append(delta)
for room_id, deltas_for_room in deltas_by_room.items():
await self._handle_state_delta(room_id, deltas_for_room)
@@ -1513,7 +1514,7 @@ class PresenceHandler(BasePresenceHandler):
max_pos
)
- async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None:
+ async def _handle_state_delta(self, room_id: str, deltas: List[StateDelta]) -> None:
"""Process current state deltas for the room to find new joins that need
to be handled.
"""
@@ -1524,31 +1525,30 @@ class PresenceHandler(BasePresenceHandler):
newly_joined_users = set()
for delta in deltas:
- assert room_id == delta["room_id"]
+ assert room_id == delta.room_id
- typ = delta["type"]
- state_key = delta["state_key"]
- event_id = delta["event_id"]
- prev_event_id = delta["prev_event_id"]
-
- logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+ logger.debug(
+ "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+ )
# Drop any event that isn't a membership join
- if typ != EventTypes.Member:
+ if delta.event_type != EventTypes.Member:
continue
- if event_id is None:
+ if delta.event_id is None:
# state has been deleted, so this is not a join. We only care about
# joins.
continue
- event = await self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(delta.event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
- if prev_event_id:
- prev_event = await self.store.get_event(prev_event_id, allow_none=True)
+ if delta.prev_event_id:
+ prev_event = await self.store.get_event(
+ delta.prev_event_id, allow_none=True
+ )
if (
prev_event
and prev_event.content.get("membership") == Membership.JOIN
@@ -1556,7 +1556,7 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events.
continue
- newly_joined_users.add(state_key)
+ newly_joined_users.add(delta.state_key)
if not newly_joined_users:
# If nobody has joined then there's nothing to do.
diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
index 7ed88a3611..87b428ab1c 100644
--- a/synapse/handlers/push_rules.py
+++ b/synapse/handlers/push_rules.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -114,7 +114,9 @@ class PushRulesHandler:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
- self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
+ self._notifier.on_new_event(
+ StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
+ )
async def push_rules_for_user(
self, user: UserID
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a7a29b758b..69ac468f75 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -130,11 +130,10 @@ class ReceiptsHandler:
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
- min_batch_id: Optional[int] = None
- max_batch_id: Optional[int] = None
+ receipts_persisted: List[ReadReceipt] = []
for receipt in receipts:
- res = await self.store.insert_receipt(
+ stream_id = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -143,30 +142,26 @@ class ReceiptsHandler:
receipt.data,
)
- if not res:
- # res will be None if this receipt is 'old'
+ if stream_id is None:
+ # stream_id will be None if this receipt is 'old'
continue
- stream_id, max_persisted_id = res
+ receipts_persisted.append(receipt)
- if min_batch_id is None or stream_id < min_batch_id:
- min_batch_id = stream_id
- if max_batch_id is None or max_persisted_id > max_batch_id:
- max_batch_id = max_persisted_id
-
- # Either both of these should be None or neither.
- if min_batch_id is None or max_batch_id is None:
+ if not receipts_persisted:
# no new receipts
return False
- affected_room_ids = list({r.room_id for r in receipts})
+ max_batch_id = self.store.get_max_receipt_stream_id()
+
+ affected_room_ids = list({r.room_id for r in receipts_persisted})
self.notifier.on_new_event(
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
+ {r.user_id for r in receipts_persisted}
)
return True
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a0c3b16819..97c9f01245 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -261,7 +261,6 @@ class RoomCreationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
# This is to satisfy mypy and should never happen
raise PartialStateConflictError()
@@ -1708,7 +1707,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
- from_key = RoomStreamToken(None, from_key.stream)
+ from_key = RoomStreamToken(stream=from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4bd7efc738..9a6b02a16b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from synapse import types
from synapse.api.constants import (
@@ -44,6 +44,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
JsonDict,
Requester,
@@ -382,8 +383,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
and persist a new event for the new membership change.
Args:
- requester:
- target:
+ requester: User requesting the membership change, i.e. the sender of the
+ desired membership event.
+ target: Use whose membership should change, i.e. the state_key of the
+ desired membership event.
room_id:
membership:
@@ -415,7 +418,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Returns:
Tuple of event ID and stream ordering position
"""
-
user_id = target.to_string()
if content is None:
@@ -475,21 +477,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(EventTypes.Member, user_id), None
)
- if event.membership == Membership.JOIN:
- newly_joined = True
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(
- prev_member_event_id
- )
- newly_joined = prev_member_event.membership != Membership.JOIN
-
- # Only rate-limit if the user actually joined the room, otherwise we'll end
- # up blocking profile updates.
- if newly_joined and ratelimit:
- await self._join_rate_limiter_local.ratelimit(requester)
- await self._join_rate_per_room_limiter.ratelimit(
- requester, key=room_id, update=False
- )
with opentracing.start_active_span("handle_new_client_event"):
result_event = (
await self.event_creation_handler.handle_new_client_event(
@@ -514,7 +501,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
# we know it was persisted, so should have a stream ordering
assert result_event.internal_metadata.stream_ordering
@@ -618,6 +604,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Raises:
ShadowBanError if a shadow-banned requester attempts to send an invite.
"""
+ if ratelimit:
+ if action == Membership.JOIN:
+ # Only rate-limit if the user isn't already joined to the room, otherwise
+ # we'll end up blocking profile updates.
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ requester.user.to_string(),
+ room_id,
+ )
+ if current_membership != Membership.JOIN:
+ await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
+ elif action == Membership.INVITE:
+ await self.ratelimit_invite(requester, room_id, target.to_string())
+
if action == Membership.INVITE and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
@@ -808,8 +813,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
- if ratelimit:
- await self.ratelimit_invite(requester, room_id, target_id)
# block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid:
@@ -2016,7 +2019,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
- pass
# we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering
@@ -2159,24 +2161,18 @@ class RoomForgetterHandler(StateDeltasHandler):
await self._store.update_room_forgetter_stream_pos(max_pos)
- async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+ async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
"""Called with the state deltas to process"""
for delta in deltas:
- typ = delta["type"]
- state_key = delta["state_key"]
- room_id = delta["room_id"]
- event_id = delta["event_id"]
- prev_event_id = delta["prev_event_id"]
-
- if typ != EventTypes.Member:
+ if delta.event_type != EventTypes.Member:
continue
- if not self._hs.is_mine_id(state_key):
+ if not self._hs.is_mine_id(delta.state_key):
continue
change = await self._get_key_change(
- prev_event_id,
- event_id,
+ delta.prev_event_id,
+ delta.event_id,
key_name="membership",
public_value=Membership.JOIN,
)
@@ -2185,7 +2181,7 @@ class RoomForgetterHandler(StateDeltasHandler):
if is_leave:
try:
await self._room_member_handler.forget(
- UserID.from_string(state_key), room_id
+ UserID.from_string(delta.state_key), delta.room_id
)
except SynapseError as e:
if e.code == 400:
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3dde19fc81..817b41aa37 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -27,6 +27,7 @@ from typing import (
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -142,7 +143,7 @@ class StatsHandler:
self.pos = max_pos
async def _handle_deltas(
- self, deltas: Iterable[JsonDict]
+ self, deltas: Iterable[StateDelta]
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
@@ -157,51 +158,50 @@ class StatsHandler:
room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas:
- typ = delta["type"]
- state_key = delta["state_key"]
- room_id = delta["room_id"]
- event_id = delta["event_id"]
- stream_id = delta["stream_id"]
- prev_event_id = delta["prev_event_id"]
-
- logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id)
+ logger.debug(
+ "Handling: %r, %r %r, %s",
+ delta.room_id,
+ delta.event_type,
+ delta.state_key,
+ delta.event_id,
+ )
- token = await self.store.get_earliest_token_for_stats("room", room_id)
+ token = await self.store.get_earliest_token_for_stats("room", delta.room_id)
# If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta.
- if token is not None and token >= stream_id:
+ if token is not None and token >= delta.stream_id:
logger.debug(
"Ignoring: %s as earlier than this room's initial ingestion event",
- event_id,
+ delta.event_id,
)
continue
- if event_id is None and prev_event_id is None:
+ if delta.event_id is None and delta.prev_event_id is None:
logger.error(
"event ID is None and so is the previous event ID. stream_id: %s",
- stream_id,
+ delta.stream_id,
)
continue
event_content: JsonDict = {}
- if event_id is not None:
- event = await self.store.get_event(event_id, allow_none=True)
+ if delta.event_id is not None:
+ event = await self.store.get_event(delta.event_id, allow_none=True)
if event:
event_content = event.content or {}
# All the values in this dict are deltas (RELATIVE changes)
- room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
+ room_stats_delta = room_to_stats_deltas.setdefault(delta.room_id, Counter())
- room_state = room_to_state_updates.setdefault(room_id, {})
+ room_state = room_to_state_updates.setdefault(delta.room_id, {})
- if prev_event_id is None:
+ if delta.prev_event_id is None:
# this state event doesn't overwrite another,
# so it is a new effective/current state event
room_stats_delta["current_state_events"] += 1
- if typ == EventTypes.Member:
+ if delta.event_type == EventTypes.Member:
# we could use StateDeltasHandler._get_key_change here but it's
# a bit inefficient given we're not testing for a specific
# result; might as well just grab the prev_membership and
@@ -210,9 +210,9 @@ class StatsHandler:
# in the absence of a previous event because we do not want to
# reduce the leave count when a new-to-the-room user joins.
prev_membership = None
- if prev_event_id is not None:
+ if delta.prev_event_id is not None:
prev_event = await self.store.get_event(
- prev_event_id, allow_none=True
+ delta.prev_event_id, allow_none=True
)
if prev_event:
prev_event_content = prev_event.content
@@ -256,7 +256,7 @@ class StatsHandler:
else:
raise ValueError("%r is not a valid membership" % (membership,))
- user_id = state_key
+ user_id = delta.state_key
if self.is_mine_id(user_id):
# this accounts for transitions like leave → ban and so on.
has_changed_joinedness = (prev_membership == Membership.JOIN) != (
@@ -272,30 +272,30 @@ class StatsHandler:
room_stats_delta["local_users_in_room"] += membership_delta
- elif typ == EventTypes.Create:
+ elif delta.event_type == EventTypes.Create:
room_state["is_federatable"] = (
event_content.get(EventContentFields.FEDERATE, True) is True
)
room_type = event_content.get(EventContentFields.ROOM_TYPE)
if isinstance(room_type, str):
room_state["room_type"] = room_type
- elif typ == EventTypes.JoinRules:
+ elif delta.event_type == EventTypes.JoinRules:
room_state["join_rules"] = event_content.get("join_rule")
- elif typ == EventTypes.RoomHistoryVisibility:
+ elif delta.event_type == EventTypes.RoomHistoryVisibility:
room_state["history_visibility"] = event_content.get(
"history_visibility"
)
- elif typ == EventTypes.RoomEncryption:
+ elif delta.event_type == EventTypes.RoomEncryption:
room_state["encryption"] = event_content.get("algorithm")
- elif typ == EventTypes.Name:
+ elif delta.event_type == EventTypes.Name:
room_state["name"] = event_content.get("name")
- elif typ == EventTypes.Topic:
+ elif delta.event_type == EventTypes.Topic:
room_state["topic"] = event_content.get("topic")
- elif typ == EventTypes.RoomAvatar:
+ elif delta.event_type == EventTypes.RoomAvatar:
room_state["avatar"] = event_content.get("url")
- elif typ == EventTypes.CanonicalAlias:
+ elif delta.event_type == EventTypes.CanonicalAlias:
room_state["canonical_alias"] = event_content.get("alias")
- elif typ == EventTypes.GuestAccess:
+ elif delta.event_type == EventTypes.GuestAccess:
room_state["guest_access"] = event_content.get(
EventContentFields.GUEST_ACCESS
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7bd42f635f..60b4d95cd7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -40,7 +40,6 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
from synapse.handlers.relations import BundledAggregations
from synapse.logging import issue9533_logger
from synapse.logging.context import current_context
@@ -363,36 +362,15 @@ class SyncHandler:
# (since we now know that the device has received them)
if since_token is not None:
since_stream_id = since_token.to_device_key
- # Fast path: delete a limited number of to-device messages up front.
- # We do this to avoid the overhead of scheduling a task for every
- # sync.
- device_deletion_limit = 100
deleted = await self.store.delete_messages_for_device(
sync_config.user.to_string(),
sync_config.device_id,
since_stream_id,
- limit=device_deletion_limit,
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
)
- # If we hit the limit, schedule a background task to delete the rest.
- if deleted >= device_deletion_limit:
- await self._task_scheduler.schedule_task(
- DELETE_DEVICE_MSGS_TASK_NAME,
- resource_id=sync_config.device_id,
- params={
- "user_id": sync_config.user.to_string(),
- "device_id": sync_config.device_id,
- "up_to_stream_id": since_stream_id,
- },
- )
- logger.debug(
- "Deletion of to-device messages up to %d scheduled",
- since_stream_id,
- )
-
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
@@ -2333,7 +2311,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
+ StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a0f5568000..75717ba4f9 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,7 +14,7 @@
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from twisted.internet.interfaces import IDelayedCall
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Memb
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
from synapse.types import UserID
@@ -247,32 +248,31 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
- async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+ async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
"""Called with the state deltas to process"""
for delta in deltas:
- typ = delta["type"]
- state_key = delta["state_key"]
- room_id = delta["room_id"]
- event_id: Optional[str] = delta["event_id"]
- prev_event_id: Optional[str] = delta["prev_event_id"]
-
- logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+ logger.debug(
+ "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+ )
# For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room
- if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
+ if delta.event_type in (
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.JoinRules,
+ ):
await self._handle_room_publicity_change(
- room_id, prev_event_id, event_id, typ
+ delta.room_id, delta.prev_event_id, delta.event_id, delta.event_type
)
- elif typ == EventTypes.Member:
+ elif delta.event_type == EventTypes.Member:
await self._handle_room_membership_event(
- room_id,
- prev_event_id,
- event_id,
- state_key,
+ delta.room_id,
+ delta.prev_event_id,
+ delta.event_id,
+ delta.state_key,
)
else:
- logger.debug("Ignoring irrelevant type: %r", typ)
+ logger.debug("Ignoring irrelevant type: %r", delta.event_type)
async def _handle_room_publicity_change(
self,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 3bbf91298e..1e4e56f36b 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -266,7 +266,7 @@ class HttpServer(Protocol):
def register_paths(
self,
method: str,
- path_patterns: Iterable[Pattern],
+ path_patterns: Iterable[Pattern[str]],
callback: ServletCallback,
servlet_classname: str,
) -> None:
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 80c448de2b..860e5ddca2 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -26,11 +26,11 @@ from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.server import Request
-from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
+from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__)
@@ -84,52 +84,12 @@ INLINE_CONTENT_TYPES = [
]
-def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
- """Parses the server name, media ID and optional file name from the request URI
-
- Also performs some rough validation on the server name.
-
- Args:
- request: The `Request`.
-
- Returns:
- A tuple containing the parsed server name, media ID and optional file name.
-
- Raises:
- SynapseError(404): if parsing or validation fail for any reason
- """
- try:
- # The type on postpath seems incorrect in Twisted 21.2.0.
- postpath: List[bytes] = request.postpath # type: ignore
- assert postpath
-
- # This allows users to append e.g. /test.png to the URL. Useful for
- # clients that parse the URL to see content type.
- server_name_bytes, media_id_bytes = postpath[:2]
- server_name = server_name_bytes.decode("utf-8")
- media_id = media_id_bytes.decode("utf8")
-
- # Validate the server name, raising if invalid
- parse_and_validate_server_name(server_name)
-
- file_name = None
- if len(postpath) > 2:
- try:
- file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
- except UnicodeDecodeError:
- pass
- return server_name, media_id, file_name
- except Exception:
- raise SynapseError(
- 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
- )
-
-
def respond_404(request: SynapseRequest) -> None:
+ assert request.path is not None
respond_with_json(
request,
404,
- cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND),
+ cs_error("Not found '%s'" % (request.path.decode(),), code=Codes.NOT_FOUND),
send_cors=True,
)
@@ -188,7 +148,9 @@ def add_file_headers(
# A strict subset of content types is allowed to be inlined so that they may
# be viewed directly in a browser. Other file types are forced to be downloads.
- if media_type.lower() in INLINE_CONTENT_TYPES:
+ #
+ # Only the type & subtype are important, parameters can be ignored.
+ if media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
disposition = "inline"
else:
disposition = "attachment"
@@ -372,7 +334,7 @@ class ThumbnailInfo:
# Content type of thumbnail, e.g. image/png
type: str
# The size of the media file, in bytes.
- length: Optional[int] = None
+ length: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1b7b014f9a..7fd46901f7 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -48,6 +48,7 @@ from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage
from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
+from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
@@ -114,7 +115,7 @@ class MediaRepository:
)
storage_providers.append(provider)
- self.media_storage = MediaStorage(
+ self.media_storage: MediaStorage = MediaStorage(
self.hs, self.primary_base_path, self.filepaths, storage_providers
)
@@ -142,6 +143,13 @@ class MediaRepository:
MEDIA_RETENTION_CHECK_PERIOD_MS,
)
+ if hs.config.media.url_preview_enabled:
+ self.url_previewer: Optional[UrlPreviewer] = UrlPreviewer(
+ hs, self, self.media_storage
+ )
+ else:
+ self.url_previewer = None
+
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
@@ -616,6 +624,7 @@ class MediaRepository:
height=t_height,
method=t_method,
type=t_type,
+ length=t_byte_source.tell(),
),
)
@@ -686,6 +695,7 @@ class MediaRepository:
height=t_height,
method=t_method,
type=t_type,
+ length=t_byte_source.tell(),
),
)
@@ -831,6 +841,7 @@ class MediaRepository:
height=t_height,
method=t_method,
type=t_type,
+ length=t_byte_source.tell(),
),
)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 65e2aca456..0786d20635 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -678,7 +678,7 @@ class ModuleApi:
"msisdn" for phone numbers, and an "address" key which value is the
threepid's address.
"""
- return await self._store.user_get_threepids(user_id)
+ return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]
def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index fc39e5c963..99e7715896 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -126,7 +126,7 @@ class _NotifierUserStream:
def notify(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
) -> None:
@@ -454,7 +454,7 @@ class Notifier:
def on_new_event(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
@@ -655,30 +655,29 @@ class Notifier:
events: List[Union[JsonDict, EventBase]] = []
end_token = from_token
- for name, source in self.event_sources.sources.get_sources():
- keyname = "%s_key" % name
- before_id = getattr(before_token, keyname)
- after_id = getattr(after_token, keyname)
+ for keyname, source in self.event_sources.sources.get_sources():
+ before_id = before_token.get_field(keyname)
+ after_id = after_token.get_field(keyname)
if before_id == after_id:
continue
new_events, new_key = await source.get_new_events(
user=user,
- from_key=getattr(from_token, keyname),
+ from_key=from_token.get_field(keyname),
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)
- if name == "room":
+ if keyname == StreamKeyType.ROOM:
new_events = await filter_events_for_client(
self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
)
- elif name == "presence":
+ elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec()
new_events[:] = [
{
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9e3a98741a..4d405f2a0c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -101,7 +101,7 @@ if TYPE_CHECKING:
class PusherConfig:
"""Parameters necessary to configure a pusher."""
- id: Optional[str]
+ id: Optional[int]
user_name: str
profile_tag: str
@@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 1710dd51b9..cf45fd09a8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -99,7 +99,7 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 0cb5dc2cc9..12b971f239 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -165,7 +165,7 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 6517e3566f..15a2cc932f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -292,20 +292,12 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(
- self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
- ) -> None:
+ async def on_new_receipts(self, users_affected: StrCollection) -> None:
if not self.pushers:
# nothing to do here.
return
try:
- # Need to subtract 1 from the minimum because the lower bound here
- # is not inclusive
- users_affected = await self.store.get_users_sent_receipts_between(
- min_stream_id - 1, max_stream_id
- )
-
for u in users_affected:
# Don't push if the user account has expired
expired = await self._account_validity_handler.is_user_expired(u)
@@ -314,7 +306,7 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
- p.on_new_receipts(min_stream_id, max_stream_id)
+ p.on_new_receipts()
except Exception:
logger.exception("Exception in pusher on_new_receipts")
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 53ad327030..e728297dce 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -138,7 +138,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts.append((event, context))
- logger.info("Got %d events from federation", len(event_and_contexts))
+ logger.info(
+ "Got batch of %i events to persist to room %s",
+ len(event_and_contexts),
+ room_id,
+ )
max_stream_id = await self.federation_event_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py
index 4f82c9f96d..8eea256063 100644
--- a/synapse/replication/http/send_events.py
+++ b/synapse/replication/http/send_events.py
@@ -118,6 +118,7 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_send_events_parse"):
events_and_context = []
events = payload["events"]
+ rooms = set()
for event_payload in events:
event_dict = event_payload["event"]
@@ -144,11 +145,13 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
UserID.from_string(u) for u in event_payload["extra_users"]
]
- logger.info(
- "Got batch of events to send, last ID of batch is: %s, sending into room: %s",
- event.event_id,
- event.room_id,
- )
+ # all the rooms *should* be the same, but we'll log separately to be
+ # sure.
+ rooms.add(event.room_id)
+
+ logger.info(
+ "Got batch of %i events to persist to rooms %s", len(events), rooms
+ )
last_event = (
await self.event_creation_handler.persist_and_notify_client_events(
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index f4f2b29e96..d5337fe588 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -129,9 +129,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
)
- await self._pusher_pool.on_new_receipts(
- token, token, {row.room_id for row in rows}
- )
+ await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index e616b5e1c8..0f0f851b79 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,7 +18,7 @@ allowed to be sent by which side.
"""
import abc
import logging
-from typing import Optional, Tuple, Type, TypeVar
+from typing import List, Optional, Tuple, Type, TypeVar
from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder
@@ -74,6 +74,8 @@ SC = TypeVar("SC", bound="_SimpleCommand")
class _SimpleCommand(Command):
"""An implementation of Command whose argument is just a 'data' string."""
+ __slots__ = ["data"]
+
def __init__(self, data: str):
self.data = data
@@ -122,6 +124,8 @@ class RdataCommand(Command):
RDATA presence master 59 ["@baz:example.com", "online", ...]
"""
+ __slots__ = ["stream_name", "instance_name", "token", "row"]
+
NAME = "RDATA"
def __init__(
@@ -179,6 +183,8 @@ class PositionCommand(Command):
of the stream.
"""
+ __slots__ = ["stream_name", "instance_name", "prev_token", "new_token"]
+
NAME = "POSITION"
def __init__(
@@ -235,6 +241,8 @@ class ReplicateCommand(Command):
REPLICATE
"""
+ __slots__: List[str] = []
+
NAME = "REPLICATE"
def __init__(self) -> None:
@@ -264,6 +272,8 @@ class UserSyncCommand(Command):
Where <state> is either "start" or "end"
"""
+ __slots__ = ["instance_id", "user_id", "device_id", "is_syncing", "last_sync_ms"]
+
NAME = "USER_SYNC"
def __init__(
@@ -316,6 +326,8 @@ class ClearUserSyncsCommand(Command):
CLEAR_USER_SYNC <instance_id>
"""
+ __slots__ = ["instance_id"]
+
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id: str):
@@ -343,6 +355,8 @@ class FederationAckCommand(Command):
FEDERATION_ACK <instance_name> <token>
"""
+ __slots__ = ["instance_name", "token"]
+
NAME = "FEDERATION_ACK"
def __init__(self, instance_name: str, token: int):
@@ -368,6 +382,15 @@ class UserIpCommand(Command):
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
"""
+ __slots__ = [
+ "user_id",
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ]
+
NAME = "USER_IP"
def __init__(
@@ -423,8 +446,6 @@ class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer
"down" and retry timings should be reset.
- If sent from a client the server will relay to all other workers.
-
Format::
REMOTE_SERVER_UP <server>
@@ -441,6 +462,8 @@ class LockReleasedCommand(Command):
LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
"""
+ __slots__ = ["instance_name", "lock_name", "lock_key"]
+
NAME = "LOCK_RELEASED"
def __init__(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e42dade246..9bd0d764f8 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -146,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet):
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken(
- event.depth, event.internal_metadata.stream_ordering
+ topological=event.depth, stream=event.internal_metadata.stream_ordering
)
token = await room_token.to_string(self.store)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index e0ee55bd0e..8a617af599 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -198,7 +198,13 @@ class DestinationMembershipRestServlet(RestServlet):
rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction
)
- response = {"rooms": rooms, "total": total}
+ response = {
+ "rooms": [
+ {"room_id": room_id, "stream_ordering": stream_ordering}
+ for room_id, stream_ordering in rooms
+ ],
+ "total": total,
+ }
if (start + limit) < total:
response["next_token"] = str(start + len(rooms))
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 5b743a1d03..7fe16130e7 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -329,9 +329,8 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
# get changed threepids (added and removed)
- # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = {
- (threepid["medium"], threepid["address"])
+ (threepid.medium, threepid.address)
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
@@ -842,7 +841,18 @@ class SearchUsersRestServlet(RestServlet):
logger.info("term: %s ", term)
ret = await self.store.search_users(term)
- return HTTPStatus.OK, ret
+ results = [
+ {
+ "name": name,
+ "password_hash": password_hash,
+ "is_guest": bool(is_guest),
+ "admin": bool(admin),
+ "user_type": user_type,
+ }
+ for name, password_hash, is_guest, admin, user_type in ret
+ ]
+
+ return HTTPStatus.OK, results
class UserAdminServlet(RestServlet):
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index e74a87af4d..641390cb30 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -24,6 +24,8 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictBool, StrictStr, constr
else:
from pydantic import StrictBool, StrictStr, constr
+
+import attr
from typing_extensions import Literal
from twisted.web.server import Request
@@ -595,7 +597,7 @@ class ThreepidRestServlet(RestServlet):
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
- return 200, {"threepids": threepids}
+ return 200, {"threepids": [attr.asdict(t) for t in threepids]}
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
# the endpoint is deprecated. (If you really want to, you could do this by reusing
diff --git a/synapse/rest/media/config_resource.py b/synapse/rest/media/config_resource.py
index a95804d327..dbf5133c72 100644
--- a/synapse/rest/media/config_resource.py
+++ b/synapse/rest/media/config_resource.py
@@ -14,17 +14,19 @@
# limitations under the License.
#
+import re
from typing import TYPE_CHECKING
-from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.server import respond_with_json
+from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
-class MediaConfigResource(DirectServeJsonResource):
- isLeaf = True
+class MediaConfigResource(RestServlet):
+ PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/config$")]
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -33,9 +35,6 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
- async def _async_render_GET(self, request: SynapseRequest) -> None:
+ async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
-
- async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
- respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 3c618ef60a..65b9ff52fa 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -13,16 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+import re
+from typing import TYPE_CHECKING, Optional
-from synapse.http.server import (
- DirectServeJsonResource,
- set_corp_headers,
- set_cors_headers,
-)
-from synapse.http.servlet import parse_boolean
+from synapse.http.server import set_corp_headers, set_cors_headers
+from synapse.http.servlet import RestServlet, parse_boolean
from synapse.http.site import SynapseRequest
-from synapse.media._base import parse_media_id, respond_404
+from synapse.media._base import respond_404
+from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
@@ -31,15 +29,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class DownloadResource(DirectServeJsonResource):
- isLeaf = True
+class DownloadResource(RestServlet):
+ PATTERNS = [
+ re.compile(
+ "/_matrix/media/(r0|v3|v1)/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
+ )
+ ]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self._is_mine_server_name = hs.is_mine_server_name
- async def _async_render_GET(self, request: SynapseRequest) -> None:
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ file_name: Optional[str] = None,
+ ) -> None:
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+
set_cors_headers(request)
set_corp_headers(request)
request.setHeader(
@@ -58,9 +69,8 @@ class DownloadResource(DirectServeJsonResource):
b"Referrer-Policy",
b"no-referrer",
)
- server_name, media_id, name = parse_media_id(request)
if self._is_mine_server_name(server_name):
- await self.media_repo.get_local_media(request, media_id, name)
+ await self.media_repo.get_local_media(request, media_id, file_name)
else:
allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
@@ -72,4 +82,6 @@ class DownloadResource(DirectServeJsonResource):
respond_404(request)
return
- await self.media_repo.get_remote_media(request, server_name, media_id, name)
+ await self.media_repo.get_remote_media(
+ request, server_name, media_id, file_name
+ )
diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py
index 5ebaa3b032..2089bb1029 100644
--- a/synapse/rest/media/media_repository_resource.py
+++ b/synapse/rest/media/media_repository_resource.py
@@ -15,7 +15,7 @@
from typing import TYPE_CHECKING
from synapse.config._base import ConfigError
-from synapse.http.server import UnrecognizedRequestResource
+from synapse.http.server import HttpServer, JsonResource
from .config_resource import MediaConfigResource
from .download_resource import DownloadResource
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-class MediaRepositoryResource(UnrecognizedRequestResource):
+class MediaRepositoryResource(JsonResource):
"""File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET
@@ -70,6 +70,11 @@ class MediaRepositoryResource(UnrecognizedRequestResource):
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
+
+ This gets mounted at various points under /_matrix/media, including:
+ * /_matrix/media/r0
+ * /_matrix/media/v1
+ * /_matrix/media/v3
"""
def __init__(self, hs: "HomeServer"):
@@ -77,17 +82,23 @@ class MediaRepositoryResource(UnrecognizedRequestResource):
if not hs.config.media.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
- super().__init__()
+ JsonResource.__init__(self, hs, canonical_json=False)
+ self.register_servlets(self, hs)
+
+ @staticmethod
+ def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
media_repo = hs.get_media_repository()
- self.putChild(b"upload", UploadResource(hs, media_repo))
- self.putChild(b"download", DownloadResource(hs, media_repo))
- self.putChild(
- b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
+ # Note that many of these should not exist as v1 endpoints, but empirically
+ # a lot of traffic still goes to them.
+
+ UploadResource(hs, media_repo).register(http_server)
+ DownloadResource(hs, media_repo).register(http_server)
+ ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+ http_server
)
if hs.config.media.url_preview_enabled:
- self.putChild(
- b"preview_url",
- PreviewUrlResource(hs, media_repo, media_repo.media_storage),
+ PreviewUrlResource(hs, media_repo, media_repo.media_storage).register(
+ http_server
)
- self.putChild(b"config", MediaConfigResource(hs))
+ MediaConfigResource(hs).register(http_server)
diff --git a/synapse/rest/media/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index 58513c4be4..c8acb65dca 100644
--- a/synapse/rest/media/preview_url_resource.py
+++ b/synapse/rest/media/preview_url_resource.py
@@ -13,24 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from typing import TYPE_CHECKING
-from synapse.http.server import (
- DirectServeJsonResource,
- respond_with_json,
- respond_with_json_bytes,
-)
-from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.server import respond_with_json_bytes
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media.media_storage import MediaStorage
-from synapse.media.url_previewer import UrlPreviewer
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
-class PreviewUrlResource(DirectServeJsonResource):
+class PreviewUrlResource(RestServlet):
"""
The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
@@ -48,7 +44,7 @@ class PreviewUrlResource(DirectServeJsonResource):
* Matrix cannot be used to distribute the metadata between homeservers.
"""
- isLeaf = True
+ PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/preview_url$")]
def __init__(
self,
@@ -62,14 +58,10 @@ class PreviewUrlResource(DirectServeJsonResource):
self.clock = hs.get_clock()
self.media_repo = media_repo
self.media_storage = media_storage
+ assert self.media_repo.url_previewer is not None
+ self.url_previewer = self.media_repo.url_previewer
- self._url_previewer = UrlPreviewer(hs, media_repo, media_storage)
-
- async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
- request.setHeader(b"Allow", b"OPTIONS, GET")
- respond_with_json(request, 200, {}, send_cors=True)
-
- async def _async_render_GET(self, request: SynapseRequest) -> None:
+ async def on_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True)
@@ -77,5 +69,5 @@ class PreviewUrlResource(DirectServeJsonResource):
if ts is None:
ts = self.clock.time_msec()
- og = await self._url_previewer.preview(url, requester.user, ts)
+ og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 661e604b85..85b6bdbe72 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -13,29 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import re
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
-from synapse.http.server import (
- DirectServeJsonResource,
- respond_with_json,
- set_corp_headers,
- set_cors_headers,
-)
-from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
FileInfo,
ThumbnailInfo,
- parse_media_id,
respond_404,
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
+from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
@@ -44,8 +39,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class ThumbnailResource(DirectServeJsonResource):
- isLeaf = True
+class ThumbnailResource(RestServlet):
+ PATTERNS = [
+ re.compile(
+ "/_matrix/media/(r0|v3|v1)/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+ )
+ ]
def __init__(
self,
@@ -60,12 +59,17 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
+ self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
- async def _async_render_GET(self, request: SynapseRequest) -> None:
+ async def on_GET(
+ self, request: SynapseRequest, server_name: str, media_id: str
+ ) -> None:
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+
set_cors_headers(request)
set_corp_headers(request)
- server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
@@ -155,30 +159,24 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
- t_w = info["thumbnail_width"] == desired_width
- t_h = info["thumbnail_height"] == desired_height
- t_method = info["thumbnail_method"] == desired_method
- t_type = info["thumbnail_type"] == desired_type
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
- thumbnail=ThumbnailInfo(
- width=info["thumbnail_width"],
- height=info["thumbnail_height"],
- type=info["thumbnail_type"],
- method=info["thumbnail_method"],
- ),
+ thumbnail=info,
)
- t_type = file_info.thumbnail_type
- t_length = info["thumbnail_length"]
-
responder = await self.media_storage.fetch_media(file_info)
if responder:
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
return
logger.debug("We don't have a thumbnail of that size. Generating")
@@ -218,29 +216,23 @@ class ThumbnailResource(DirectServeJsonResource):
file_id = media_info["filesystem_id"]
for info in thumbnail_infos:
- t_w = info["thumbnail_width"] == desired_width
- t_h = info["thumbnail_height"] == desired_height
- t_method = info["thumbnail_method"] == desired_method
- t_type = info["thumbnail_type"] == desired_type
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
- thumbnail=ThumbnailInfo(
- width=info["thumbnail_width"],
- height=info["thumbnail_height"],
- type=info["thumbnail_type"],
- method=info["thumbnail_method"],
- ),
+ thumbnail=info,
)
- t_type = file_info.thumbnail_type
- t_length = info["thumbnail_length"]
-
responder = await self.media_storage.fetch_media(file_info)
if responder:
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
return
logger.debug("We don't have a thumbnail of that size. Generating")
@@ -300,7 +292,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos: List[Dict[str, Any]],
+ thumbnail_infos: List[ThumbnailInfo],
media_id: str,
file_id: str,
url_cache: bool,
@@ -315,7 +307,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ thumbnail_infos: A list of thumbnail info of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
@@ -418,13 +410,14 @@ class ThumbnailResource(DirectServeJsonResource):
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
+ assert request.path is not None
respond_with_json(
request,
400,
cs_error(
- "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+ "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
- request.postpath,
+ request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
@@ -438,7 +431,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos: List[Dict[str, Any]],
+ thumbnail_infos: List[ThumbnailInfo],
file_id: str,
url_cache: bool,
server_name: Optional[str],
@@ -451,7 +444,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
@@ -469,21 +462,25 @@ class ThumbnailResource(DirectServeJsonResource):
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
+ crop_info_list: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
# Other thumbnails.
- crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
+ crop_info_list2: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
- if info["thumbnail_method"] != "crop":
+ if info.method != "crop":
continue
- t_w = info["thumbnail_width"]
- t_h = info["thumbnail_height"]
+ t_w = info.width
+ t_h = info.height
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
+ type_quality = desired_type != info.type
+ length_quality = info.length
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
@@ -508,7 +505,7 @@ class ThumbnailResource(DirectServeJsonResource):
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
@@ -516,20 +513,20 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
- info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
+ info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
# Other thumbnails.
- info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
+ info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
- if info["thumbnail_method"] != "scale":
+ if info.method != "scale":
continue
- t_w = info["thumbnail_width"]
- t_h = info["thumbnail_height"]
+ t_w = info.width
+ t_h = info.height
size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
+ type_quality = desired_type != info.type
+ length_quality = info.length
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
else:
@@ -538,7 +535,7 @@ class ThumbnailResource(DirectServeJsonResource):
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
@@ -550,13 +547,7 @@ class ThumbnailResource(DirectServeJsonResource):
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
- thumbnail=ThumbnailInfo(
- width=thumbnail_info["thumbnail_width"],
- height=thumbnail_info["thumbnail_height"],
- type=thumbnail_info["thumbnail_type"],
- method=thumbnail_info["thumbnail_method"],
- length=thumbnail_info["thumbnail_length"],
- ),
+ thumbnail=thumbnail_info,
)
# No matching thumbnail was found.
diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py
index 043e8d6077..949326d85d 100644
--- a/synapse/rest/media/upload_resource.py
+++ b/synapse/rest/media/upload_resource.py
@@ -14,11 +14,12 @@
# limitations under the License.
import logging
+import re
from typing import IO, TYPE_CHECKING, Dict, List, Optional
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.servlet import parse_bytes_from_args
+from synapse.http.server import respond_with_json
+from synapse.http.servlet import RestServlet, parse_bytes_from_args
from synapse.http.site import SynapseRequest
from synapse.media.media_storage import SpamMediaException
@@ -29,8 +30,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UploadResource(DirectServeJsonResource):
- isLeaf = True
+class UploadResource(RestServlet):
+ PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
@@ -43,10 +44,7 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
- respond_with_json(request, 200, {}, send_cors=True)
-
- async def _async_render_POST(self, request: SynapseRequest) -> None:
+ async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
raw_content_length = request.getHeader("Content-Length")
if raw_content_length is None:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 46957723a1..9f7959c45d 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -16,7 +16,6 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
- Any,
Callable,
Collection,
Dict,
@@ -32,6 +31,7 @@ from typing import (
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.roommember import ProfileInfo
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -531,19 +531,9 @@ class StateStorageController:
@tag_args
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ca894edd5a..81f661160c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1874,9 +1874,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows.
Filters rows by whether the value of `column` is in `iterable`.
@@ -1888,10 +1888,13 @@ class DatabasePool:
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query
+
+ Returns:
+ The results as a list of tuples.
"""
keyvalues = keyvalues or {}
- results: List[Dict[str, Any]] = []
+ results: List[Tuple[Any, ...]] = []
for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
@@ -1918,9 +1921,9 @@ class DatabasePool:
iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows.
Filters rows by whether the value of `column` is in `iterable`.
@@ -1931,6 +1934,9 @@ class DatabasePool:
iterable: list
keyvalues: dict of column names and values to select the rows with
retcols: list of strings giving the names of the columns to return
+
+ Returns:
+ The results as a list of tuples.
"""
if not iterable:
return []
@@ -1949,7 +1955,7 @@ class DatabasePool:
)
txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
async def simple_update(
self,
@@ -2418,7 +2424,7 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
exclude_keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -2447,7 +2453,7 @@ class DatabasePool:
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
- The result as a list of dictionaries.
+ The result as a list of tuples.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -2474,69 +2480,7 @@ class DatabasePool:
)
txn.execute(sql, arg_list + [limit, start])
- return cls.cursor_to_dict(txn)
-
- async def simple_search_list(
- self,
- table: str,
- term: Optional[str],
- col: str,
- retcols: Collection[str],
- desc: str = "simple_search_list",
- ) -> Optional[List[Dict[str, Any]]]:
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table: the table name
- term: term for searching the table matched to a column.
- col: column to query term should be matched to
- retcols: the names of the columns to return
-
- Returns:
- A list of dictionaries or None.
- """
-
- return await self.runInteraction(
- desc,
- self.simple_search_list_txn,
- table,
- term,
- col,
- retcols,
- db_autocommit=True,
- )
-
- @classmethod
- def simple_search_list_txn(
- cls,
- txn: LoggingTransaction,
- table: str,
- term: Optional[str],
- col: str,
- retcols: Iterable[str],
- ) -> Optional[List[Dict[str, Any]]]:
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn: Transaction object
- table: the table name
- term: term for searching the table matched to a column.
- col: column to query term should be matched to
- retcols: the names of the columns to return
-
- Returns:
- None if no term is given, otherwise a list of dictionaries.
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return None
-
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
def make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 101403578c..840d725114 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
@@ -142,26 +142,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- async def get_users(self) -> List[JsonDict]:
- """Function to retrieve a list of users in users table.
-
- Returns:
- A list of dictionaries representing users.
- """
- return await self.db_pool.simple_select_list(
- table="users",
- keyvalues={},
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "user_type",
- "deactivated",
- ],
- desc="get_users",
- )
-
async def get_users_paginate(
self,
start: int,
@@ -316,7 +296,11 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- async def search_users(self, term: str) -> Optional[List[JsonDict]]:
+ async def search_users(
+ self, term: str
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
"""Function to search users list for one or more users with
the matched term.
@@ -324,15 +308,37 @@ class DataStore(
term: search term
Returns:
- A list of dictionaries or None.
+ A list of tuples of name, password_hash, is_guest, admin, user_type or None.
"""
- return await self.db_pool.simple_search_list(
- table="users",
- term=term,
- col="name",
- retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
- desc="search_users",
- )
+
+ def search_users(
+ txn: LoggingTransaction,
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
+ search_term = "%%" + term + "%%"
+
+ sql = """
+ SELECT name, password_hash, is_guest, admin, user_type
+ FROM users
+ WHERE name LIKE ?
+ """
+ txn.execute(sql, (search_term,))
+
+ return cast(
+ List[
+ Tuple[
+ str,
+ Optional[str],
+ Union[int, bool],
+ Union[int, bool],
+ Optional[str],
+ ]
+ ],
+ txn.fetchall(),
+ )
+
+ return await self.db_pool.runInteraction("search_users", search_users)
def check_database_before_upgrade(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 80f146dd53..39498d52c6 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -103,6 +103,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_index_update(
+ update_name="room_account_data_index_room_id",
+ index_name="room_account_data_room_id",
+ table="room_account_data",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_account_data_for_deactivated_users",
self._delete_account_data_for_deactivated_users,
@@ -151,10 +158,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in txn
}
return await self.db_pool.runInteraction(
@@ -196,13 +203,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_data = by_room.setdefault(row["room_id"], {})
+ for room_id, account_data_type, content in txn:
+ room_data = by_room.setdefault(room_id, {})
- room_data[row["account_data_type"]] = db_to_json(row["content"])
+ room_data[account_data_type] = db_to_json(content)
return by_room
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 0553a0621a..073a99cd84 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- List,
- Optional,
- Pattern,
- Sequence,
- Tuple,
- cast,
-)
+from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
from synapse.appservice import (
ApplicationService,
@@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
- "SELECT * FROM application_services_txns WHERE as_id=?"
+ "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
- return None
-
- entry = rows[0]
-
- return entry
+ return cast(Optional[Tuple[int, str]], txn.fetchone())
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = db_to_json(entry["event_ids"])
+ txn_id, event_ids_str = entry
+ event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused
@@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
- id=entry["txn_id"],
+ id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 16170e0436..bf5b8c804b 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+import attr
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 10 * 60 * 1000
-class DeviceLastConnectionInfo(TypedDict):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table
@@ -499,24 +501,29 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
- res = cast(
- List[DeviceLastConnectionInfo],
- await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
- ),
+ res = await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
- return {(d["user_id"], d["device_id"]): d for d in res}
+ return {
+ (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
+ user_id=d["user_id"],
+ device_id=d["device_id"],
+ ip=d["ip"],
+ user_agent=d["user_agent"],
+ last_seen=d["last_seen"],
+ )
+ for d in res
+ }
async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
@@ -683,8 +690,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
@@ -705,13 +711,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
continue
if not device_id or did == device_id:
- ret[(user_id, did)] = {
- "user_id": user_id,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": did,
- "last_seen": last_seen,
- }
+ ret[(user_id, did)] = DeviceLastConnectionInfo(
+ user_id=user_id,
+ ip=ip,
+ user_agent=user_agent,
+ device_id=did,
+ last_seen=last_seen,
+ )
return ret
async def get_user_ip_and_agents(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0be12f0e06..72dc4f54dc 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
- user_device_dicts = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
)
- device_ids_to_query.update(
- {row["device_id"] for row in user_device_dicts}
- )
+ device_ids_to_query.update({row[0] for row in user_device_dicts})
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -449,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id: str,
device_id: Optional[str],
up_to_stream_id: int,
- limit: int,
+ limit: Optional[int] = None,
) -> int:
"""
Args:
@@ -480,11 +481,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
ROW_ID_NAME = self.database_engine.row_id_name
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
+ limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
SELECT {ROW_ID_NAME} FROM device_inbox
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
- LIMIT {limit}
+ {limit_statement}
)
"""
txn.execute(sql, (user_id, device_id, up_to_stream_id))
@@ -849,20 +851,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- column="device_id",
- iterable=devices,
- retcols=("device_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ column="device_id",
+ iterable=devices,
+ retcols=("device_id",),
+ ),
)
- for row in rows:
+ for (device_id,) in rows:
# Only insert into the local inbox if the device exists on
# this server
- device_id = row["device_id"]
-
with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a8206c6afe..a07086149c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1054,16 +1054,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id"),
- desc="get_device_list_last_stream_id_for_remotes",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id"),
+ desc="get_device_list_last_stream_id_for_remotes",
+ ),
)
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
- results.update({row["user_id"]: row["stream_id"] for row in rows})
+ results.update(rows)
return results
@@ -1079,22 +1082,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ row_tuples = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ ),
)
+
+ return {row[0] for row in row_tuples}
else:
- rows = await self.db_pool.simple_select_list(
- table="device_lists_remote_resync",
- keyvalues=None,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
+ rows = cast(
+ List[Dict[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ ),
)
- return {row["user_id"] for row in rows}
+ return {row["user_id"] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
@@ -1415,13 +1426,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@@ -1429,11 +1440,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
devices: Dict[str, List[str]] = {}
- for row in rows:
+ for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
- if self.hs.is_mine_id(row["user_id"]):
- user_devices = devices.setdefault(row["user_id"], [])
- user_devices.append(row["device_id"])
+ if self.hs.is_mine_id(user_id):
+ user_devices = devices.setdefault(user_id, [])
+ user_devices.append(device_id)
return devices
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index d01f28cc80..aac4cfb054 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -53,6 +53,13 @@ class EndToEndRoomKeyBackgroundStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ self.db_pool.updates.register_background_index_update(
+ update_name="e2e_room_keys_index_room_id",
+ index_name="e2e_room_keys_room_id",
+ table="e2e_room_keys",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_e2e_backup_keys_for_deactivated_users",
self._delete_e2e_backup_keys_for_deactivated_users,
@@ -208,7 +215,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- StreamKeyType.ROOM: room_key,
+ StreamKeyType.ROOM.value: room_key,
}
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 89fac23f93..f13d776b0d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A map from (algorithm, key_id) to json string for key
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json"),
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="add_e2e_one_time_keys_check",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="add_e2e_one_time_keys_check",
+ ),
)
- result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
@@ -921,14 +924,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}
txn.execute(sql, params)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- user_id = row["user_id"]
- key_type = row["keytype"]
- key = db_to_json(row["keydata"])
+ for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
- user_keys[key_type] = key
+ user_keys[key_type] = db_to_json(key_data)
return result
@@ -988,13 +987,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
- for row in rows:
- key_id: str = row["key_id"]
- target_user_id: str = row["target_user_id"]
- target_device_id: str = row["target_device_id"]
+ for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@@ -1012,13 +1007,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
- user_sigs[key_id] = row["signature"]
+ user_sigs[key_id] = signature
else:
- signatures[from_user_id] = {key_id: row["signature"]}
+ signatures[from_user_id] = {key_id: signature}
else:
- target_user_key["signatures"] = {
- from_user_id: {key_id: row["signature"]}
- }
+ target_user_key["signatures"] = {from_user_id: {key_id: signature}}
return keys
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index afffa54985..4f80ce75cc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1049,15 +1049,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_max_depth_of",
),
- desc="get_max_depth_of",
)
if not rows:
@@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
max_depth_event_id = ""
current_max_depth = 0
- for row in rows:
- if row["depth"] > current_max_depth:
- max_depth_event_id = row["event_id"]
- current_max_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth > current_max_depth:
+ max_depth_event_id = event_id
+ current_max_depth = depth
return max_depth_event_id, current_max_depth
@@ -1078,15 +1081,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_min_depth_of",
),
- desc="get_min_depth_of",
)
if not rows:
@@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
min_depth_event_id = ""
current_min_depth = MAX_DEPTH
- for row in rows:
- if row["depth"] < current_min_depth:
- min_depth_event_id = row["event_id"]
- current_min_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth < current_min_depth:
+ min_depth_event_id = event_id
+ current_min_depth = depth
return min_depth_event_id, current_min_depth
@@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A filtered down list of `event_ids` that have previous failed pull attempts.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id",),
- desc="get_event_ids_with_failed_pull_attempts",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id",),
+ desc="get_event_ids_with_failed_pull_attempts",
+ ),
)
- event_ids_with_failed_pull_attempts: Set[str] = {
- row["event_id"] for row in rows
- }
-
- return event_ids_with_failed_pull_attempts
+ return {row[0] for row in rows}
@trace
async def get_event_ids_to_not_pull_from_backoff(
@@ -1585,32 +1590,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
- event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=(
- "event_id",
- "last_attempt_ts",
- "num_attempts",
+ event_failed_pull_attempts = cast(
+ List[Tuple[str, int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=(
+ "event_id",
+ "last_attempt_ts",
+ "num_attempts",
+ ),
+ desc="get_event_ids_to_not_pull_from_backoff",
),
- desc="get_event_ids_to_not_pull_from_backoff",
)
current_time = self._clock.time_msec()
event_ids_with_backoff = {}
- for event_failed_pull_attempt in event_failed_pull_attempts:
- event_id = event_failed_pull_attempt["event_id"]
+ for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = (
- event_failed_pull_attempt["last_attempt_ts"]
+ last_attempt_ts
+ (
2
** min(
- event_failed_pull_attempt["num_attempts"],
+ num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
)
)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 790d058c43..ef6766b5e0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -27,6 +27,7 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
cast,
)
@@ -501,16 +502,19 @@ class PersistEventsStore:
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="rooms",
- column="room_id",
- iterable={event.room_id for event in events if event.is_state()},
- keyvalues={},
- retcols=("room_id", "has_auth_chain_index"),
+ rows = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
+ ),
)
rooms_using_chain_index = {
- row["room_id"] for row in rows if row["has_auth_chain_index"]
+ room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
@@ -571,19 +575,18 @@ class PersistEventsStore:
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="room_id",
- iterable=set(event_to_room_id.values()),
- retcols=("event_id", "type", "state_key"),
+ auth_chain_to_calc_rows = cast(
+ List[Tuple[str, str, str]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable=set(event_to_room_id.values()),
+ retcols=("event_id", "type", "state_key"),
+ ),
)
- for row in rows:
- event_id = row["event_id"]
- event_type = row["type"]
- state_key = row["state_key"]
-
+ for event_id, event_type, state_key in auth_chain_to_calc_rows:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
@@ -753,23 +756,31 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable={chain_id for chain_id, _ in chain_map.values()},
- keyvalues={},
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
+ auth_chain_rows = cast(
+ List[Tuple[int, int, int, int]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
),
)
- for row in rows:
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in auth_chain_rows:
chain_links.add_link(
- (row["origin_chain_id"], row["origin_sequence_number"]),
- (row["target_chain_id"], row["target_sequence_number"]),
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
new=False,
)
@@ -1654,8 +1665,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []
- rows = []
-
ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@@ -1676,10 +1685,9 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- event = ev_map[row["event_id"]]
- if not row["rejects"] and not row["redacts"]:
+ for event_id, redacts, rejects in txn:
+ event = ev_map[event_id]
+ if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index daef3685b0..c5fce1c82b 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="event_json",
- column="event_id",
- iterable=chunk,
- retcols=["event_id", "json"],
- keyvalues={},
+ ev_rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ ),
)
- for row in ev_rows:
- event_id = row["event_id"]
- event_json = db_to_json(row["json"])
+ for event_id, json in ev_rows:
+ event_json = db_to_json(json)
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="events",
- column="event_id",
- iterable=to_delete,
- keyvalues={},
- retcols=("room_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ retcols=("room_id",),
+ ),
)
- room_ids = {row["room_id"] for row in rows}
+ room_ids = {row[0] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
@@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
count = len(rows)
# We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
+ auth_events = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ ),
)
event_to_auth_chain: Dict[str, List[str]] = {}
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+ for event_id, auth_id in auth_events:
+ event_to_auth_chain.setdefault(event_id, []).append(auth_id)
# Calculate and persist the chain cover index for this set of events.
#
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8737a1370e..89757eabed 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ ),
)
- return {r["event_id"] for r in rows}
+ return {r[0] for r in rows}
@trace
@tag_args
@@ -2340,15 +2343,18 @@ class EventsWorkerStore(SQLBaseStore):
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
- result = await self.db_pool.simple_select_many_batch(
- table="partial_state_events",
- column="event_id",
- iterable=event_ids,
- retcols=["event_id"],
- desc="get_partial_state_events",
+ result = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=["event_id"],
+ desc="get_partial_state_events",
+ ),
)
# convert the result to a dict, to make @cachedList work
- partial = {r["event_id"] for r in result}
+ partial = {r[0] for r in result}
return {e_id: e_id in partial for e_id in event_ids}
@cached()
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 889c578b9c..ea797864b9 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
import itertools
import json
import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="server_keys_json",
- column="key_id",
- iterable=key_ids,
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
- # We sort the rows so that the most recently added entry is picked up.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
async def get_all_server_keys_json_for_remote(
@@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
if not rows:
return {}
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"])
return {
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8cebeb5189..2e6b176bd2 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -28,6 +28,7 @@ from typing import (
from synapse.api.constants import Direction
from synapse.logging.opentracing import trace
+from synapse.media._base import ThumbnailInfo
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -435,8 +436,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
+ async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
+ rows = await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -448,6 +449,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
),
desc="get_local_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row["thumbnail_width"],
+ height=row["thumbnail_height"],
+ method=row["thumbnail_method"],
+ type=row["thumbnail_type"],
+ length=row["thumbnail_length"],
+ )
+ for row in rows
+ ]
@trace
async def store_local_thumbnail(
@@ -556,8 +567,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
- ) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
+ ) -> List[ThumbnailInfo]:
+ rows = await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -566,10 +577,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
- "filesystem_id",
),
desc="get_remote_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row["thumbnail_width"],
+ height=row["thumbnail_height"],
+ method=row["thumbnail_method"],
+ type=row["thumbnail_type"],
+ length=row["thumbnail_length"],
+ )
+ for row in rows
+ ]
@trace
async def get_remote_media_thumbnail(
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 194b4e031f..3b444d2d07 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -20,6 +20,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Union,
cast,
)
@@ -260,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Mapping[str, UserPresenceState]:
- rows = await self.db_pool.simple_select_many_batch(
- table="presence_stream",
- column="user_id",
- iterable=user_ids,
- keyvalues={},
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ desc="get_presence_for_users",
),
- desc="get_presence_for_users",
)
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ return {
+ user_id: UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ }
async def should_user_receive_full_presence_with_token(
self,
@@ -385,28 +399,49 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100
offset = 0
while True:
- rows = await self.db_pool.runInteraction(
- "get_presence_for_all_users",
- self.db_pool.simple_select_list_paginate_txn,
- "presence_stream",
- orderby="stream_id",
- start=offset,
- limit=limit,
- exclude_keyvalues=exclude_keyvalues,
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.runInteraction(
+ "get_presence_for_all_users",
+ self.db_pool.simple_select_list_paginate_txn,
+ "presence_stream",
+ orderby="stream_id",
+ start=offset,
+ limit=limit,
+ exclude_keyvalues=exclude_keyvalues,
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ order_direction="ASC",
),
- order_direction="ASC",
)
- for row in rows:
- users_to_state[row["user_id"]] = UserPresenceState(**row)
+ for (
+ user_id,
+ state,
+ last_active_ts,
+ last_federation_update_ts,
+ last_user_sync_ts,
+ status_msg,
+ currently_active,
+ ) in rows:
+ users_to_state[user_id] = UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
# We've run out of updates to query
if len(rows) < limit:
@@ -434,13 +469,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
txn.close()
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
+ return [
+ UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ ]
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index dea0e0458c..1e11bf2706 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -89,6 +89,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# furthermore, we might already have the table from a previous (failed)
# purge attempt, so let's drop the table first.
+ if isinstance(self.database_engine, PostgresEngine):
+ # Disable statement timeouts for this transaction; purging rooms can
+ # take a while!
+ txn.execute("SET LOCAL statement_timeout = 0")
+
txn.execute("DROP TABLE IF EXISTS events_to_purge")
txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 923166974c..f5356e7f80 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
+ rawrules: List[Tuple[str, int, str, str]],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
+
+ Args:
+ rawrules: List of tuples of:
+ * rule ID
+ * Priority lass
+ * Conditions (as serialized JSON)
+ * Actions (as serialized JSON)
+ enabled_map: A dictionary of rule ID to a boolean of whether the rule is
+ enabled. This might not include all rule IDs from rawrules.
+ experimental_config: The `experimental_features` section of the Synapse
+ config. (Used to check if various features are enabled.)
+
+ Returns:
+ A new FilteredPushRules object.
"""
ruleslist = [
PushRule.from_db(
- rule_id=rawrule["rule_id"],
- priority_class=rawrule["priority_class"],
- conditions=rawrule["conditions"],
- actions=rawrule["actions"],
+ rule_id=rawrule[0],
+ priority_class=rawrule[1],
+ conditions=rawrule[2],
+ actions=rawrule[3],
)
for rawrule in rawrules
]
@@ -183,7 +197,19 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(
+ [
+ (
+ row["rule_id"],
+ row["priority_class"],
+ row["conditions"],
+ row["actions"],
+ )
+ for row in rows
+ ],
+ enabled_map,
+ self.hs.config.experimental,
+ )
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@@ -221,21 +247,36 @@ class PushRulesWorkerStore(
if not user_ids:
return {}
- raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
+ user_id: [] for user_id in user_ids
+ }
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules",
- column="user_name",
- iterable=user_ids,
- retcols=("*",),
- desc="bulk_get_push_rules",
- batch_size=1000,
+ rows = cast(
+ List[Tuple[str, str, int, int, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=(
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="bulk_get_push_rules",
+ batch_size=1000,
+ ),
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
- for row in rows:
- raw_rules.setdefault(row["user_name"], []).append(row)
+ for user_name, rule_id, priority_class, _, conditions, actions in rows:
+ raw_rules.setdefault(user_name, []).append(
+ (rule_id, priority_class, conditions, actions)
+ )
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
@@ -256,17 +297,19 @@ class PushRulesWorkerStore(
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules_enable",
- column="user_name",
- iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled"),
- desc="bulk_get_push_rules_enabled",
- batch_size=1000,
+ rows = cast(
+ List[Tuple[str, str, Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
+ ),
)
- for row in rows:
- enabled = bool(row["enabled"])
- results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+ for user_name, rule_id, enabled in rows:
+ results.setdefault(user_name, {})[rule_id] = bool(enabled)
return results
async def get_all_push_rule_updates(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..c7eb7fc478 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+ int, # id
+ str, # user_name
+ Optional[int], # access_token
+ str, # profile_tag
+ str, # kind
+ str, # app_id
+ str, # app_display_name
+ str, # device_display_name
+ str, # pushkey
+ int, # ts
+ str, # lang
+ str, # data
+ int, # last_stream_ordering
+ int, # last_success
+ int, # failing_since
+ bool, # enabled
+ str, # device_id
+]
+
class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+ def _decode_pushers_rows(
+ self,
+ rows: Iterable[PusherRow],
+ ) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
- for r in rows:
- data_json = r["data"]
+ for (
+ id,
+ user_name,
+ access_token,
+ profile_tag,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ ts,
+ lang,
+ data,
+ last_stream_ordering,
+ last_success,
+ failing_since,
+ enabled,
+ device_id,
+ ) in rows:
try:
- r["data"] = db_to_json(data_json)
+ data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r["id"],
- data_json,
+ id,
+ data,
e.args[0],
)
continue
- # If we're using SQLite, then boolean values are integers. This is
- # troublesome since some code using the return value of this method might
- # expect it to be a boolean, or will expose it to clients (in responses).
- r["enabled"] = bool(r["enabled"])
-
- yield PusherConfig(**r)
+ yield PusherConfig(
+ id=id,
+ user_name=user_name,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=ts,
+ lang=lang,
+ data=data_json,
+ last_stream_ordering=last_stream_ordering,
+ last_success=last_success,
+ failing_since=failing_since,
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ enabled=bool(enabled),
+ device_id=device_id,
+ access_token=access_token,
+ )
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
- def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
- def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
- rows = self.db_pool.cursor_to_dict(txn)
-
- return self._decode_pushers_rows(rows)
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+ txn.execute(
+ """
+ SELECT id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ enabled, device_id
+ FROM pushers WHERE COALESCE(enabled, TRUE)
+ """
+ )
+ return cast(List[PusherRow], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_enabled_pushers", get_enabled_pushers_txn
+ return self._decode_pushers_rows(
+ await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
)
async def get_all_updated_pushers_rows(
@@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
- self, pusher_id: str
+ self, pusher_id: int
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
@@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
async def set_throttle_params(
- self, pusher_id: str, room_id: str, params: ThrottleParams
+ self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if len(rows) == 0:
return 0
@@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
- key_values=[(row["pusher_id"],) for row in rows],
+ key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
- (row["pusher_device_id"] or row["token_device_id"], None)
- for row in rows
+ (pusher_device_id or token_device_id, None)
+ for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
- txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..b2645ab43c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
sql = (
- "SELECT * FROM receipts_linearized WHERE"
+ "SELECT receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
- "SELECT * FROM receipts_linearized WHERE"
+ "SELECT receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, to_key))
- rows = self.db_pool.cursor_to_dict(txn)
-
- return rows
+ return cast(List[Tuple[str, str, str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
- for row in rows:
- content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
- row["user_id"]
- ] = db_to_json(row["data"])
+ for receipt_type, user_id, event_id, data in rows:
+ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
+ user_id
+ ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not room_ids:
return {}
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
@@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
+ )
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
- if row["thread_id"]:
- receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
+ receipt_type_dict[user_id] = db_to_json(data)
+ if thread_id:
+ receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [from_key, to_key])
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [to_key])
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
+ receipt_type_dict[user_id] = db_to_json(data)
return results
@@ -742,7 +750,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[Tuple[int, int]]:
+ ) -> Optional[int]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -804,9 +812,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
+ return stream_id
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..9e8643ae4d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -143,6 +143,14 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidResult:
+ medium: str
+ address: str
+ validated_at: int
+ added_at: int
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -195,7 +203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
- def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +221,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
-
- if len(rows) == 0:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ (
+ name,
+ is_guest,
+ admin,
+ consent_version,
+ consent_ts,
+ consent_server_notice_sent,
+ appservice_id,
+ creation_ts,
+ user_type,
+ deactivated,
+ shadow_banned,
+ approved,
+ locked,
+ ) = row
+
+ return UserInfo(
+ appservice_id=appservice_id,
+ consent_server_notice_sent=consent_server_notice_sent,
+ consent_version=consent_version,
+ consent_ts=consent_ts,
+ creation_ts=creation_ts,
+ is_admin=bool(admin),
+ is_deactivated=bool(deactivated),
+ is_guest=bool(is_guest),
+ is_shadow_banned=bool(shadow_banned),
+ user_id=UserID.from_string(name),
+ user_type=user_type,
+ approved=bool(approved),
+ locked=bool(locked),
+ )
- row = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
- if row is None:
- return None
-
- return UserInfo(
- appservice_id=row["appservice_id"],
- consent_server_notice_sent=row["consent_server_notice_sent"],
- consent_version=row["consent_version"],
- consent_ts=row["consent_ts"],
- creation_ts=row["creation_ts"],
- is_admin=bool(row["admin"]),
- is_deactivated=bool(row["deactivated"]),
- is_guest=bool(row["is_guest"]),
- is_shadow_banned=bool(row["shadow_banned"]),
- user_id=UserID.from_string(row["name"]),
- user_type=row["user_type"],
- approved=bool(row["approved"]),
- locked=bool(row["locked"]),
- )
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +598,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
- rows = self.db_pool.cursor_to_dict(txn)
-
- if rows:
- row = rows[0]
+ row = txn.fetchone()
- # This field is nullable, ensure it comes out as a boolean
- if row["token_used"] is None:
- row["token_used"] = False
-
- return TokenLookupResult(**row)
+ if row:
+ (
+ user_id,
+ is_guest,
+ shadow_banned,
+ token_id,
+ device_id,
+ valid_until_ms,
+ token_owner,
+ token_used,
+ ) = row
+
+ return TokenLookupResult(
+ user_id=user_id,
+ is_guest=is_guest,
+ shadow_banned=shadow_banned,
+ token_id=token_id,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ token_owner=token_owner,
+ # This field is nullable, ensure it comes out as a boolean
+ token_used=bool(token_used),
+ )
return None
@@ -833,11 +867,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -891,11 +924,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@@ -964,13 +996,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
+ async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
+ results = await self.db_pool.simple_select_list(
"user_threepids",
- {"user_id": user_id},
- ["medium", "address", "validated_at", "added_at"],
- "user_get_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
)
+ return [ThreepidResult(**r) for r in results]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1252,12 +1285,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
+ for (name,) in txn.fetchall():
+ self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1963,11 +1992,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ row = txn.fetchone()
+ assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
- return bool(rows[0]["approved"])
+ return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2075,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
- for user in rows:
- if not user["count_tokens"] and not user["count_threepids"]:
- self.set_user_deactivated_status_txn(txn, user["name"], True)
+ for name, count_tokens, count_threepids in rows:
+ if not count_tokens and not count_threepids:
+ self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
- txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 9246b418f5..7f40e2c446 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_many_txn(
- txn=txn,
- table="event_relations",
- column="relation_type",
- iterable=relation_types,
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn=txn,
+ table="event_relations",
+ column="relation_type",
+ iterable=relation_types,
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 719e11aea6..9d24d2c347 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Optional[int]]]:
+ ) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)
- min_lifetime = ret[0]["min_lifetime"]
- max_lifetime = ret[0]["max_lifetime"]
+ min_lifetime, max_lifetime = ret
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args)
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = RetentionPolicy(
- min_lifetime=row["min_lifetime"],
- max_lifetime=row["max_lifetime"],
+ rooms_dict = {
+ room_id: RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
)
+ for room_id, min_lifetime, max_lifetime in txn
+ }
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql)
- rows = self.db_pool.cursor_to_dict(txn)
-
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = RetentionPolicy()
+ for (room_id,) in txn:
+ if room_id not in rooms_dict:
+ rooms_dict[room_id] = RetentionPolicy()
return rooms_dict
@@ -1300,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
complete.
"""
- rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
- table="partial_state_rooms",
- column="room_id",
- iterable=room_ids,
- retcols=("room_id",),
- desc="is_partial_state_room_batched",
- )
- partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_rooms",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id",),
+ desc="is_partial_state_room_batched",
+ ),
+ )
+ partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
@@ -1703,24 +1702,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True
- for row in rows:
- if not row["json"]:
+ for room_id, event_id, json in rows:
+ if not json:
retention_policy = {}
else:
- ev = db_to_json(row["json"])
+ ev = db_to_json(json)
retention_policy = ev["content"]
self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
+ "room_id": room_id,
+ "event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@@ -1729,7 +1728,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn(
- txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ txn, "insert_room_retention", {"room_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e93573f315..3a87eba430 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="current_state_events",
- column="state_key",
- iterable=user_ids,
- retcols=(
- "state_key",
- "room_id",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
),
- keyvalues={
- "type": EventTypes.Member,
- "membership": Membership.JOIN,
- },
- desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
- for row in rows:
- user_rooms[row["state_key"]].add(row["room_id"])
+ for state_key, room_id in rows:
+ user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=event_ids,
- retcols=("user_id", "event_id"),
- keyvalues={"membership": Membership.JOIN},
- batch_size=1000,
- desc="_get_user_ids_from_membership_event_ids",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id", "user_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=1000,
+ desc="_get_user_ids_from_membership_event_ids",
+ ),
)
- return {row["event_id"]: row["user_id"] for row in rows}
+ return dict(rows)
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids,
- retcols=("user_id", "membership", "event_id"),
- keyvalues={},
- batch_size=500,
- desc="get_membership_from_event_ids",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ ),
)
return {
- row["event_id"]: EventIdMembership(
- membership=row["membership"], user_id=row["user_id"]
- )
- for row in rows
+ event_id: EventIdMembership(membership=membership, user_id=user_id)
+ for user_id, membership, event_id in rows
}
async def is_local_host_in_room_ignoring_users(
@@ -1349,18 +1357,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
+ for _, event_id, room_id, json in rows:
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index a7aae661d8..1d69c4a5f0 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
event_search_rows = []
- for row in rows:
+ for (
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ json,
+ origin_server_ts,
+ ) in rows:
try:
- event_id = row["event_id"]
- room_id = row["room_id"]
- etype = row["type"]
- stream_ordering = row["stream_ordering"]
- origin_server_ts = row["origin_server_ts"]
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5eaaff5b68..598025dd91 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -20,10 +20,12 @@ from typing import (
Collection,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
Tuple,
+ cast,
)
import attr
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
RuntimeError if the state is unknown at any of the given events
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id", "state_group"),
- desc="_get_state_group_for_events",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group"),
+ desc="_get_state_group_for_events",
+ ),
)
- res = {row["event_id"]: row["state_group"] for row in rows}
+ res = dict(rows)
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("DISTINCT state_group",),
- desc="get_referenced_state_groups",
+ rows = cast(
+ List[Tuple[int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ ),
)
- return {row["state_group"] for row in rows}
+ return {row[0] for row in rows}
async def update_state_for_partial_state_event(
self,
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="room_id",
- iterable=to_delete,
- keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
- retcols=("state_key",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=to_delete,
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ retcols=("state_key",),
+ ),
)
- potentially_left_users = {row["state_key"] for row in rows}
+ potentially_left_users = {row[0] for row in rows}
# Now lets actually delete the rooms from the DB.
self.db_pool.simple_delete_many_txn(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+ stream_id: int
+ room_id: str
+ event_type: str
+ state_key: str
+
+ event_id: Optional[str]
+ """new event_id for this state key. None if the state has been deleted."""
+
+ prev_event_id: Optional[str]
+ """previous event_id for this state key. None if it's new state."""
+
+
class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
This may be the partial state if we're lazy joining the room.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- - ie, an upper limit to return changes from.
+ - ie, an upper limit to return changes from.
Returns:
A tuple consisting of:
- - the stream id which these results go up to
- - list of current_state_delta_stream rows. If it is empty, we are
- up to date.
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+ return clipped_stream_id, [
+ StateDelta(
+ stream_id=row[0],
+ room_id=row[1],
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
+ )
+ for row in txn.fetchall()
+ ]
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9d403919e4..5b2d0ba870 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore):
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="type",
- iterable=[
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.RoomHistoryVisibility,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.RoomAvatar,
- EventTypes.CanonicalAlias,
- ],
- keyvalues={"room_id": room_id, "state_key": ""},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="type",
+ iterable=[
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.RoomAvatar,
+ EventTypes.CanonicalAlias,
+ ],
+ keyvalues={"room_id": room_id, "state_key": ""},
+ retcols=["event_id"],
+ ),
)
- event_ids = cast(List[str], [row["event_id"] for row in rows])
+ event_ids = [row[0] for row in rows]
txn.execute(
"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 5a3611c415..ea06e4eee0 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
- return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+ return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
def _make_generic_sql_bound(
@@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, immutabledict(positions))
+ return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
+ key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return RoomStreamToken(topo, stream_ordering)
+ return RoomStreamToken(topological=topo, stream=stream_ordering)
@overload
def get_stream_id_for_event_txn(
@@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(
+ topological=row["topological_ordering"], stream=row["stream_ordering"]
+ )
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = RoomStreamToken(topo, stream - 1)
- internal.after = RoomStreamToken(topo, stream)
+ internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
+ internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- results["topological_ordering"] - 1, results["stream_ordering"]
+ topological=results["topological_ordering"] - 1,
+ stream=results["stream_ordering"],
)
after_token = RoomStreamToken(
- results["topological_ordering"], results["stream_ordering"]
+ topological=results["topological_ordering"],
+ stream=results["stream_ordering"],
)
rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5c5372a825..5555b53575 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
+ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
+
class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
@staticmethod
- def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
- row["status"] = TaskStatus(row["status"])
- if row["params"] is not None:
- row["params"] = db_to_json(row["params"])
- if row["result"] is not None:
- row["result"] = db_to_json(row["result"])
- return ScheduledTask(**row)
+ def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
+ task_id, action, status, timestamp, resource_id, params, result, error = row
+ return ScheduledTask(
+ id=task_id,
+ action=action,
+ status=TaskStatus(status),
+ timestamp=timestamp,
+ resource_id=resource_id,
+ params=db_to_json(params) if params is not None else None,
+ result=db_to_json(result) if result is not None else None,
+ error=error,
+ )
async def get_scheduled_tasks(
self,
@@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
- def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)
- return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+ return (
+ TaskSchedulerWorkerStore._convert_row_to_task(
+ (
+ row["id"],
+ row["action"],
+ row["status"],
+ row["timestamp"],
+ row["resource_id"],
+ row["params"],
+ row["result"],
+ row["error"],
+ )
+ )
+ if row
+ else None
+ )
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 8f70eff809..c4a6475060 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
) -> Mapping[str, Optional[DestinationRetryTimings]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="destinations",
- iterable=destinations,
- column="destination",
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
- desc="get_destination_retry_timings_batch",
+ rows = cast(
+ List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="destinations",
+ iterable=destinations,
+ column="destination",
+ retcols=(
+ "destination",
+ "failure_ts",
+ "retry_last_ts",
+ "retry_interval",
+ ),
+ desc="get_destination_retry_timings_batch",
+ ),
)
return {
- row.pop("destination"): DestinationRetryTimings(**row)
- for row in rows
- if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
+ destination: DestinationRetryTimings(
+ failure_ts, retry_last_ts, retry_interval
+ )
+ for destination, failure_ts, retry_last_ts, retry_interval in rows
+ if retry_last_ts and failure_ts and retry_interval
}
async def set_destination_retry_timings(
@@ -526,7 +536,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
start: int,
limit: int,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
total number of rooms.
@@ -537,12 +547,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: number of rows to retrieve
direction: sort ascending or descending by room_id
Returns:
- A tuple of a dict of rooms and a count of total rooms.
+ A tuple of a list of room tuples and a count of total rooms.
+
+ Each room tuple is room_id, stream_ordering.
"""
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
if direction == Direction.BACKWARDS:
order = "DESC"
else:
@@ -556,14 +568,17 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, [destination])
count = cast(Tuple[int], txn.fetchone())[0]
- rooms = self.db_pool.simple_select_list_paginate_txn(
- txn=txn,
- table="destination_rooms",
- orderby="room_id",
- start=start,
- limit=limit,
- retcols=("room_id", "stream_ordering"),
- order_direction=order,
+ rooms = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ ),
)
return rooms, count
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index f38bedbbcd..919c66f553 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter
# before deleting the session.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="ui_auth_sessions_credentials",
- column="session_id",
- iterable=session_ids,
- keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
- retcols=["result"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ ),
)
# Get the tokens used and how much pending needs to be decremented by.
@@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
- token = db_to_json(r["result"])
+ token = db_to_json(r[0])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
- token_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="registration_tokens",
- column="token",
- iterable=list(token_counts.keys()),
- keyvalues={},
- retcols=["token", "pending"],
+ token_rows = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ ),
)
- for token_row in token_rows:
- token = token_row["token"]
- new_pending = token_row["pending"] - token_counts[token]
+ for token, pending in token_rows:
+ new_pending = pending - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f0dc31fee6..23eb92c514 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -410,25 +410,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
# Next fetch their profiles. Note that not all users have profiles.
- profile_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="profiles",
- column="full_user_id",
- iterable=list(users_to_insert),
- retcols=(
- "full_user_id",
- "displayname",
- "avatar_url",
+ profile_rows = cast(
+ List[Tuple[str, Optional[str], Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="profiles",
+ column="full_user_id",
+ iterable=list(users_to_insert),
+ retcols=(
+ "full_user_id",
+ "displayname",
+ "avatar_url",
+ ),
+ keyvalues={},
),
- keyvalues={},
)
profiles = {
- row["full_user_id"]: _UserDirProfile(
- row["full_user_id"],
- row["displayname"],
- row["avatar_url"],
- )
- for row in profile_rows
+ full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
+ for full_user_id, displayname, avatar_url in profile_rows
}
profiles_to_insert = [
@@ -517,18 +516,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="users",
- column="name",
- iterable=users,
- keyvalues={
- "deactivated": 0,
- },
- retcols=("name", "user_type"),
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="users",
+ column="name",
+ iterable=users,
+ keyvalues={
+ "deactivated": 0,
+ },
+ retcols=("name", "user_type"),
+ ),
)
- return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
+ return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 06fcbe5e54..8bd58c6e3d 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Mapping
+from typing import Iterable, List, Mapping, Tuple, cast
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
Returns:
for each user, whether the user has requested erasure.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="erased_users",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="are_users_erased",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ ),
)
- erased_users = {row["user_id"] for row in rows}
+ erased_users = {row[0] for row in rows}
return {u: u in erased_users for u in user_ids}
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 6984d11352..09d2a8c5b3 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,7 +13,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
@@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- retcols=("state_group",),
+ rows = cast(
+ List[Tuple[int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ ),
)
remaining_state_groups = {
- row["state_group"]
- for row in rows
- if row["state_group"] not in state_groups_to_delete
+ state_group
+ for state_group, in rows
+ if state_group not in state_groups_to_delete
}
logger.info(
@@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- desc="get_previous_state_groups",
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_previous_state_groups",
+ ),
)
- return {row["state_group"]: row["prev_state_group"] for row in rows}
+ return dict(rows)
async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int]
diff --git a/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql b/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql
new file mode 100644
index 0000000000..fc948166e6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/82/04_add_indices_for_purging_rooms.sql
@@ -0,0 +1,20 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (8204, 'e2e_room_keys_index_room_id', '{}');
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (8204, 'room_account_data_index_room_id', '{}');
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index d7084d2358..609a0978a9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Iterator, Tuple
+from typing import TYPE_CHECKING, Sequence, Tuple
import attr
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import StreamToken
+from synapse.types import StreamKeyType, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -37,9 +37,14 @@ class _EventSourcesInner:
receipt: ReceiptEventSource
account_data: AccountDataEventSource
- def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
- for attribute in attr.fields(_EventSourcesInner):
- yield attribute.name, getattr(self, attribute.name)
+ def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
+ return [
+ (StreamKeyType.ROOM, self.room),
+ (StreamKeyType.PRESENCE, self.presence),
+ (StreamKeyType.TYPING, self.typing),
+ (StreamKeyType.RECEIPT, self.receipt),
+ (StreamKeyType.ACCOUNT_DATA, self.account_data),
+ ]
class EventSources:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 76b0e3e694..09a88c86a7 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -22,8 +22,8 @@ from typing import (
Any,
ClassVar,
Dict,
- Final,
List,
+ Literal,
Mapping,
Match,
MutableMapping,
@@ -34,6 +34,7 @@ from typing import (
Type,
TypeVar,
Union,
+ overload,
)
import attr
@@ -60,6 +61,8 @@ from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
+ from typing_extensions import Self
+
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
@@ -436,7 +439,78 @@ def map_username_to_mxid_localpart(
@attr.s(frozen=True, slots=True, order=False)
-class RoomStreamToken:
+class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
+ """An abstract stream token class for streams that supports multiple
+ writers.
+
+ This works by keeping track of the stream position of each writer,
+ represented by a default `stream` attribute and a map of instance name to
+ stream position of any writers that are ahead of the default stream
+ position.
+ """
+
+ stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
+
+ instance_map: "immutabledict[str, int]" = attr.ib(
+ factory=immutabledict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(immutabledict),
+ ),
+ kw_only=True,
+ )
+
+ @classmethod
+ @abc.abstractmethod
+ async def parse(cls, store: "DataStore", string: str) -> "Self":
+ """Parse the string representation of the token."""
+ ...
+
+ @abc.abstractmethod
+ async def to_string(self, store: "DataStore") -> str:
+ """Serialize the token into its string representation."""
+ ...
+
+ def copy_and_advance(self, other: "Self") -> "Self":
+ """Return a new token such that if an event is after both this token and
+ the other token, then its after the returned token too.
+ """
+
+ max_stream = max(self.stream, other.stream)
+
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return attr.evolve(
+ self, stream=max_stream, instance_map=immutabledict(instance_map)
+ )
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token."""
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+
+@attr.s(frozen=True, slots=True, order=False)
+class RoomStreamToken(AbstractMultiWriterStreamToken):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
@@ -513,16 +587,8 @@ class RoomStreamToken:
topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
- )
- stream: int = attr.ib(validator=attr.validators.instance_of(int))
-
- instance_map: "immutabledict[str, int]" = attr.ib(
- factory=immutabledict,
- validator=attr.validators.deep_mapping(
- key_validator=attr.validators.instance_of(str),
- value_validator=attr.validators.instance_of(int),
- mapping_validator=attr.validators.instance_of(immutabledict),
- ),
+ kw_only=True,
+ default=None,
)
def __attrs_post_init__(self) -> None:
@@ -582,17 +648,7 @@ class RoomStreamToken:
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")
- max_stream = max(self.stream, other.stream)
-
- instance_map = {
- instance: max(
- self.instance_map.get(instance, self.stream),
- other.instance_map.get(instance, other.stream),
- )
- for instance in set(self.instance_map).union(other.instance_map)
- }
-
- return RoomStreamToken(None, max_stream, immutabledict(instance_map))
+ return super().copy_and_advance(other)
def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
@@ -618,16 +674,6 @@ class RoomStreamToken:
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)
- def get_max_stream_pos(self) -> int:
- """Get the maximum stream position referenced in this token.
-
- The corresponding "min" position is, by definition just `self.stream`.
-
- This is used to handle tokens that have non-empty `instance_map`, and so
- reference stream positions after the `self.stream` position.
- """
- return max(self.instance_map.values(), default=self.stream)
-
async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
@@ -649,20 +695,20 @@ class RoomStreamToken:
return "s%d" % (self.stream,)
-class StreamKeyType:
+class StreamKeyType(Enum):
"""Known stream types.
A stream is a list of entities ordered by an incrementing "stream token".
"""
- ROOM: Final = "room_key"
- PRESENCE: Final = "presence_key"
- TYPING: Final = "typing_key"
- RECEIPT: Final = "receipt_key"
- ACCOUNT_DATA: Final = "account_data_key"
- PUSH_RULES: Final = "push_rules_key"
- TO_DEVICE: Final = "to_device_key"
- DEVICE_LIST: Final = "device_list_key"
+ ROOM = "room_key"
+ PRESENCE = "presence_key"
+ TYPING = "typing_key"
+ RECEIPT = "receipt_key"
+ ACCOUNT_DATA = "account_data_key"
+ PUSH_RULES = "push_rules_key"
+ TO_DEVICE = "to_device_key"
+ DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
@@ -784,7 +830,7 @@ class StreamToken:
def room_stream_id(self) -> int:
return self.room_key.stream
- def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
+ def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
@@ -797,35 +843,68 @@ class StreamToken:
return new_token
new_token = self.copy_and_replace(key, new_value)
- new_id = int(getattr(new_token, key))
- old_id = int(getattr(self, key))
+ new_id = new_token.get_field(key)
+ old_id = self.get_field(key)
if old_id < new_id:
return new_token
else:
return self
- def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
- return attr.evolve(self, **{key: new_value})
+ def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
+ return attr.evolve(self, **{key.value: new_value})
+ @overload
+ def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
+ ...
-StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+ @overload
+ def get_field(
+ self,
+ key: Literal[
+ StreamKeyType.ACCOUNT_DATA,
+ StreamKeyType.DEVICE_LIST,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.PUSH_RULES,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.TYPING,
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+ ],
+ ) -> int:
+ ...
+ @overload
+ def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ ...
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class PersistedEventPosition:
- """Position of a newly persisted event with instance that persisted it.
+ def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ """Returns the stream ID for the given key."""
+ return getattr(self, key.value)
- This can be used to test whether the event is persisted before or after a
- RoomStreamToken.
- """
+
+StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PersistedPosition:
+ """Position of a newly persisted row with instance that persisted it."""
instance_name: str
stream: int
- def persisted_after(self, token: RoomStreamToken) -> bool:
+ def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PersistedEventPosition(PersistedPosition):
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
persisted in the same room after this position will be after the
@@ -836,7 +915,7 @@ class PersistedEventPosition:
"""
# Doing the naive thing satisfies the desired properties described in
# the docstring.
- return RoomStreamToken(None, self.stream)
+ return RoomStreamToken(stream=self.stream)
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 0e1f907667..547202c96b 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -170,10 +170,10 @@ class RetryDestinationLimiter:
database in milliseconds, or zero if the last request was
successful.
backoff_on_404: Back off if we get a 404
-
backoff_on_failure: set to False if we should not increase the
retry interval on a failure.
-
+ notifier: A notifier used to mark servers as up.
+ replication_client A replication client used to mark servers as up.
backoff_on_all_error_codes: Whether we should back off on any
error code.
"""
@@ -237,6 +237,9 @@ class RetryDestinationLimiter:
else:
valid_err_code = False
+ # Whether previous requests to the destination had been failing.
+ previously_failing = bool(self.failure_ts)
+
if success:
# We connected successfully.
if not self.retry_interval:
@@ -282,6 +285,9 @@ class RetryDestinationLimiter:
if self.failure_ts is None:
self.failure_ts = retry_last_ts
+ # Whether the current request to the destination had been failing.
+ currently_failing = bool(self.failure_ts)
+
async def store_retry_timings() -> None:
try:
await self.store.set_destination_retry_timings(
@@ -291,17 +297,15 @@ class RetryDestinationLimiter:
self.retry_interval,
)
- if self.notifier:
- # Inform the relevant places that the remote server is back up.
- self.notifier.notify_remote_server_up(self.destination)
-
- if self.replication_client:
- # If we're on a worker we try and inform master about this. The
- # replication client doesn't hook into the notifier to avoid
- # infinite loops where we send a `REMOTE_SERVER_UP` command to
- # master, which then echoes it back to us which in turn pokes
- # the notifier.
- self.replication_client.send_remote_server_up(self.destination)
+ # If the server was previously failing, but is no longer.
+ if previously_failing and not currently_failing:
+ if self.notifier:
+ # Inform the relevant places that the remote server is back up.
+ self.notifier.notify_remote_server_up(self.destination)
+
+ if self.replication_client:
+ # Inform other workers that the remote server is up.
+ self.replication_client.send_remote_server_up(self.destination)
except Exception:
logger.exception("Failed to store destination_retry_timings")
|