diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index db22fab23e..6a2baa7841 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
db_conn, "presence_stream", "stream_id"
)
+ self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
@@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+ # Delete old rows to stop database from getting really big
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+ for states in batch_iter(presence_states, 50):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
+
# Actually insert new rows
self.db_pool.simple_insert_many_txn(
txn,
@@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
],
)
- # Delete old rows to stop database from getting really big
- sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
-
- for states in batch_iter(presence_states, 50):
- clause, args = make_in_list_sql_clause(
- self.database_engine, "user_id", [s.user_id for s in states]
- )
- txn.execute(sql + clause, [stream_id] + list(args))
-
async def get_all_presence_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
@@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ async def should_user_receive_full_presence_with_token(
+ self,
+ user_id: str,
+ from_token: int,
+ ) -> bool:
+ """Check whether the given user should receive full presence using the stream token
+ they're updating from.
+
+ Args:
+ user_id: The ID of the user to check.
+ from_token: The stream token included in their /sync token.
+
+ Returns:
+ True if the user should have full presence sent to them, False otherwise.
+ """
+
+ def _should_user_receive_full_presence_with_token_txn(txn):
+ sql = """
+ SELECT 1 FROM users_to_send_full_presence_to
+ WHERE user_id = ?
+ AND presence_stream_id >= ?
+ """
+ txn.execute(sql, (user_id, from_token))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "should_user_receive_full_presence_with_token",
+ _should_user_receive_full_presence_with_token_txn,
+ )
+
+ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+ """Adds to the list of users who should receive a full snapshot of presence
+ upon their next sync.
+
+ Args:
+ user_ids: An iterable of user IDs.
+ """
+ # Add user entries to the table, updating the presence_stream_id column if the user already
+ # exists in the table.
+ await self.db_pool.simple_upsert_many(
+ table="users_to_send_full_presence_to",
+ key_names=("user_id",),
+ key_values=[(user_id,) for user_id in user_ids],
+ value_names=("presence_stream_id",),
+ # We save the current presence stream ID token along with the user ID entry so
+ # that when a user /sync's, even if they syncing multiple times across separate
+ # devices at different times, each device will receive full presence once - when
+ # the presence stream ID in their sync token is less than the one in the table
+ # for their user ID.
+ value_values=(
+ (self._presence_id_gen.get_current_token(),) for _ in user_ids
+ ),
+ desc="add_users_to_send_full_presence_to",
+ )
+
async def get_presence_for_all_users(
self,
include_offline: bool = True,
|