1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
class SessionStore(SQLBaseStore):
"""
A store for generic session data.
Each type of session should provide a unique type (to separate sessions).
Sessions are automatically removed when they expire.
"""
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(
self, session_type: str, value: JsonDict, expiry_ms: int
) -> str:
"""
Creates a new pagination session for the room hierarchy endpoint.
Args:
session_type: The type for this session.
value: The value to store.
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
Returns:
The newly created session ID.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db_pool.simple_insert(
table="sessions",
values={
"session_id": session_id,
"session_type": session_type,
"value": json_encoder.encode(value),
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
},
desc="create_session",
)
return session_id
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
"""
Retrieve data stored with create_session
Args:
session_type: The type for this session.
session_id: The session ID returned from create_session.
Raises:
StoreError if the session cannot be found.
"""
def _get_session(
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
) -> JsonDict:
# This includes the expiry time since items are only periodically
# deleted, not upon expiry.
select_sql = """
SELECT value FROM sessions WHERE
session_type = ? AND session_id = ? AND expiry_time_ms > ?
"""
txn.execute(select_sql, [session_type, session_id, ts])
row = txn.fetchone()
if not row:
raise StoreError(404, "No session")
return db_to_json(row[0])
return await self.db_pool.runInteraction(
"get_session",
_get_session,
session_type,
session_id,
self._clock.time_msec(),
)
@wrap_as_background_process("delete_expired_sessions")
async def _delete_expired_sessions(self) -> None:
"""Remove sessions with expiry dates that have passed."""
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
txn.execute(sql, (ts,))
await self.db_pool.runInteraction(
"delete_expired_sessions",
_delete_expired_sessions_txn,
self._clock.time_msec(),
)
|