summary refs log tree commit diff
path: root/synapse/storage/databases/main/presence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/presence.py')
-rw-r--r--synapse/storage/databases/main/presence.py92
1 files changed, 90 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c207d917b1..db22fab23e 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,16 +12,69 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import PresenceState, UserPresenceState
+from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class PresenceStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: Connection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._can_persist_presence = (
+            hs.get_instance_name() in hs.config.worker.writers.presence
+        )
+
+        if isinstance(database.engine, PostgresEngine):
+            self._presence_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="presence_stream",
+                instance_name=self._instance_name,
+                tables=[("presence_stream", "instance_name", "stream_id")],
+                sequence_name="presence_stream_sequence",
+                writers=hs.config.worker.writers.to_device,
+            )
+        else:
+            self._presence_id_gen = StreamIdGenerator(
+                db_conn, "presence_stream", "stream_id"
+            )
+
+        self._presence_on_startup = self._get_active_presence(db_conn)
+
+        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
+            db_conn,
+            "presence_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._presence_id_gen.get_current_token(),
+        )
+        self.presence_stream_cache = StreamChangeCache(
+            "PresenceStreamChangeCache",
+            min_presence_val,
+            prefilled_cache=presence_cache_prefill,
+        )
+
     async def update_presence(self, presence_states):
+        assert self._can_persist_presence
+
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
@@ -57,6 +110,7 @@ class PresenceStore(SQLBaseStore):
                     "last_user_sync_ts": state.last_user_sync_ts,
                     "status_msg": state.status_msg,
                     "currently_active": state.currently_active,
+                    "instance_name": self._instance_name,
                 }
                 for stream_id, state in zip(stream_orderings, presence_states)
             ],
@@ -216,3 +270,37 @@ class PresenceStore(SQLBaseStore):
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
+
+    def _get_active_presence(self, db_conn: Connection):
+        """Fetch non-offline presence from the database so that we can register
+        the appropriate time outs.
+        """
+
+        sql = (
+            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+            " WHERE state != ?"
+        )
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (PresenceState.OFFLINE,))
+        rows = self.db_pool.cursor_to_dict(txn)
+        txn.close()
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return [UserPresenceState(**row) for row in rows]
+
+    def take_presence_startup_info(self):
+        active_on_startup = self._presence_on_startup
+        self._presence_on_startup = None
+        return active_on_startup
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == PresenceStream.NAME:
+            self._presence_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.presence_stream_cache.entity_has_changed(row.user_id, token)
+                self._get_presence_for_user.invalidate((row.user_id,))
+        return super().process_replication_rows(stream_name, instance_name, token, rows)