summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py73
-rw-r--r--synapse/federation/transport/server.py24
-rw-r--r--synapse/handlers/auth.py21
-rw-r--r--synapse/handlers/federation.py19
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/state.py5
-rw-r--r--synapse/storage/_base.py5
-rw-r--r--synapse/storage/client_ips.py6
-rw-r--r--synapse/storage/events.py5
-rw-r--r--synapse/storage/state.py10
-rw-r--r--synapse/util/caches/__init__.py2
-rw-r--r--synapse/util/caches/descriptors.py5
-rw-r--r--synapse/util/caches/expiringcache.py3
-rw-r--r--synapse/util/caches/stream_change_cache.py6
16 files changed, 139 insertions, 60 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 9df7d18993..60af1cbecd 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.21.1"
+__version__ = "0.22.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d23bcecbad..e3da45b416 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,6 +23,8 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes
 from synapse.types import UserID
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -38,6 +40,10 @@ AuthEventTypes = (
 GUEST_DEVICE_ID = "guest_device"
 
 
+class _InvalidMacaroonException(Exception):
+    pass
+
+
 class Auth(object):
     """
     FIXME: This class contains a mix of functions for authenticating users
@@ -50,6 +56,9 @@ class Auth(object):
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
 
+        self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+        register_cache("token_cache", self.token_cache)
+
     @defer.inlineCallbacks
     def check_from_context(self, event, context, do_sig_check=True):
         auth_events_ids = yield self.compute_auth_events(
@@ -266,8 +275,8 @@ class Auth(object):
             AuthError if no user by that token exists or the token is invalid.
         """
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(token)
-        except Exception:  # deserialize can throw more-or-less anything
+            user_id, guest = self._parse_and_validate_macaroon(token, rights)
+        except _InvalidMacaroonException:
             # doesn't look like a macaroon: treat it as an opaque token which
             # must be in the database.
             # TODO: it would be nice to get rid of this, but apparently some
@@ -276,19 +285,8 @@ class Auth(object):
             defer.returnValue(r)
 
         try:
-            user_id = self.get_user_id_from_macaroon(macaroon)
             user = UserID.from_string(user_id)
 
-            self.validate_macaroon(
-                macaroon, rights, self.hs.config.expire_access_token,
-                user_id=user_id,
-            )
-
-            guest = False
-            for caveat in macaroon.caveats:
-                if caveat.caveat_id == "guest = true":
-                    guest = True
-
             if guest:
                 # Guest access tokens are not stored in the database (there can
                 # only be one access token per guest, anyway).
@@ -360,6 +358,55 @@ class Auth(object):
                 errcode=Codes.UNKNOWN_TOKEN
             )
 
+    def _parse_and_validate_macaroon(self, token, rights="access"):
+        """Takes a macaroon and tries to parse and validate it. This is cached
+        if and only if rights == access and there isn't an expiry.
+
+        On invalid macaroon raises _InvalidMacaroonException
+
+        Returns:
+            (user_id, is_guest)
+        """
+        if rights == "access":
+            cached = self.token_cache.get(token, None)
+            if cached:
+                return cached
+
+        try:
+            macaroon = pymacaroons.Macaroon.deserialize(token)
+        except Exception:  # deserialize can throw more-or-less anything
+            # doesn't look like a macaroon: treat it as an opaque token which
+            # must be in the database.
+            # TODO: it would be nice to get rid of this, but apparently some
+            # people use access tokens which aren't macaroons
+            raise _InvalidMacaroonException()
+
+        try:
+            user_id = self.get_user_id_from_macaroon(macaroon)
+
+            has_expiry = False
+            guest = False
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id.startswith("time "):
+                    has_expiry = True
+                elif caveat.caveat_id == "guest = true":
+                    guest = True
+
+            self.validate_macaroon(
+                macaroon, rights, self.hs.config.expire_access_token,
+                user_id=user_id,
+            )
+        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+
+        if not has_expiry and rights == "access":
+            self.token_cache[token] = (user_id, guest)
+
+        return user_id, guest
+
     def get_user_id_from_macaroon(self, macaroon):
         """Retrieve the user_id given by the caveats on the macaroon.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3d676e7d8b..a78f01e442 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -153,12 +153,10 @@ class Authenticator(object):
 class BaseFederationServlet(object):
     REQUIRE_AUTH = True
 
-    def __init__(self, handler, authenticator, ratelimiter, server_name,
-                 room_list_handler):
+    def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
         self.ratelimiter = ratelimiter
-        self.room_list_handler = room_list_handler
 
     def _wrap(self, func):
         authenticator = self.authenticator
@@ -590,7 +588,7 @@ class PublicRoomList(BaseFederationServlet):
         else:
             network_tuple = ThirdPartyInstanceID(None, None)
 
-        data = yield self.room_list_handler.get_local_public_room_list(
+        data = yield self.handler.get_local_public_room_list(
             limit, since_token,
             network_tuple=network_tuple
         )
@@ -611,7 +609,7 @@ class FederationVersionServlet(BaseFederationServlet):
         }))
 
 
-SERVLET_CLASSES = (
+FEDERATION_SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
     FederationEventServlet,
@@ -634,17 +632,27 @@ SERVLET_CLASSES = (
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     OpenIdUserInfo,
-    PublicRoomList,
     FederationVersionServlet,
 )
 
+ROOM_LIST_CLASSES = (
+    PublicRoomList,
+)
+
 
 def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in SERVLET_CLASSES:
+    for servletclass in FEDERATION_SERVLET_CLASSES:
         servletclass(
             handler=hs.get_replication_layer(),
             authenticator=authenticator,
             ratelimiter=ratelimiter,
             server_name=hs.hostname,
-            room_list_handler=hs.get_room_list_handler(),
+        ).register(resource)
+
+    for servletclass in ROOM_LIST_CLASSES:
+        servletclass(
+            handler=hs.get_room_list_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
         ).register(resource)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e7a1bb7246..b00446bec0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -21,6 +21,7 @@ from synapse.api.constants import LoginType
 from synapse.types import UserID
 from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
 from synapse.util.async import run_on_reactor
+from synapse.util.caches.expiringcache import ExpiringCache
 
 from twisted.web.client import PartialDownloadError
 
@@ -52,7 +53,15 @@ class AuthHandler(BaseHandler):
             LoginType.DUMMY: self._check_dummy_auth,
         }
         self.bcrypt_rounds = hs.config.bcrypt_rounds
-        self.sessions = {}
+
+        # This is not a cache per se, but a store of all current sessions that
+        # expire after N hours
+        self.sessions = ExpiringCache(
+            cache_name="register_sessions",
+            clock=hs.get_clock(),
+            expiry_ms=self.SESSION_EXPIRE_MS,
+            reset_expiry_on_get=True,
+        )
 
         account_handler = _AccountHandler(
             hs, check_user_exists=self.check_user_exists
@@ -617,16 +626,6 @@ class AuthHandler(BaseHandler):
         logger.debug("Saving session %s", session)
         session["last_used"] = self.hs.get_clock().time_msec()
         self.sessions[session["id"]] = session
-        self._prune_sessions()
-
-    def _prune_sessions(self):
-        for sid, sess in self.sessions.items():
-            last_used = 0
-            if 'last_used' in sess:
-                last_used = sess['last_used']
-            now = self.hs.get_clock().time_msec()
-            if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
-                del self.sessions[sid]
 
     def hash(self, password):
         """Computes a secure hash of password.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7ae369a1d..483cb8eac6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -75,6 +75,7 @@ class FederationHandler(BaseHandler):
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
+        self.is_mine_id = hs.is_mine_id
 
         self.replication_layer.set_handler(self)
 
@@ -1072,6 +1073,20 @@ class FederationHandler(BaseHandler):
         if is_blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
+        membership = event.content.get("membership")
+        if event.type != EventTypes.Member or membership != Membership.INVITE:
+            raise SynapseError(400, "The event was not an m.room.member invite event")
+
+        sender_domain = get_domain_from_id(event.sender)
+        if sender_domain != origin:
+            raise SynapseError(400, "The invite event was not from the server sending it")
+
+        if event.state_key is None:
+            raise SynapseError(400, "The invite event did not have a state key")
+
+        if not self.is_mine_id(event.state_key):
+            raise SynapseError(400, "The invite event must be for this server")
+
         event.internal_metadata.outlier = True
         event.internal_metadata.invite_from_remote = True
 
@@ -1280,7 +1295,7 @@ class FederationHandler(BaseHandler):
             for event in res:
                 # We sign these again because there was a bug where we
                 # incorrectly signed things the first time round
-                if self.hs.is_mine_id(event.event_id):
+                if self.is_mine_id(event.event_id):
                     event.signatures.update(
                         compute_event_signature(
                             event,
@@ -1353,7 +1368,7 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            if self.hs.is_mine_id(event.event_id):
+            if self.is_mine_id(event.event_id):
                 # FIXME: This is a temporary work around where we occasionally
                 # return events slightly differently than when they were
                 # originally signed
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a04f634c5c..24c9ffdb20 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -34,6 +34,7 @@ from canonicaljson import encode_canonical_json
 
 import logging
 import random
+import ujson
 
 logger = logging.getLogger(__name__)
 
@@ -498,6 +499,14 @@ class MessageHandler(BaseHandler):
             logger.warn("Denying new event %r because %s", event, err)
             raise err
 
+        # Ensure that we can round trip before trying to persist in db
+        try:
+            dump = ujson.dumps(event.content)
+            ujson.loads(dump)
+        except:
+            logger.exception("Failed to encode content: %r", event.content)
+            raise
+
         yield self.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4d88046579..172c27c137 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -200,7 +200,9 @@ def _glob_to_re(glob, word_boundary):
         return re.compile(r, flags=re.IGNORECASE)
 
 
-def _flatten_dict(d, prefix=[], result={}):
+def _flatten_dict(d, prefix=[], result=None):
+    if result is None:
+        result = {}
     for key, value in d.items():
         if isinstance(value, basestring):
             result[".".join(prefix + [key])] = value.lower()
diff --git a/synapse/state.py b/synapse/state.py
index 5b386e3183..390799fbd5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.events.snapshot import EventContext
 from synapse.util.async import Linearizer
+from synapse.util.caches import CACHE_SIZE_FACTOR
 
 from collections import namedtuple
 from frozendict import frozendict
 
 import logging
 import hashlib
-import os
 
 logger = logging.getLogger(__name__)
 
@@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 51730a88bf..6f54036d67 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,6 +16,7 @@ import logging
 
 from synapse.api.errors import StoreError
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 from synapse.storage.engines import PostgresEngine
@@ -27,10 +28,6 @@ from twisted.internet import defer
 import sys
 import time
 import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 5a88e242e5..3c95e90eca 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -20,7 +20,8 @@ from twisted.internet import defer, reactor
 from ._base import Cache
 from . import background_updates
 
-import os
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
 
 logger = logging.getLogger(__name__)
 
@@ -30,9 +31,6 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class ClientIpStore(background_updates.BackgroundUpdateStore):
     def __init__(self, hs):
         self.client_ip_last_seen = Cache(
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2b7340c1d9..7002b3752e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -403,6 +403,11 @@ class EventsStore(SQLBaseStore):
                         (room_id, ), new_state
                     )
 
+                for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+                    self.get_latest_event_ids_in_room.prefill(
+                        (room_id,), list(latest_event_ids)
+                    )
+
     @defer.inlineCallbacks
     def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
         """Calculates the new forward extremeties for a room given events to
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1e679719b..5673e4aa96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -315,6 +315,12 @@ class StateStore(SQLBaseStore):
             ],
         )
 
+        for event_id, state_group_id in state_groups.iteritems():
+            txn.call_after(
+                self._get_state_group_for_event.prefill,
+                (event_id,), state_group_id
+            )
+
     def _count_state_group_hops_txn(self, txn, state_group):
         """Given a state group, count how many hops there are in the tree.
 
@@ -584,8 +590,8 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=50000)
-    def _get_state_group_for_event(self, room_id, event_id):
+    @cached(max_entries=50000)
+    def _get_state_group_for_event(self, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4a83c46d98..4adae96681 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
 import synapse.metrics
 import os
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
 
 metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index cbdff86596..af65bfe7b8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -16,6 +16,7 @@ import logging
 
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 from synapse.util.stringutils import to_ascii
@@ -25,7 +26,6 @@ from . import register_cache
 from twisted.internet import defer
 from collections import namedtuple
 
-import os
 import functools
 import inspect
 import threading
@@ -37,9 +37,6 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class CacheEntry(object):
     __slots__ = [
         "deferred", "sequence", "callbacks", "invalidated"
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index cbdde34a57..6ad53a6390 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -94,6 +94,9 @@ class ExpiringCache(object):
 
         return entry.value
 
+    def __contains__(self, key):
+        return key in self._cache
+
     def get(self, key, default=None):
         try:
             return self[key]
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 609625b322..941d873ab8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,20 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
 
 
 from blist import sorteddict
 import logging
-import os
 
 
 logger = logging.getLogger(__name__)
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class StreamChangeCache(object):
     """Keeps track of the stream positions of the latest change in a set of entities.