summary refs log tree commit diff
path: root/synapse/storage/databases/main/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/session.py')
-rw-r--r--synapse/storage/databases/main/session.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 0000000000..172f27d109
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+    """
+    A store for generic session data.
+
+    Each type of session should provide a unique type (to separate sessions).
+
+    Sessions are automatically removed when they expire.
+    """
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        # Create a background job for culling expired sessions.
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+    async def create_session(
+        self, session_type: str, value: JsonDict, expiry_ms: int
+    ) -> str:
+        """
+        Creates a new pagination session for the room hierarchy endpoint.
+
+        Args:
+            session_type: The type for this session.
+            value: The value to store.
+            expiry_ms: How long before an item is evicted from the cache
+                in milliseconds. Default is 0, indicating items never get
+                evicted based on time.
+
+        Returns:
+            The newly created session ID.
+
+        Raises:
+            StoreError if a unique session ID cannot be generated.
+        """
+        # autogen a session ID and try to create it. We may clash, so just
+        # try a few times till one goes through, giving up eventually.
+        attempts = 0
+        while attempts < 5:
+            session_id = stringutils.random_string(24)
+
+            try:
+                await self.db_pool.simple_insert(
+                    table="sessions",
+                    values={
+                        "session_id": session_id,
+                        "session_type": session_type,
+                        "value": json_encoder.encode(value),
+                        "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+                    },
+                    desc="create_session",
+                )
+
+                return session_id
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+        raise StoreError(500, "Couldn't generate a session ID.")
+
+    async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+        """
+        Retrieve data stored with create_session
+
+        Args:
+            session_type: The type for this session.
+            session_id: The session ID returned from create_session.
+
+        Raises:
+            StoreError if the session cannot be found.
+        """
+
+        def _get_session(
+            txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+        ) -> JsonDict:
+            # This includes the expiry time since items are only periodically
+            # deleted, not upon expiry.
+            select_sql = """
+            SELECT value FROM sessions WHERE
+            session_type = ? AND session_id = ? AND expiry_time_ms > ?
+            """
+            txn.execute(select_sql, [session_type, session_id, ts])
+            row = txn.fetchone()
+
+            if not row:
+                raise StoreError(404, "No session")
+
+            return db_to_json(row[0])
+
+        return await self.db_pool.runInteraction(
+            "get_session",
+            _get_session,
+            session_type,
+            session_id,
+            self._clock.time_msec(),
+        )
+
+    @wrap_as_background_process("delete_expired_sessions")
+    async def _delete_expired_sessions(self) -> None:
+        """Remove sessions with expiry dates that have passed."""
+
+        def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+            sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "delete_expired_sessions",
+            _delete_expired_sessions_txn,
+            self._clock.time_msec(),
+        )