diff options
-rw-r--r-- | synapse/handlers/auth.py | 21 | ||||
-rw-r--r-- | synapse/util/caches/expiringcache.py | 3 |
2 files changed, 13 insertions, 11 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e7a1bb7246..b00446bec0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -21,6 +21,7 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.util.caches.expiringcache import ExpiringCache from twisted.web.client import PartialDownloadError @@ -52,7 +53,15 @@ class AuthHandler(BaseHandler): LoginType.DUMMY: self._check_dummy_auth, } self.bcrypt_rounds = hs.config.bcrypt_rounds - self.sessions = {} + + # This is not a cache per se, but a store of all current sessions that + # expire after N hours + self.sessions = ExpiringCache( + cache_name="register_sessions", + clock=hs.get_clock(), + expiry_ms=self.SESSION_EXPIRE_MS, + reset_expiry_on_get=True, + ) account_handler = _AccountHandler( hs, check_user_exists=self.check_user_exists @@ -617,16 +626,6 @@ class AuthHandler(BaseHandler): logger.debug("Saving session %s", session) session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - self._prune_sessions() - - def _prune_sessions(self): - for sid, sess in self.sessions.items(): - last_used = 0 - if 'last_used' in sess: - last_used = sess['last_used'] - now = self.hs.get_clock().time_msec() - if last_used < now - AuthHandler.SESSION_EXPIRE_MS: - del self.sessions[sid] def hash(self, password): """Computes a secure hash of password. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cbdde34a57..6ad53a6390 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -94,6 +94,9 @@ class ExpiringCache(object): return entry.value + def __contains__(self, key): + return key in self._cache + def get(self, key, default=None): try: return self[key] |