diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index baec35ee27..4a883dc166 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state}, ["as_id"]
+ "application_services_state", {"state": state.value}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- return result.get("state")
+ return ApplicationServiceState(result.get("state"))
return None
async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
- "application_services_state", {"as_id": service.id}, {"state": state}
+ "application_services_state", {"as_id": service.id}, {"state": state.value}
)
async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..838a2a6a3d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
+ async def get_devices_by_auth_provider_session_id(
+ self, auth_provider_id: str, auth_provider_session_id: str
+ ) -> List[Dict[str, Any]]:
+ """Retrieve the list of devices associated with a SSO IdP session ID.
+
+ Args:
+ auth_provider_id: The SSO IdP ID as defined in the server config
+ auth_provider_session_id: The session ID within the IdP
+ Returns:
+ A list of dicts containing the device_id and the user_id of each device
+ """
+ return await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ )
+
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@@ -253,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore):
# add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("m.signing_key_update", result))
+ # also send the unstable version
+ # FIXME: remove this when enough servers have upgraded
results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results
@@ -1070,7 +1093,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+ self,
+ user_id: str,
+ device_id: str,
+ initial_device_display_name: Optional[str],
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -1079,6 +1107,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1145,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+ if auth_provider_id and auth_provider_session_id:
+ await self.db_pool.simple_insert(
+ "device_auth_providers",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="store_device_auth_provider",
+ )
+
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@@ -1168,6 +1210,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_auth_providers",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef01e..9580a40785 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
- LEFT JOIN state_events USING (room_id, event_id)
+ LEFT JOIN state_events AS se USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
- AND state_key IS null
+ AND se.state_key IS null
)
"""
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
+from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
+class BasePushAction(TypedDict):
+ event_id: str
+ actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+ room_id: str
+ stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+ received_ts: Optional[int]
+
+
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[HttpPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[EmailPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index c3440de2cb..4e528612ea 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -124,10 +124,12 @@ class PersistEventsStore:
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
+ *,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
- backfilled: bool = False,
+ use_negative_stream_ordering: bool = False,
+ inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -140,7 +142,14 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
- backfilled
+ use_negative_stream_ordering: Whether to start stream_ordering on
+ the negative side and decrement. This should be set as True
+ for backfilled events because backfilled events get a negative
+ stream ordering so they don't come down incremental `/sync`.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
Returns:
Resolves when the events have been persisted
@@ -162,7 +171,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
- if backfilled:
+ if use_negative_stream_ordering:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -179,13 +188,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
- if not backfilled:
+ if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@@ -319,8 +328,9 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
+ *,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@@ -333,7 +343,10 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
- backfilled: True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@@ -366,9 +379,7 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(
- txn, events_and_contexts=events_and_contexts, backfilled=backfilled
- )
+ self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -401,7 +412,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# We call this last as it assumes we've inserted the events into
@@ -564,9 +575,9 @@ class PersistEventsStore:
# fetch their auth event info.
while missing_auth_chains:
sql = """
- SELECT event_id, events.type, state_key, chain_id, sequence_number
+ SELECT event_id, events.type, se.state_key, chain_id, sequence_number
FROM events
- INNER JOIN state_events USING (event_id)
+ INNER JOIN state_events AS se USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
@@ -1203,7 +1214,6 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
):
"""Update min_depth for each room
@@ -1211,13 +1221,18 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
- backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
- if not backfilled:
+ # Then update the `stream_ordering` position to mark the latest
+ # event as the front of the room. This should not be done for
+ # backfilled events because backfilled events have negative
+ # stream_ordering and happened in the past so we know that we don't
+ # need to update the stream_ordering tip/front for the room.
+ assert event.internal_metadata.stream_ordering is not None
+ if event.internal_metadata.stream_ordering >= 0:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@@ -1430,7 +1445,12 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
- self, txn, events_and_contexts, all_events_and_contexts, backfilled
+ self,
+ txn,
+ *,
+ events_and_contexts,
+ all_events_and_contexts,
+ inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@@ -1442,7 +1462,10 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
- backfilled (bool): True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
"""
# Insert all the push actions into the event_push_actions table.
@@ -1516,7 +1539,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# Insert event_reference_hashes table.
@@ -1643,8 +1666,19 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database."""
+ def _store_room_members_txn(
+ self, txn, events, *, inhibit_local_membership_updates: bool = False
+ ):
+ """
+ Store a room member in the database.
+ Args:
+ txn: The transaction to use.
+ events: List of events to store.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
+ """
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1687,7 +1721,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
- and not backfilled
+ and not inhibit_local_membership_updates
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4cefc0a07e..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1408,10 +1408,10 @@ class EventsWorkerStore(SQLBaseStore):
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1449,11 +1449,11 @@ class EventsWorkerStore(SQLBaseStore):
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1507,10 +1507,10 @@ class EventsWorkerStore(SQLBaseStore):
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@@ -1537,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@@ -1762,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
+
+ async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a backward gap of missing events.
+ <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question has any of its prev_events listed as a
+ # backward extremity, it's next to a gap.
+ #
+ # We can't just check the backward edges in `event_edges` because
+ # when we persist events, we will also record the prev_events as
+ # edges to the event in question regardless of whether we have those
+ # prev_events yet. We need to check whether those prev_events are
+ # backward extremities, also known as gaps, that need to be
+ # backfilled.
+ backward_extremity_query = """
+ SELECT 1 FROM event_backward_extremities
+ WHERE
+ room_id = ?
+ AND %s
+ LIMIT 1
+ """
+
+ # If the event in question is a backward extremity or has any of its
+ # prev_events listed as a backward extremity, it's next to a
+ # backward gap.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "event_id",
+ [event.event_id] + list(event.prev_event_ids()),
+ )
+
+ txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+ backward_extremities = txn.fetchall()
+
+ # We consider any backward extremity as a backward gap
+ if len(backward_extremities):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_backward_gap_txn",
+ is_event_next_to_backward_gap_txn,
+ )
+
+ async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a forward gap of missing events.
+ The gap in front of the latest events is not considered a gap.
+ <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
+ <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question is a forward extremity, we will just
+ # consider any potential forward gap as not a gap since it's one of
+ # the latest events in the room.
+ #
+ # `event_forward_extremities` does not include backfilled or outlier
+ # events so we can't rely on it to find forward gaps. We can only
+ # use it to determine whether a message is the latest in the room.
+ #
+ # We can't combine this query with the `forward_edge_query` below
+ # because if the event in question has no forward edges (isn't
+ # referenced by any other event's prev_events) but is in
+ # `event_forward_extremities`, we don't want to return 0 rows and
+ # say it's next to a gap.
+ forward_extremity_query = """
+ SELECT 1 FROM event_forward_extremities
+ WHERE
+ room_id = ?
+ AND event_id = ?
+ LIMIT 1
+ """
+
+ # Check to see whether the event in question is already referenced
+ # by another event. If we don't see any edges, we're next to a
+ # forward gap.
+ forward_edge_query = """
+ SELECT 1 FROM event_edges
+ /* Check to make sure the event referencing our event in question is not rejected */
+ LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ WHERE
+ event_edges.room_id = ?
+ AND event_edges.prev_event_id = ?
+ /* It's not a valid edge if the event referencing our event in
+ * question is rejected.
+ */
+ AND rejections.event_id IS NULL
+ LIMIT 1
+ """
+
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ forward_extremities = txn.fetchall()
+ if len(forward_extremities):
+ return False
+
+ # If there are no forward edges to the event in question (another
+ # event hasn't referenced this event in their prev_events), then we
+ # assume there is a forward gap in the history.
+ txn.execute(forward_edge_query, (event.room_id, event.event_id))
+ forward_edges = txn.fetchall()
+ if not len(forward_edges):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_gap_txn",
+ is_event_next_to_gap_txn,
+ )
+
+ async def get_event_id_for_timestamp(
+ self, room_id: str, timestamp: int, direction: str
+ ) -> Optional[str]:
+ """Find the closest event to the given timestamp in the given direction.
+
+ Args:
+ room_id: Room to fetch the event from
+ timestamp: The point in time (inclusive) we should navigate from in
+ the given direction to find the closest event.
+ direction: ["f"|"b"] to indicate whether we should navigate forward
+ or backward from the given timestamp to find the closest event.
+
+ Returns:
+ The closest event_id otherwise None if we can't find any event in
+ the given direction.
+ """
+
+ sql_template = """
+ SELECT event_id FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ origin_server_ts %s ?
+ AND room_id = ?
+ /* Make sure event is not rejected */
+ AND rejections.event_id IS NULL
+ ORDER BY origin_server_ts %s
+ LIMIT 1;
+ """
+
+ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
+
+ txn.execute(
+ sql_template % (comparison_operator, order), (timestamp, room_id)
+ )
+ row = txn.fetchone()
+ if row:
+ (event_id,) = row
+ return event_id
+
+ return None
+
+ if direction not in ("f", "b"):
+ raise ValueError("Unknown direction: %s" % (direction,))
+
+ return await self.db_pool.runInteraction(
+ "get_event_id_for_timestamp_txn",
+ get_event_id_for_timestamp_txn,
+ )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944bf..91b0576b85 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
- should_delete_expr = "state_key IS NULL"
+ should_delete_expr = "state_events.state_key IS NULL"
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831d6..6b2a8d06a6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND c.membership = ?
"""
else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND m.membership = ?
"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d17..57aab55259 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
oldest `limit` events.
Returns:
- The list of events (in ascending order) and the token from the start
+ The list of events (in ascending stream order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return [], from_key
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
+ """Fetch membership events for a given user.
+
+ All such events whose stream ordering `s` lies in the range
+ `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+ order.
+ """
+ # Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return []
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
Returns:
A list of events and a token pointing to the start of the returned
- events. The events returned are in ascending order.
+ events. The events returned are in ascending topological order.
"""
rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73ac..1622822552 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
import logging
from collections import namedtuple
+from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
)
+class DestinationSortOrder(Enum):
+ """Enum to define the sorting method used when returning destinations."""
+
+ DESTINATION = "destination"
+ RETRY_LAST_TS = "retry_last_ts"
+ RETTRY_INTERVAL = "retry_interval"
+ FAILURE_TS = "failure_ts"
+ LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destinations = [row[0] for row in txn]
return destinations
+
+ async def get_destinations_paginate(
+ self,
+ start: int,
+ limit: int,
+ destination: Optional[str] = None,
+ order_by: str = DestinationSortOrder.DESTINATION.value,
+ direction: str = "f",
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destinations.
+ This will return a json list of destinations and the
+ total number of destinations matching the filter criteria.
+
+ Args:
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ destination: search string in destination
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
+ Returns:
+ A tuple of a list of mappings from destination to information
+ and a count of total destinations.
+ """
+
+ def get_destinations_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ order_by_column = DestinationSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ args = []
+ where_statement = ""
+ if destination:
+ args.extend(["%" + destination.lower() + "%"])
+ where_statement = "WHERE LOWER(destination) LIKE ?"
+
+ sql_base = f"FROM destinations {where_statement} "
+ sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = f"""
+ SELECT destination, retry_last_ts, retry_interval, failure_ts,
+ last_successful_stream_ordering
+ {sql_base}
+ ORDER BY {order_by_column} {order}, destination ASC
+ LIMIT ? OFFSET ?
+ """
+ txn.execute(sql, args + [limit, start])
+ destinations = self.db_pool.cursor_to_dict(txn)
+ return destinations, count
+
+ return await self.db_pool.runInteraction(
+ "get_destinations_paginate_txn", get_destinations_paginate_txn
+ )
|