diff --git a/synapse/__init__.py b/synapse/__init__.py
index 9df7d18993..60af1cbecd 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.21.1"
+__version__ = "0.22.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9dbc7993df..f8266d1c81 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
-from synapse.util import logcontext
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -39,6 +40,10 @@ AuthEventTypes = (
GUEST_DEVICE_ID = "guest_device"
+class _InvalidMacaroonException(Exception):
+ pass
+
+
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
@@ -51,6 +56,9 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+ self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+ register_cache("token_cache", self.token_cache)
+
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
@@ -144,17 +152,8 @@ class Auth(object):
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-
- logger.debug("calling resolve_state_groups from check_host_in_room")
- entry = yield self.state.resolve_state_groups(
- room_id, latest_event_ids
- )
-
- ret = yield self.store.is_host_joined(
- room_id, host, entry.state_group, entry.state
- )
- defer.returnValue(ret)
+ latest_event_ids = yield self.store.is_host_joined(room_id, host)
+ defer.returnValue(latest_event_ids)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
@@ -209,7 +208,7 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
- logcontext.preserve_fn(self.store.insert_client_ip)(
+ self.store.insert_client_ip(
user=user,
access_token=access_token,
ip=ip_addr,
@@ -276,8 +275,8 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
- macaroon = pymacaroons.Macaroon.deserialize(token)
- except Exception: # deserialize can throw more-or-less anything
+ user_id, guest = self._parse_and_validate_macaroon(token, rights)
+ except _InvalidMacaroonException:
# doesn't look like a macaroon: treat it as an opaque token which
# must be in the database.
# TODO: it would be nice to get rid of this, but apparently some
@@ -286,19 +285,8 @@ class Auth(object):
defer.returnValue(r)
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
- self.validate_macaroon(
- macaroon, rights, self.hs.config.expire_access_token,
- user_id=user_id,
- )
-
- guest = False
- for caveat in macaroon.caveats:
- if caveat.caveat_id == "guest = true":
- guest = True
-
if guest:
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
@@ -370,6 +358,55 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
+ def _parse_and_validate_macaroon(self, token, rights="access"):
+ """Takes a macaroon and tries to parse and validate it. This is cached
+ if and only if rights == access and there isn't an expiry.
+
+ On invalid macaroon raises _InvalidMacaroonException
+
+ Returns:
+ (user_id, is_guest)
+ """
+ if rights == "access":
+ cached = self.token_cache.get(token, None)
+ if cached:
+ return cached
+
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ except Exception: # deserialize can throw more-or-less anything
+ # doesn't look like a macaroon: treat it as an opaque token which
+ # must be in the database.
+ # TODO: it would be nice to get rid of this, but apparently some
+ # people use access tokens which aren't macaroons
+ raise _InvalidMacaroonException()
+
+ try:
+ user_id = self.get_user_id_from_macaroon(macaroon)
+
+ has_expiry = False
+ guest = False
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith("time "):
+ has_expiry = True
+ elif caveat.caveat_id == "guest = true":
+ guest = True
+
+ self.validate_macaroon(
+ macaroon, rights, self.hs.config.expire_access_token,
+ user_id=user_id,
+ )
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
+ if not has_expiry and rights == "access":
+ self.token_cache[token] = (user_id, guest)
+
+ return user_id, guest
+
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 9b72c649ac..09bc1935f1 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -24,6 +24,7 @@ from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
@@ -33,7 +34,6 @@ from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
@@ -65,8 +65,8 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
TransactionStore,
+ SlavedClientIpStore,
BaseSlavedStore,
- ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index e51a69074d..03327dc47a 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -51,7 +51,7 @@ import sys
import logging
import gc
-logger = logging.getLogger("synapse.app.appservice")
+logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 26c4416956..f57ec784fe 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -23,13 +23,13 @@ from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.httpresourcetree import create_resource_tree
@@ -60,10 +60,10 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
+ SlavedClientIpStore,
TransactionStore,
BaseSlavedStore,
MediaRepositoryStore,
- ClientIpStore,
):
pass
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 13c00ef2ba..4bdd99a966 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -29,6 +29,7 @@ from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
@@ -42,7 +43,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
@@ -77,9 +77,9 @@ class SynchrotronSlavedStore(
SlavedPresenceStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
+ SlavedClientIpStore,
RoomStore,
BaseSlavedStore,
- ClientIpStore, # After BaseSlavedStore because the constructor is different
):
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
new file mode 100644
index 0000000000..8c6300db9d
--- /dev/null
+++ b/synapse/app/user_dir.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.server import HomeServer
+from synapse.config._base import ConfigError
+from synapse.config.logger import setup_logging
+from synapse.config.homeserver import HomeServerConfig
+from synapse.crypto import context_factory
+from synapse.http.site import SynapseSite
+from synapse.http.server import JsonResource
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha import user_directory
+from synapse.storage.engines import create_engine
+from synapse.storage.user_directory import UserDirectoryStore
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from synapse import events
+
+from twisted.internet import reactor
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+import gc
+
+logger = logging.getLogger("synapse.app.user_dir")
+
+
+class UserDirectorySlaveStore(
+ SlavedEventStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedClientIpStore,
+ UserDirectoryStore,
+ BaseSlavedStore,
+):
+ def __init__(self, db_conn, hs):
+ super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
+ self._current_state_delta_pos = events_max
+
+ def stream_positions(self):
+ result = super(UserDirectorySlaveStore, self).stream_positions()
+ result["current_state_deltas"] = self._current_state_delta_pos
+ return result
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "current_state_deltas":
+ self._current_state_delta_pos = token
+ for row in rows:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.room_id, token
+ )
+ return super(UserDirectorySlaveStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
+
+
+class UserDirectoryServer(HomeServer):
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ user_directory.register_servlets(self, resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, Resource())
+
+ for address in bind_addresses:
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=address
+ )
+
+ logger.info("Synapse user_dir now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ bind_addresses = listener["bind_addresses"]
+
+ for address in bind_addresses:
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=address
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return UserDirectoryReplicationHandler(self)
+
+
+class UserDirectoryReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
+ self.user_directory = hs.get_user_directory_handler()
+
+ def on_rdata(self, stream_name, token, rows):
+ super(UserDirectoryReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ if stream_name == "current_state_deltas":
+ preserve_fn(self.user_directory.notify_new_event)()
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse user directory", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.user_dir"
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ if config.update_user_directory:
+ sys.stderr.write(
+ "\nThe update_user_directory must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``update_user_directory: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.update_user_directory = True
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ps = UserDirectoryServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ps.setup()
+ ps.start_listening(config.worker_listeners)
+
+ def run():
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
+ logger.info("Running")
+ change_resource_limit(config.soft_file_limit)
+ if config.gc_thresholds:
+ gc.set_threshold(*config.gc_thresholds)
+ reactor.run()
+
+ def start():
+ ps.get_datastore().start_profiling()
+ ps.get_state_handler().start_caching()
+
+ reactor.callWhenRunning(start)
+
+ if config.worker_daemonize:
+ daemon = Daemonize(
+ app="synapse-user-dir",
+ pid=config.worker_pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 7346206bb1..b989007314 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -241,6 +241,16 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
+ def get_exlusive_user_regexes(self):
+ """Get the list of regexes used to determine if a user is exclusively
+ registered by the AS
+ """
+ return [
+ regex_obj["regex"]
+ for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+ if regex_obj["exclusive"]
+ ]
+
def is_rate_limited(self):
return self.rate_limited
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 0f890fc04a..b22cacf8dc 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -33,6 +33,7 @@ from .jwt import JWTConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig
from .workers import WorkerConfig
+from .push import PushConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -40,7 +41,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig,):
+ WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
pass
diff --git a/synapse/config/push.py b/synapse/config/push.py
new file mode 100644
index 0000000000..9c68318b40
--- /dev/null
+++ b/synapse/config/push.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class PushConfig(Config):
+ def read_config(self, config):
+ self.push_redact_content = False
+
+ push_config = config.get("email", {})
+ self.push_redact_content = push_config.get("redact_content", False)
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Control how push messages are sent to google/apple to notifications.
+ # Normally every message said in a room with one or more people using
+ # mobile devices will be posted to a push server hosted by matrix.org
+ # which is registered with google and apple in order to allow push
+ # notifications to be sent to these mobile devices.
+ #
+ # Setting redact_content to true will make the push messages contain no
+ # message content which will provide increased privacy. This is a
+ # temporary solution pending improvements to Android and iPhone apps
+ # to get content from the app rather than the notification.
+ #
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ #push:
+ # redact_content: false
+ """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 3910b9dc31..28b4e5f50c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -35,6 +35,10 @@ class ServerConfig(Config):
# "disable" federation
self.send_federation = config.get("send_federation", True)
+ # Whether to update the user directory or not. This should be set to
+ # false only if we are updating the user directory in a worker
+ self.update_user_directory = config.get("update_user_directory", True)
+
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
if self.public_baseurl is not None:
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index a15198e05d..003eaba893 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -187,6 +187,7 @@ class TransactionQueue(object):
prev_id for prev_id, _ in event.prev_events
],
)
+ destinations = set(destinations)
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e7a1bb7246..b00446bec0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -21,6 +21,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor
+from synapse.util.caches.expiringcache import ExpiringCache
from twisted.web.client import PartialDownloadError
@@ -52,7 +53,15 @@ class AuthHandler(BaseHandler):
LoginType.DUMMY: self._check_dummy_auth,
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
- self.sessions = {}
+
+ # This is not a cache per se, but a store of all current sessions that
+ # expire after N hours
+ self.sessions = ExpiringCache(
+ cache_name="register_sessions",
+ clock=hs.get_clock(),
+ expiry_ms=self.SESSION_EXPIRE_MS,
+ reset_expiry_on_get=True,
+ )
account_handler = _AccountHandler(
hs, check_user_exists=self.check_user_exists
@@ -617,16 +626,6 @@ class AuthHandler(BaseHandler):
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- self._prune_sessions()
-
- def _prune_sessions(self):
- for sid, sess in self.sessions.items():
- last_used = 0
- if 'last_used' in sess:
- last_used = sess['last_used']
- now = self.hs.get_clock().time_msec()
- if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
- del self.sessions[sid]
def hash(self, password):
"""Computes a secure hash of password.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 982cda3edf..ed60d494ff 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id) for device_id in device_map.keys())
+ user_id, device_id=None
)
devices = device_map.values()
@@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler):
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id),)
+ user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 52d97dfbf3..483cb8eac6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -43,7 +43,6 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
-from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
from twisted.internet import defer
@@ -75,6 +74,8 @@ class FederationHandler(BaseHandler):
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
+ self.action_generator = hs.get_action_generator()
+ self.is_mine_id = hs.is_mine_id
self.replication_layer.set_handler(self)
@@ -832,7 +833,11 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_event_auth(self, event_id):
- auth = yield self.store.get_auth_chain([event_id])
+ event = yield self.store.get_event(event_id)
+ auth = yield self.store.get_auth_chain(
+ [auth_id for auth_id, _ in event.auth_events],
+ include_given=True
+ )
for event in auth:
event.signatures.update(
@@ -1047,9 +1052,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id)
state_ids = context.prev_state_ids.values()
- auth_chain = yield self.store.get_auth_chain(set(
- [event.event_id] + state_ids
- ))
+ auth_chain = yield self.store.get_auth_chain(state_ids)
state = yield self.store.get_events(context.prev_state_ids.values())
@@ -1066,6 +1069,24 @@ class FederationHandler(BaseHandler):
"""
event = pdu
+ is_blocked = yield self.store.is_room_blocked(event.room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
+ membership = event.content.get("membership")
+ if event.type != EventTypes.Member or membership != Membership.INVITE:
+ raise SynapseError(400, "The event was not an m.room.member invite event")
+
+ sender_domain = get_domain_from_id(event.sender)
+ if sender_domain != origin:
+ raise SynapseError(400, "The invite event was not from the server sending it")
+
+ if event.state_key is None:
+ raise SynapseError(400, "The invite event did not have a state key")
+
+ if not self.is_mine_id(event.state_key):
+ raise SynapseError(400, "The invite event must be for this server")
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -1100,6 +1121,9 @@ class FederationHandler(BaseHandler):
user_id,
"leave"
)
+ # Mark as outlier as we don't have any state for this event; we're not
+ # even in the room.
+ event.internal_metadata.outlier = True
event = self._sign_event(event)
# Try the host that we succesfully called /make_leave/ on first for
@@ -1271,7 +1295,7 @@ class FederationHandler(BaseHandler):
for event in res:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
+ if self.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
@@ -1344,7 +1368,7 @@ class FederationHandler(BaseHandler):
)
if event:
- if self.hs.is_mine_id(event.event_id):
+ if self.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
@@ -1389,8 +1413,7 @@ class FederationHandler(BaseHandler):
)
if not event.internal_metadata.is_outlier():
- action_generator = ActionGenerator(self.hs)
- yield action_generator.handle_push_actions_for_event(
+ yield self.action_generator.handle_push_actions_for_event(
event, context
)
@@ -1599,7 +1622,11 @@ class FederationHandler(BaseHandler):
pass
# Now get the current auth_chain for the event.
- local_auth_chain = yield self.store.get_auth_chain([event_id])
+ event = yield self.store.get_event(event_id)
+ local_auth_chain = yield self.store.get_auth_chain(
+ [auth_id for auth_id, _ in event.auth_events],
+ include_given=True
+ )
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
@@ -1792,7 +1819,9 @@ class FederationHandler(BaseHandler):
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
)
- local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+ local_auth_chain = yield self.store.get_auth_chain(
+ auth_ids, include_given=True
+ )
try:
# 2. Get remote difference.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 196925edad..24c9ffdb20 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -20,7 +20,6 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
-from synapse.push.action_generator import ActionGenerator
from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
@@ -35,6 +34,7 @@ from canonicaljson import encode_canonical_json
import logging
import random
+import ujson
logger = logging.getLogger(__name__)
@@ -54,6 +54,8 @@ class MessageHandler(BaseHandler):
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
+ self.action_generator = hs.get_action_generator()
+
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
@@ -497,6 +499,14 @@ class MessageHandler(BaseHandler):
logger.warn("Denying new event %r because %s", event, err)
raise err
+ # Ensure that we can round trip before trying to persist in db
+ try:
+ dump = ujson.dumps(event.content)
+ ujson.loads(dump)
+ except:
+ logger.exception("Failed to encode content: %r", event.content)
+ raise
+
yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
@@ -590,8 +600,7 @@ class MessageHandler(BaseHandler):
"Changing the room create event is forbidden",
)
- action_generator = ActionGenerator(self.hs)
- yield action_generator.handle_push_actions_for_event(
+ yield self.action_generator.handle_push_actions_for_event(
event, context
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d2a0d6520a..5698d28088 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -61,7 +61,7 @@ class RoomCreationHandler(BaseHandler):
}
@defer.inlineCallbacks
- def create_room(self, requester, config):
+ def create_room(self, requester, config, ratelimit=True):
""" Creates a new room.
Args:
@@ -75,7 +75,8 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- yield self.ratelimit(requester)
+ if ratelimit:
+ yield self.ratelimit(requester)
if "room_alias_name" in config:
for wchar in string.whitespace:
@@ -167,6 +168,7 @@ class RoomCreationHandler(BaseHandler):
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
+ power_level_content_override=config.get("power_level_content_override", {})
)
if "name" in config:
@@ -245,7 +247,8 @@ class RoomCreationHandler(BaseHandler):
invite_list,
initial_state,
creation_content,
- room_alias
+ room_alias,
+ power_level_content_override,
):
def create(etype, content, **kwargs):
e = {
@@ -291,7 +294,15 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- if (EventTypes.PowerLevels, '') not in initial_state:
+ # We treat the power levels override specially as this needs to be one
+ # of the first events that get sent into a room.
+ pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+ if pl_content is not None:
+ yield send(
+ etype=EventTypes.PowerLevels,
+ content=pl_content,
+ )
+ else:
power_level_content = {
"users": {
creator_id: 100,
@@ -316,6 +327,8 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list:
power_level_content["users"][invitee] = 100
+ power_level_content.update(power_level_content_override)
+
yield send(
etype=EventTypes.PowerLevels,
content=power_level_content,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1ca88517a2..b3f979b246 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,6 +203,11 @@ class RoomMemberHandler(BaseHandler):
if not remote_room_hosts:
remote_room_hosts = []
+ if effective_membership_state not in ("leave", "ban",):
+ is_blocked = yield self.store.is_room_blocked(room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
@@ -369,6 +374,11 @@ class RoomMemberHandler(BaseHandler):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if event.membership not in (Membership.LEAVE, Membership.BAN):
+ is_blocked = yield self.store.is_room_blocked(room_id)
+ if is_blocked:
+ raise SynapseError(403, "This room has been blocked on this server")
+
yield message_handler.handle_new_client_event(
requester,
event,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c0205da1a9..91c6c6be3c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -117,6 +117,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
+ "device_one_time_keys_count", # Dict of algorithm to count for one time keys
+ # for this device
])):
__slots__ = []
@@ -550,6 +552,14 @@ class SyncHandler(object):
sync_result_builder
)
+ device_id = sync_config.device_id
+ one_time_key_counts = {}
+ if device_id:
+ user_id = sync_config.user.to_string()
+ one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -558,6 +568,7 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
+ device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b7818af5c..82dedbbc99 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -89,7 +89,7 @@ class TypingHandler(object):
until = self._member_typing_until.get(member, None)
if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id)
- preserve_fn(self._stopped_typing)(member)
+ self._stopped_typing(member)
continue
# Check if we need to resend a keep alive over federation for this
@@ -147,7 +147,7 @@ class TypingHandler(object):
# No point sending another notification
defer.returnValue(None)
- yield self._push_update(
+ self._push_update(
member=member,
typing=True,
)
@@ -171,7 +171,7 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id)
- yield self._stopped_typing(member)
+ self._stopped_typing(member)
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -180,7 +180,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member)
- @defer.inlineCallbacks
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
@@ -189,16 +188,15 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
- yield self._push_update(
+ self._push_update(
member=member,
typing=False,
)
- @defer.inlineCallbacks
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- yield self._push_remote(member, typing)
+ preserve_fn(self._push_remote)(member, typing)
self._push_update_local(
member=member,
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
new file mode 100644
index 0000000000..2a49456bfc
--- /dev/null
+++ b/synapse/handlers/user_directory.py
@@ -0,0 +1,641 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.storage.roommember import ProfileInfo
+from synapse.util.metrics import Measure
+from synapse.util.async import sleep
+
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoyHandler(object):
+ """Handles querying of and keeping updated the user_directory.
+
+ N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
+
+ The user directory is filled with users who this server can see are joined to a
+ world_readable or publically joinable room. We keep a database table up to date
+ by streaming changes of the current state and recalculating whether users should
+ be in the directory or not when necessary.
+
+ For each user in the directory we also store a room_id which is public and that the
+ user is joined to. This allows us to ignore history_visibility and join_rules changes
+ for that user in all other public rooms, as we know they'll still be in at least
+ one public room.
+ """
+
+ INITIAL_SLEEP_MS = 50
+ INITIAL_SLEEP_COUNT = 100
+ INITIAL_BATCH_SIZE = 100
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.update_user_directory = hs.config.update_user_directory
+
+ # When start up for the first time we need to populate the user_directory.
+ # This is a set of user_id's we've inserted already
+ self.initially_handled_users = set()
+ self.initially_handled_users_in_public = set()
+
+ self.initially_handled_users_share = set()
+ self.initially_handled_users_share_private_room = set()
+
+ # The current position in the current_state_delta stream
+ self.pos = None
+
+ # Guard to ensure we only process deltas one at a time
+ self._is_processing = False
+
+ if self.update_user_directory:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # We kick this off so that we don't have to wait for a change before
+ # we start populating the user directory
+ self.clock.call_later(0, self.notify_new_event)
+
+ def search_users(self, user_id, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ return self.store.search_user_dir(user_id, search_term, limit)
+
+ @defer.inlineCallbacks
+ def notify_new_event(self):
+ """Called when there may be more deltas to process
+ """
+ if not self.update_user_directory:
+ return
+
+ if self._is_processing:
+ return
+
+ self._is_processing = True
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # If self.pos is None then means we haven't fetched it from DB
+ if self.pos is None:
+ self.pos = yield self.store.get_user_directory_stream_pos()
+
+ # If still None then we need to do the initial fill of directory
+ if self.pos is None:
+ yield self._do_initial_spam()
+ self.pos = yield self.store.get_user_directory_stream_pos()
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "user_dir_delta"):
+ deltas = yield self.store.get_current_state_deltas(self.pos)
+ if not deltas:
+ return
+
+ logger.info("Handling %d state deltas", len(deltas))
+ yield self._handle_deltas(deltas)
+
+ self.pos = deltas[-1]["stream_id"]
+ yield self.store.update_user_directory_stream_pos(self.pos)
+
+ @defer.inlineCallbacks
+ def _do_initial_spam(self):
+ """Populates the user_directory from the current state of the DB, used
+ when synapse first starts with user_directory support
+ """
+ new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
+
+ # Delete any existing entries just in case there are any
+ yield self.store.delete_all_from_user_dir()
+
+ # We process by going through each existing room at a time.
+ room_ids = yield self.store.get_all_rooms()
+
+ logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
+ num_processed_rooms = 1
+
+ for room_id in room_ids:
+ logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
+ yield self._handle_intial_room(room_id)
+ num_processed_rooms += 1
+ yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+
+ logger.info("Processed all rooms.")
+
+ self.initially_handled_users = None
+ self.initially_handled_users_in_public = None
+ self.initially_handled_users_share = None
+ self.initially_handled_users_share_private_room = None
+
+ yield self.store.update_user_directory_stream_pos(new_pos)
+
+ @defer.inlineCallbacks
+ def _handle_intial_room(self, room_id):
+ """Called when we initially fill out user_directory one room at a time
+ """
+ is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
+ if not is_in_room:
+ return
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ user_ids = set(users_with_profile)
+ unhandled_users = user_ids - self.initially_handled_users
+
+ yield self.store.add_profiles_to_user_dir(
+ room_id, {
+ user_id: users_with_profile[user_id] for user_id in unhandled_users
+ }
+ )
+
+ self.initially_handled_users |= unhandled_users
+
+ if is_public:
+ yield self.store.add_users_to_public_room(
+ room_id,
+ user_ids=user_ids - self.initially_handled_users_in_public
+ )
+ self.initially_handled_users_in_public |= user_ids
+
+ # We now go and figure out the new users who share rooms with user entries
+ # We sleep aggressively here as otherwise it can starve resources.
+ # We also batch up inserts/updates, but try to avoid too many at once.
+ to_insert = set()
+ to_update = set()
+ count = 0
+ for user_id in user_ids:
+ if count % self.INITIAL_SLEEP_COUNT == 0:
+ yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+
+ if not self.is_mine_id(user_id):
+ count += 1
+ continue
+
+ if self.store.get_if_app_services_interested_in_user(user_id):
+ count += 1
+ continue
+
+ for other_user_id in user_ids:
+ if user_id == other_user_id:
+ continue
+
+ if count % self.INITIAL_SLEEP_COUNT == 0:
+ yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ count += 1
+
+ user_set = (user_id, other_user_id)
+
+ if user_set in self.initially_handled_users_share_private_room:
+ continue
+
+ if user_set in self.initially_handled_users_share:
+ if is_public:
+ continue
+ to_update.add(user_set)
+ else:
+ to_insert.add(user_set)
+
+ if is_public:
+ self.initially_handled_users_share.add(user_set)
+ else:
+ self.initially_handled_users_share_private_room.add(user_set)
+
+ if len(to_insert) > self.INITIAL_BATCH_SIZE:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+ to_insert.clear()
+
+ if len(to_update) > self.INITIAL_BATCH_SIZE:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+ to_update.clear()
+
+ if to_insert:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+ to_insert.clear()
+
+ if to_update:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+ to_update.clear()
+
+ @defer.inlineCallbacks
+ def _handle_deltas(self, deltas):
+ """Called with the state deltas to process
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ prev_event_id = delta["prev_event_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ # For join rule and visibility changes we need to check if the room
+ # may have become public or not and add/remove the users in said room
+ if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
+ yield self._handle_room_publicity_change(
+ room_id, prev_event_id, event_id, typ,
+ )
+ elif typ == EventTypes.Member:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="membership",
+ public_value=Membership.JOIN,
+ )
+
+ if change is None:
+ # Handle any profile changes
+ yield self._handle_profile_change(
+ state_key, room_id, prev_event_id, event_id,
+ )
+ continue
+
+ if not change:
+ # Need to check if the server left the room entirely, if so
+ # we might need to remove all the users in that room
+ is_in_room = yield self.store.is_host_joined(
+ room_id, self.server_name,
+ )
+ if not is_in_room:
+ logger.info("Server left room: %r", room_id)
+ # Fetch all the users that we marked as being in user
+ # directory due to being in the room and then check if
+ # need to remove those users or not
+ user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
+ for user_id in user_ids:
+ yield self._handle_remove_user(room_id, user_id)
+ return
+ else:
+ logger.debug("Server is still in room: %r", room_id)
+
+ if change: # The user joined
+ event = yield self.store.get_event(event_id, allow_none=True)
+ profile = ProfileInfo(
+ avatar_url=event.content.get("avatar_url"),
+ display_name=event.content.get("displayname"),
+ )
+
+ yield self._handle_new_user(room_id, state_key, profile)
+ else: # The user left
+ yield self._handle_remove_user(room_id, state_key)
+ else:
+ logger.debug("Ignoring irrelevant type: %r", typ)
+
+ @defer.inlineCallbacks
+ def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
+ """Handle a room having potentially changed from/to world_readable/publically
+ joinable.
+
+ Args:
+ room_id (str)
+ prev_event_id (str|None): The previous event before the state change
+ event_id (str|None): The new event after the state change
+ typ (str): Type of the event
+ """
+ logger.debug("Handling change for %s: %s", typ, room_id)
+
+ if typ == EventTypes.RoomHistoryVisibility:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="history_visibility",
+ public_value="world_readable",
+ )
+ elif typ == EventTypes.JoinRules:
+ change = yield self._get_key_change(
+ prev_event_id, event_id,
+ key_name="join_rule",
+ public_value=JoinRules.PUBLIC,
+ )
+ else:
+ raise Exception("Invalid event type")
+ # If change is None, no change. True => become world_readable/public,
+ # False => was world_readable/public
+ if change is None:
+ logger.debug("No change")
+ return
+
+ # There's been a change to or from being world readable.
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ logger.debug("Change: %r, is_public: %r", change, is_public)
+
+ if change and not is_public:
+ # If we became world readable but room isn't currently public then
+ # we ignore the change
+ return
+ elif not change and is_public:
+ # If we stopped being world readable but are still public,
+ # ignore the change
+ return
+
+ if change:
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ for user_id, profile in users_with_profile.iteritems():
+ yield self._handle_new_user(room_id, user_id, profile)
+ else:
+ users = yield self.store.get_users_in_public_due_to_room(room_id)
+ for user_id in users:
+ yield self._handle_remove_user(room_id, user_id)
+
+ @defer.inlineCallbacks
+ def _handle_new_user(self, room_id, user_id, profile):
+ """Called when we might need to add user to directory
+
+ Args:
+ room_id (str): room_id that user joined or started being public that
+ user_id (str)
+ """
+ logger.debug("Adding user to dir, %r", user_id)
+
+ row = yield self.store.get_user_in_directory(user_id)
+ if not row:
+ yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ if is_public:
+ row = yield self.store.get_user_in_public_room(user_id)
+ if not row:
+ yield self.store.add_users_to_public_room(room_id, [user_id])
+ else:
+ logger.debug("Not adding user to public dir, %r", user_id)
+
+ # Now we update users who share rooms with users. We do this by getting
+ # all the current users in the room and seeing which aren't already
+ # marked in the database as sharing with `user_id`
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ to_insert = set()
+ to_update = set()
+
+ is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+
+ # First, if they're our user then we need to update for every user
+ if self.is_mine_id(user_id) and not is_appservice:
+ # Returns a map of other_user_id -> shared_private. We only need
+ # to update mappings if for users that either don't share a room
+ # already (aren't in the map) or, if the room is private, those that
+ # only share a public room.
+ user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
+ user_id
+ )
+
+ for other_user_id in users_with_profile:
+ if user_id == other_user_id:
+ continue
+
+ shared_is_private = user_ids_shared.get(other_user_id)
+ if shared_is_private is True:
+ # We've already marked in the database they share a private room
+ continue
+ elif shared_is_private is False:
+ # They already share a public room, so only update if this is
+ # a private room
+ if not is_public:
+ to_update.add((user_id, other_user_id))
+ elif shared_is_private is None:
+ # This is the first time they both share a room
+ to_insert.add((user_id, other_user_id))
+
+ # Next we need to update for every local user in the room
+ for other_user_id in users_with_profile:
+ if user_id == other_user_id:
+ continue
+
+ is_appservice = self.store.get_if_app_services_interested_in_user(
+ other_user_id
+ )
+ if self.is_mine_id(other_user_id) and not is_appservice:
+ shared_is_private = yield self.store.get_if_users_share_a_room(
+ other_user_id, user_id,
+ )
+ if shared_is_private is True:
+ # We've already marked in the database they share a private room
+ continue
+ elif shared_is_private is False:
+ # They already share a public room, so only update if this is
+ # a private room
+ if not is_public:
+ to_update.add((other_user_id, user_id))
+ elif shared_is_private is None:
+ # This is the first time they both share a room
+ to_insert.add((other_user_id, user_id))
+
+ if to_insert:
+ yield self.store.add_users_who_share_room(
+ room_id, not is_public, to_insert,
+ )
+
+ if to_update:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, to_update,
+ )
+
+ @defer.inlineCallbacks
+ def _handle_remove_user(self, room_id, user_id):
+ """Called when we might need to remove user to directory
+
+ Args:
+ room_id (str): room_id that user left or stopped being public that
+ user_id (str)
+ """
+ logger.debug("Maybe removing user %r", user_id)
+
+ row = yield self.store.get_user_in_directory(user_id)
+ update_user_dir = row and row["room_id"] == room_id
+
+ row = yield self.store.get_user_in_public_room(user_id)
+ update_user_in_public = row and row["room_id"] == room_id
+
+ if (update_user_in_public or update_user_dir):
+ # XXX: Make this faster?
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ for j_room_id in rooms:
+ if (not update_user_in_public and not update_user_dir):
+ break
+
+ is_in_room = yield self.store.is_host_joined(
+ j_room_id, self.server_name,
+ )
+
+ if not is_in_room:
+ continue
+
+ if update_user_dir:
+ update_user_dir = False
+ yield self.store.update_user_in_user_dir(user_id, j_room_id)
+
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ j_room_id
+ )
+
+ if update_user_in_public and is_public:
+ yield self.store.update_user_in_public_user_list(user_id, j_room_id)
+ update_user_in_public = False
+
+ if update_user_dir:
+ yield self.store.remove_from_user_dir(user_id)
+ elif update_user_in_public:
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ # Now handle users_who_share_rooms.
+
+ # Get a list of user tuples that were in the DB due to this room and
+ # users (this includes tuples where the other user matches `user_id`)
+ user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
+ user_id, room_id,
+ )
+
+ for user_id, other_user_id in user_tuples:
+ # For each user tuple get a list of rooms that they still share,
+ # trying to find a private room, and update the entry in the DB
+ rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
+
+ # If they dont share a room anymore, remove the mapping
+ if not rooms:
+ yield self.store.remove_user_who_share_room(
+ user_id, other_user_id,
+ )
+ continue
+
+ found_public_share = None
+ for j_room_id in rooms:
+ is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ j_room_id
+ )
+
+ if is_public:
+ found_public_share = j_room_id
+ else:
+ found_public_share = None
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, [(user_id, other_user_id)],
+ )
+ break
+
+ if found_public_share:
+ yield self.store.update_users_who_share_room(
+ room_id, not is_public, [(user_id, other_user_id)],
+ )
+
+ @defer.inlineCallbacks
+ def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ """Check member event changes for any profile changes and update the
+ database if there are.
+ """
+ if not prev_event_id or not event_id:
+ return
+
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ event = yield self.store.get_event(event_id, allow_none=True)
+
+ if not prev_event or not event:
+ return
+
+ if event.membership != Membership.JOIN:
+ return
+
+ prev_name = prev_event.content.get("displayname")
+ new_name = event.content.get("displayname")
+
+ prev_avatar = prev_event.content.get("avatar_url")
+ new_avatar = event.content.get("avatar_url")
+
+ if prev_name != new_name or prev_avatar != new_avatar:
+ yield self.store.update_profile_in_user_dir(
+ user_id, new_name, new_avatar, room_id,
+ )
+
+ @defer.inlineCallbacks
+ def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ """Given two events check if the `key_name` field in content changed
+ from not matching `public_value` to doing so.
+
+ For example, check if `history_visibility` (`key_name`) changed from
+ `shared` to `world_readable` (`public_value`).
+
+ Returns:
+ None if the field in the events either both match `public_value`
+ or if neither do, i.e. there has been no change.
+ True if it didnt match `public_value` but now does
+ False if it did match `public_value` but now doesn't
+ """
+ prev_event = None
+ event = None
+ if prev_event_id:
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+
+ if event_id:
+ event = yield self.store.get_event(event_id, allow_none=True)
+
+ if not event and not prev_event:
+ logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
+ defer.returnValue(None)
+
+ prev_value = None
+ value = None
+
+ if prev_event:
+ prev_value = prev_event.content.get(key_name)
+
+ if event:
+ value = event.content.get(key_name)
+
+ logger.debug("prev_value: %r -> value: %r", prev_value, value)
+
+ if value == public_value and prev_value != public_value:
+ defer.returnValue(True)
+ elif value != public_value and prev_value == public_value:
+ defer.returnValue(False)
+ else:
+ defer.returnValue(None)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 14715878c5..7ef3d526b1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -412,7 +412,7 @@ def set_cors_headers(request):
)
request.setHeader(
"Access-Control-Allow-Headers",
- "Origin, X-Requested-With, Content-Type, Accept"
+ "Origin, X-Requested-With, Content-Type, Accept, Authorization"
)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 48566187ab..385208b574 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -251,7 +251,8 @@ class Notifier(object):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
preserve_fn(self.appservice_handler.notify_interested_services)(
- room_stream_id)
+ room_stream_id
+ )
if self.federation_sender:
preserve_fn(self.federation_sender.notify_new_events)(
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 3f75d3f921..fe09d50d55 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from .bulk_push_rule_evaluator import evaluator_for_event
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
@@ -24,11 +24,12 @@ import logging
logger = logging.getLogger(__name__)
-class ActionGenerator:
+class ActionGenerator(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user.
@@ -38,16 +39,11 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
- with Measure(self.clock, "evaluator_for_event"):
- bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context
- )
-
with Measure(self.clock, "action_for_event_by_user"):
- actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+ actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
event, context
)
context.push_actions = [
- (uid, actions) for uid, actions in actions_by_user.items()
+ (uid, actions) for uid, actions in actions_by_user.iteritems()
]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index f943ff640f..9a96e6fe8f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,60 +19,83 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients_context
+from synapse.api.constants import EventTypes, Membership
+from synapse.util.caches.descriptors import cached
+from synapse.util.async import Linearizer
+from collections import namedtuple
-logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, context):
- rules_by_user = yield store.bulk_get_push_rules_for_room(
- event, context
- )
-
- # if this event is an invite event, we may need to run rules for the user
- # who's been invited, otherwise they won't get told they've been invited
- if event.type == 'm.room.member' and event.content['membership'] == 'invite':
- invited_user = event.state_key
- if invited_user and hs.is_mine_id(invited_user):
- has_pusher = yield store.user_has_pusher(invited_user)
- if has_pusher:
- rules_by_user = dict(rules_by_user)
- rules_by_user[invited_user] = yield store.get_push_rules_for_user(
- invited_user
- )
- defer.returnValue(BulkPushRuleEvaluator(
- event.room_id, rules_by_user, store
- ))
+rules_by_room = {}
-class BulkPushRuleEvaluator:
+class BulkPushRuleEvaluator(object):
+ """Calculates the outcome of push rules for an event for all users in the
+ room at once.
"""
- Runs push rules for all users in a room.
- This is faster than running PushRuleEvaluator for each user because it
- fetches all the rules for all the users in one (batched) db query
- rather than doing multiple queries per-user. It currently uses
- the same logic to run the actual rules, but could be optimised further
- (see https://matrix.org/jira/browse/SYN-562)
- """
- def __init__(self, room_id, rules_by_user, store):
- self.room_id = room_id
- self.rules_by_user = rules_by_user
- self.store = store
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def _get_rules_for_event(self, event, context):
+ """This gets the rules for all users in the room at the time of the event,
+ as well as the push rules for the invitee if the event is an invite.
+
+ Returns:
+ dict of user_id -> push_rules
+ """
+ room_id = event.room_id
+ rules_for_room = self._get_rules_for_room(room_id)
+
+ rules_by_user = yield rules_for_room.get_rules(event, context)
+
+ # if this event is an invite event, we may need to run rules for the user
+ # who's been invited, otherwise they won't get told they've been invited
+ if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+ invited = event.state_key
+ if invited and self.hs.is_mine_id(invited):
+ has_pusher = yield self.store.user_has_pusher(invited)
+ if has_pusher:
+ rules_by_user = dict(rules_by_user)
+ rules_by_user[invited] = yield self.store.get_push_rules_for_user(
+ invited
+ )
+
+ defer.returnValue(rules_by_user)
+
+ @cached()
+ def _get_rules_for_room(self, room_id):
+ """Get the current RulesForRoom object for the given room id
+
+ Returns:
+ RulesForRoom
+ """
+ # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+ # before any lookup methods get called on it as otherwise there may be
+ # a race if invalidate_all gets called (which assumes its in the cache)
+ return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
+ """Given an event and context, evaluate the push rules and return
+ the results
+
+ Returns:
+ dict of user_id -> action
+ """
+ rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {}
# None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all
# actually in the room.
- user_tuples = [
- (u, False) for u in self.rules_by_user.keys()
- ]
+ user_tuples = [(u, False) for u in rules_by_user]
filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: context}
@@ -86,7 +109,7 @@ class BulkPushRuleEvaluator:
condition_cache = {}
- for uid, rules in self.rules_by_user.items():
+ for uid, rules in rules_by_user.iteritems():
display_name = None
profile_info = room_members.get(uid)
if profile_info:
@@ -138,3 +161,240 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
return False
return True
+
+
+class RulesForRoom(object):
+ """Caches push rules for users in a room.
+
+ This efficiently handles users joining/leaving the room by not invalidating
+ the entire cache for the room.
+ """
+
+ def __init__(self, hs, room_id, rules_for_room_cache):
+ """
+ Args:
+ hs (HomeServer)
+ room_id (str)
+ rules_for_room_cache(Cache): The cache object that caches these
+ RoomsForUser objects.
+ """
+ self.room_id = room_id
+ self.is_mine_id = hs.is_mine_id
+ self.store = hs.get_datastore()
+
+ self.linearizer = Linearizer(name="rules_for_room")
+
+ self.member_map = {} # event_id -> (user_id, state)
+ self.rules_by_user = {} # user_id -> rules
+
+ # The last state group we updated the caches for. If the state_group of
+ # a new event comes along, we know that we can just return the cached
+ # result.
+ # On invalidation of the rules themselves (if the user changes them),
+ # we invalidate everything and set state_group to `object()`
+ self.state_group = object()
+
+ # A sequence number to keep track of when we're allowed to update the
+ # cache. We bump the sequence number when we invalidate the cache. If
+ # the sequence number changes while we're calculating stuff we should
+ # not update the cache with it.
+ self.sequence = 0
+
+ # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+ # owned by AS's, or remote users, etc. (I.e. users we will never need to
+ # calculate push for)
+ # These never need to be invalidated as we will never set up push for
+ # them.
+ self.uninteresting_user_set = set()
+
+ # We need to be clever on the invalidating caches callbacks, as
+ # otherwise the invalidation callback holds a reference to the object,
+ # potentially causing it to leak.
+ # To get around this we pass a function that on invalidations looks ups
+ # the RoomsForUser entry in the cache, rather than keeping a reference
+ # to self around in the callback.
+ self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
+
+ @defer.inlineCallbacks
+ def get_rules(self, event, context):
+ """Given an event context return the rules for all users who are
+ currently in the room.
+ """
+ state_group = context.state_group
+
+ with (yield self.linearizer.queue(())):
+ if state_group and self.state_group == state_group:
+ logger.debug("Using cached rules for %r", self.room_id)
+ defer.returnValue(self.rules_by_user)
+
+ ret_rules_by_user = {}
+ missing_member_event_ids = {}
+ if state_group and self.state_group == context.prev_group:
+ # If we have a simple delta then we can reuse most of the previous
+ # results.
+ ret_rules_by_user = self.rules_by_user
+ current_state_ids = context.delta_ids
+ else:
+ current_state_ids = context.current_state_ids
+
+ logger.debug(
+ "Looking for member changes in %r %r", state_group, current_state_ids
+ )
+
+ # Loop through to see which member events we've seen and have rules
+ # for and which we need to fetch
+ for key in current_state_ids:
+ typ, user_id = key
+ if typ != EventTypes.Member:
+ continue
+
+ if user_id in self.uninteresting_user_set:
+ continue
+
+ if not self.is_mine_id(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ if self.store.get_if_app_services_interested_in_user(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ event_id = current_state_ids[key]
+
+ res = self.member_map.get(event_id, None)
+ if res:
+ user_id, state = res
+ if state == Membership.JOIN:
+ rules = self.rules_by_user.get(user_id, None)
+ if rules:
+ ret_rules_by_user[user_id] = rules
+ continue
+
+ # If a user has left a room we remove their push rule. If they
+ # joined then we readd it later in _update_rules_with_member_event_ids
+ ret_rules_by_user.pop(user_id, None)
+ missing_member_event_ids[user_id] = event_id
+
+ if missing_member_event_ids:
+ # If we have some memebr events we haven't seen, look them up
+ # and fetch push rules for them if appropriate.
+ logger.debug("Found new member events %r", missing_member_event_ids)
+ yield self._update_rules_with_member_event_ids(
+ ret_rules_by_user, missing_member_event_ids, state_group, event
+ )
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "Returning push rules for %r %r",
+ self.room_id, ret_rules_by_user.keys(),
+ )
+ defer.returnValue(ret_rules_by_user)
+
+ @defer.inlineCallbacks
+ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
+ state_group, event):
+ """Update the partially filled rules_by_user dict by fetching rules for
+ any newly joined users in the `member_event_ids` list.
+
+ Args:
+ ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ updated with any new rules.
+ member_event_ids (list): List of event ids for membership events that
+ have happened since the last time we filled rules_by_user
+ state_group: The state group we are currently computing push rules
+ for. Used when updating the cache.
+ """
+ sequence = self.sequence
+
+ rows = yield self.store._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids.values(),
+ retcols=('user_id', 'membership', 'event_id'),
+ keyvalues={},
+ batch_size=500,
+ desc="_get_rules_for_member_event_ids",
+ )
+
+ members = {
+ row["event_id"]: (row["user_id"], row["membership"])
+ for row in rows
+ }
+
+ # If the event is a join event then it will be in current state evnts
+ # map but not in the DB, so we have to explicitly insert it.
+ if event.type == EventTypes.Member:
+ for event_id in member_event_ids.itervalues():
+ if event_id == event.event_id:
+ members[event_id] = (event.state_key, event.membership)
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Found members %r: %r", self.room_id, members.values())
+
+ interested_in_user_ids = set(
+ user_id for user_id, membership in members.itervalues()
+ if membership == Membership.JOIN
+ )
+
+ logger.debug("Joined: %r", interested_in_user_ids)
+
+ if_users_with_pushers = yield self.store.get_if_users_have_pushers(
+ interested_in_user_ids,
+ on_invalidate=self.invalidate_all_cb,
+ )
+
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ )
+
+ logger.debug("With pushers: %r", user_ids)
+
+ users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
+ self.room_id, on_invalidate=self.invalidate_all_cb,
+ )
+
+ logger.debug("With receipts: %r", users_with_receipts)
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in interested_in_user_ids:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.store.bulk_get_push_rules(
+ user_ids, on_invalidate=self.invalidate_all_cb,
+ )
+
+ ret_rules_by_user.update(
+ item for item in rules_by_user.iteritems() if item[0] is not None
+ )
+
+ self.update_cache(sequence, members, ret_rules_by_user, state_group)
+
+ def invalidate_all(self):
+ # Note: Don't hand this function directly to an invalidation callback
+ # as it keeps a reference to self and will stop this instance from being
+ # GC'd if it gets dropped from the rules_to_user cache. Instead use
+ # `self.invalidate_all_cb`
+ logger.debug("Invalidating RulesForRoom for %r", self.room_id)
+ self.sequence += 1
+ self.state_group = object()
+ self.member_map = {}
+ self.rules_by_user = {}
+
+ def update_cache(self, sequence, members, rules_by_user, state_group):
+ if sequence == self.sequence:
+ self.member_map.update(members)
+ self.rules_by_user = rules_by_user
+ self.state_group = state_group
+
+
+class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
+ # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+ # which namedtuple does for us (i.e. two _CacheContext are the same if
+ # their caches and keys match). This is important in particular to
+ # dedupe when we add callbacks to lru cache nodes, otherwise the number
+ # of callbacks would grow.
+ def __call__(self):
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
+ if rules:
+ rules.invalidate_all()
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c7afd11111..a69dda7b09 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext
-from mailer import Mailer
logger = logging.getLogger(__name__)
@@ -56,8 +55,10 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
- def __init__(self, hs, pusherdict):
+ def __init__(self, hs, pusherdict, mailer):
self.hs = hs
+ self.mailer = mailer
+
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id']
@@ -73,16 +74,6 @@ class EmailPusher(object):
self.processing = False
- if self.hs.config.email_enable_notifs:
- if 'data' in pusherdict and 'brand' in pusherdict['data']:
- app_name = pusherdict['data']['brand']
- else:
- app_name = self.hs.config.email_app_name
-
- self.mailer = Mailer(self.hs, app_name)
- else:
- self.mailer = None
-
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c0f8176e3d..8a5d473108 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -275,7 +275,7 @@ class HttpPusher(object):
if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id
- if 'content' in event:
+ if not self.hs.config.push_redact_content and 'content' in event:
d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index f83aa7625c..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
class Mailer(object):
- def __init__(self, hs, app_name):
+ def __init__(self, hs, app_name, notif_template_html, notif_template_text):
self.hs = hs
+ self.notif_template_html = notif_template_html
+ self.notif_template_text = notif_template_text
+
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
- loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
+
logger.info("Created Mailer for app_name %s" % app_name)
- env = jinja2.Environment(loader=loader)
- env.filters["format_ts"] = format_ts_filter
- env.filters["mxc_to_http"] = self.mxc_to_http_filter
- self.notif_template_html = env.get_template(
- self.hs.config.email_notif_template_html
- )
- self.notif_template_text = env.get_template(
- self.hs.config.email_notif_template_text
- )
@defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address,
@@ -481,28 +475,6 @@ class Mailer(object):
urllib.urlencode(params),
)
- def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- serverAndMediaId = value[6:]
- fragment = None
- if '#' in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
- fragment = "#" + fragment
-
- params = {
- "width": width,
- "height": height,
- "method": resize_method,
- }
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- self.hs.config.public_baseurl,
- serverAndMediaId,
- urllib.urlencode(params),
- fragment or "",
- )
-
def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -543,3 +515,52 @@ def string_ordinal_total(s):
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+ """Load the jinja2 email templates from disk
+
+ Returns:
+ (notif_template_html, notif_template_text)
+ """
+ logger.info("loading jinja2")
+
+ loader = jinja2.FileSystemLoader(config.email_template_dir)
+ env = jinja2.Environment(loader=loader)
+ env.filters["format_ts"] = format_ts_filter
+ env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+ notif_template_html = env.get_template(
+ config.email_notif_template_html
+ )
+ notif_template_text = env.get_template(
+ config.email_notif_template_text
+ )
+
+ return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+ def mxc_to_http_filter(value, width, height, resize_method="crop"):
+ if value[0:6] != "mxc://":
+ return ""
+
+ serverAndMediaId = value[6:]
+ fragment = None
+ if '#' in serverAndMediaId:
+ (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+ fragment = "#" + fragment
+
+ params = {
+ "width": width,
+ "height": height,
+ "method": resize_method,
+ }
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ config.public_baseurl,
+ serverAndMediaId,
+ urllib.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..491f27bded 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
+ from synapse.push.mailer import Mailer, load_jinja2_templates
except:
pass
-def create_pusher(hs, pusherdict):
- logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+ def __init__(self, hs):
+ self.hs = hs
- PUSHER_TYPES = {
- "http": HttpPusher,
- }
+ self.pusher_types = {
+ "http": HttpPusher,
+ }
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
- PUSHER_TYPES["email"] = EmailPusher
- logger.info("defined email pusher type")
+ logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+ if hs.config.email_enable_notifs:
+ self.mailers = {} # app_name -> Mailer
- if pusherdict['kind'] in PUSHER_TYPES:
- logger.info("found pusher")
- return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+ templates = load_jinja2_templates(hs.config)
+ self.notif_template_html, self.notif_template_text = templates
+
+ self.pusher_types["email"] = self._create_email_pusher
+
+ logger.info("defined email pusher type")
+
+ def create_pusher(self, pusherdict):
+ logger.info("trying to create_pusher for %r", pusherdict)
+
+ if pusherdict['kind'] in self.pusher_types:
+ logger.info("found pusher")
+ return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+ def _create_email_pusher(self, _hs, pusherdict):
+ app_name = self._app_name_from_pusherdict(pusherdict)
+ mailer = self.mailers.get(app_name)
+ if not mailer:
+ mailer = Mailer(
+ hs=self.hs,
+ app_name=app_name,
+ notif_template_html=self.notif_template_html,
+ notif_template_text=self.notif_template_text,
+ )
+ self.mailers[app_name] = mailer
+ return EmailPusher(self.hs, pusherdict, mailer)
+
+ def _app_name_from_pusherdict(self, pusherdict):
+ if 'data' in pusherdict and 'brand' in pusherdict['data']:
+ app_name = pusherdict['data']['brand']
+ else:
+ app_name = self.hs.config.email_app_name
+
+ return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..43cb6e9c01 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
-import pusher
+from .pusher import PusherFactory
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
+ self.pusher_factory = PusherFactory(_hs)
self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
- pusher.create_pusher(self.hs, {
+ self.pusher_factory.create_pusher({
"id": None,
"user_name": user_id,
"kind": kind,
@@ -186,7 +187,7 @@ class PusherPool:
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
try:
- p = pusher.create_pusher(self.hs, pusherdict)
+ p = self.pusher_factory.create_pusher(pusherdict)
except:
logger.exception("Couldn't start a pusher: caught Exception")
continue
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a374f2f1a2..0d3f31a50c 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -16,6 +16,7 @@
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
+from synapse.storage.appservice import _make_exclusive_regex
class SlavedApplicationServiceStore(BaseSlavedStore):
@@ -25,6 +26,7 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
hs.config.server_name,
hs.config.app_service_config_files
)
+ self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
@@ -38,3 +40,6 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__
+ get_if_app_services_interested_in_user = (
+ DataStore.get_if_app_services_interested_in_user.__func__
+ )
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
new file mode 100644
index 0000000000..65250285e8
--- /dev/null
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import Cache
+
+
+class SlavedClientIpStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedClientIpStore, self).__init__(db_conn, hs)
+
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
+ )
+
+ def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
+ now = int(self._clock.time_msec())
+ user_id = user.to_string()
+ key = (user_id, access_token, ip)
+
+ try:
+ last_seen = self.client_ip_last_seen.get(key)
+ except KeyError:
+ last_seen = None
+
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ return
+
+ self.hs.get_tcp_replication().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 4d4a435471..7687867aee 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,6 +16,7 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
+from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -45,6 +46,7 @@ class SlavedDeviceStore(BaseSlavedStore):
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
+ count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index fcaf58b93b..94ebbffc1b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore):
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
+ get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
+ _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = (
@@ -151,8 +153,7 @@ class SlavedEventStore(BaseSlavedStore):
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
- is_host_joined = DataStore.is_host_joined.__func__
- _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
+ is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 90fb6c1336..6d2513c4e2 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -20,6 +20,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ UserIpCommand,
)
from .protocol import ClientReplicationStreamProtocol
@@ -178,6 +179,12 @@ class ReplicationClientHandler(object):
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
+ def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ """Tell the master that the user made a request.
+ """
+ cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+ self.send_command(cmd)
+
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 84d2a2272a..a009214e43 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -304,6 +304,36 @@ class InvalidateCacheCommand(Command):
return " ".join((self.cache_func, json.dumps(self.keys)))
+class UserIpCommand(Command):
+ """Sent periodically when a worker sees activity from a client.
+
+ Format::
+
+ USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
+ """
+ NAME = "USER_IP"
+
+ def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ self.user_id = user_id
+ self.access_token = access_token
+ self.ip = ip
+ self.user_agent = user_agent
+ self.device_id = device_id
+ self.last_seen = last_seen
+
+ @classmethod
+ def from_line(cls, line):
+ user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
+
+ return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
+
+ def to_line(self):
+ return " ".join((
+ self.user_id, self.access_token, self.ip, self.device_id,
+ str(self.last_seen), self.user_agent,
+ ))
+
+
# Map of command name to command type.
COMMAND_MAP = {
cmd.NAME: cmd
@@ -320,6 +350,7 @@ COMMAND_MAP = {
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
+ UserIpCommand,
)
}
@@ -342,5 +373,6 @@ VALID_CLIENT_COMMANDS = (
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
+ UserIpCommand.NAME,
ErrorCommand.NAME,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 9fee2a484b..062272f8dd 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -406,6 +406,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ def on_USER_IP(self, cmd):
+ self.streamer.on_user_ip(
+ cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ cmd.last_seen,
+ )
+
@defer.inlineCallbacks
def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a streams.
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 8b2c4c3043..3ea3ca5a6f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -35,6 +35,7 @@ user_sync_counter = metrics.register_counter("user_sync")
federation_ack_counter = metrics.register_counter("federation_ack")
remove_pusher_counter = metrics.register_counter("remove_pusher")
invalidate_cache_counter = metrics.register_counter("invalidate_cache")
+user_ip_cache_counter = metrics.register_counter("user_ip_cache")
logger = logging.getLogger(__name__)
@@ -67,6 +68,7 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
# Current connections.
self.connections = []
@@ -99,7 +101,7 @@ class ReplicationStreamer(object):
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
- hs.get_notifier().add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -237,6 +239,15 @@ class ReplicationStreamer(object):
invalidate_cache_counter.inc()
getattr(self.store, cache_func).invalidate(tuple(keys))
+ @measure_func("repl.on_user_ip")
+ def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ """The client saw a user request
+ """
+ user_ip_cache_counter.inc()
+ self.store.insert_client_ip(
+ user_id, access_token, ip, user_agent, device_id, last_seen,
+ )
+
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 369d5f2428..fbafe12cc2 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -112,6 +112,12 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
"data_type", # str
"data", # dict
))
+CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
+ "room_id", # str
+ "type", # str
+ "state_key", # str
+ "event_id", # str, optional
+))
class Stream(object):
@@ -443,6 +449,21 @@ class AccountDataStream(Stream):
defer.returnValue(results)
+class CurrentStateDeltaStream(Stream):
+ """Current state for a room was changed
+ """
+ NAME = "current_state_deltas"
+ ROW_TYPE = CurrentStateDeltaStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_max_current_state_delta_stream_id
+ self.update_function = store.get_all_updated_current_state_deltas
+
+ super(CurrentStateDeltaStream, self).__init__(hs)
+
+
STREAMS_MAP = {
stream.NAME: stream
for stream in (
@@ -460,5 +481,6 @@ STREAMS_MAP = {
FederationStream,
TagAccountDataStream,
AccountDataStream,
+ CurrentStateDeltaStream,
)
}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index aa8d874f96..3d809d181b 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -51,6 +51,7 @@ from synapse.rest.client.v2_alpha import (
devices,
thirdparty,
sendtodevice,
+ user_directory,
)
from synapse.http.server import JsonResource
@@ -100,3 +101,4 @@ class ClientRestResource(JsonResource):
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
+ user_directory.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 29fcd72375..7d786e8de3 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -15,8 +15,9 @@
from twisted.internet import defer
+from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -157,6 +158,142 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
+class ShutdownRoomRestServlet(ClientV1RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+ PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violatation will be blocked."
+ )
+
+ def __init__(self, hs):
+ super(ShutdownRoomRestServlet, self).__init__(hs)
+ self.store = hs.get_datastore()
+ self.handlers = hs.get_handlers()
+ self.state = hs.get_state_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ content = parse_json_object_from_request(request)
+
+ new_room_user_id = content.get("new_room_user_id")
+ if not new_room_user_id:
+ raise SynapseError(400, "Please provide field `new_room_user_id`")
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info = yield self.handlers.room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {
+ "users_default": -10,
+ },
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ msg_handler = self.handlers.message_handler
+ yield msg_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info("Shutting down room %r", room_id)
+
+ yield self.store.block_room(room_id, requester_user_id)
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ kicked_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ target_requester = create_requester(user_id)
+ yield self.handlers.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False
+ )
+
+ yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
+
+ yield self.handlers.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False
+ )
+
+ kicked_users.append(user_id)
+
+ aliases_for_room = yield self.store.get_aliases_for_room(room_id)
+
+ yield self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ defer.returnValue((200, {
+ "kicked_users": kicked_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ }))
+
+
+class QuarantineMediaInRoom(ClientV1RestServlet):
+ """Quarantines all media in a room so that no one can download it via
+ this server.
+ """
+ PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
+
+ def __init__(self, hs):
+ super(QuarantineMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ num_quarantined = yield self.store.quarantine_media_ids_in_room(
+ room_id, requester.user.to_string(),
+ )
+
+ defer.returnValue((200, {"num_quarantined": num_quarantined}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This need a user have a administrator access in Synapse.
@@ -353,3 +490,5 @@ def register_servlets(hs, http_server):
ResetPasswordRestServlet(hs).register(http_server)
GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
+ ShutdownRoomRestServlet(hs).register(http_server)
+ QuarantineMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 771e127ab9..83e209d18f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -192,6 +192,7 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
+ "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
new file mode 100644
index 0000000000..6e012da4aa
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectorySearchRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/user_directory/search$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(UserDirectorySearchRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ user_id = requester.user.to_string()
+
+ body = parse_json_object_from_request(request)
+
+ limit = body.get("limit", 10)
+ limit = min(limit, 50)
+
+ try:
+ search_term = body["search_term"]
+ except:
+ raise SynapseError(400, "`search_term` is required field")
+
+ results = yield self.user_directory_handler.search_users(
+ user_id, search_term, limit,
+ )
+
+ defer.returnValue((200, results))
+
+
+def register_servlets(hs, http_server):
+ UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6788375e85..6879249c8a 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -66,14 +66,19 @@ class DownloadResource(Resource):
@defer.inlineCallbacks
def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
- if not media_info:
+ if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
- file_path = self.filepaths.local_media_filepath(media_id)
+ if media_info["url_cache"]:
+ # TODO: Check the file still exists, if it doesn't we can redownload
+ # it from the url `media_info["url_cache"]`
+ file_path = self.filepaths.url_cache_filepath(media_id)
+ else:
+ file_path = self.filepaths.local_media_filepath(media_id)
yield respond_with_file(
request, media_type, file_path, media_length,
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 0137458f71..d92b7ff337 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -71,3 +71,21 @@ class MediaFilePaths(object):
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)
+
+ def url_cache_filepath(self, media_id):
+ return os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4], media_id[4:]
+ )
+
+ def url_cache_thumbnail(self, media_id, width, height, content_type,
+ method):
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s-%s" % (
+ width, height, top_level_type, sub_type, method
+ )
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ file_name
+ )
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index caca96c222..0ea1248ce6 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -135,6 +135,8 @@ class MediaRepository(object):
media_info = yield self._download_remote_file(
server_name, media_id
)
+ elif media_info["quarantined_by"]:
+ raise NotFoundError()
else:
self.recently_accessed_remotes.add((server_name, media_id))
yield self.store.update_cached_last_access_time(
@@ -184,6 +186,7 @@ class MediaRepository(object):
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
@@ -323,13 +326,17 @@ class MediaRepository(object):
defer.returnValue(t_path)
@defer.inlineCallbacks
- def _generate_local_thumbnails(self, media_id, media_info):
+ def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
- input_path = self.filepaths.local_media_filepath(media_id)
+ if url_cache:
+ input_path = self.filepaths.url_cache_filepath(media_id)
+ else:
+ input_path = self.filepaths.local_media_filepath(media_id)
+
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -357,9 +364,14 @@ class MediaRepository(object):
for t_width, t_height, t_type in scales:
t_method = "scale"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
+ if url_cache:
+ t_path = self.filepaths.url_cache_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ else:
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
@@ -374,9 +386,14 @@ class MediaRepository(object):
# thumbnail.
continue
t_method = "crop"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
+ if url_cache:
+ t_path = self.filepaths.url_cache_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ else:
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 99760d622f..b81a336c5d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -164,7 +164,7 @@ class PreviewUrlResource(Resource):
if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails(
- media_info['filesystem_id'], media_info
+ media_info['filesystem_id'], media_info, url_cache=True,
)
og = {
@@ -210,7 +210,7 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
- image_info['filesystem_id'], image_info
+ image_info['filesystem_id'], image_info, url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -256,7 +256,7 @@ class PreviewUrlResource(Resource):
# XXX: horrible duplication with base_resource's _download_remote_file()
file_id = random_string(24)
- fname = self.filepaths.local_media_filepath(file_id)
+ fname = self.filepaths.url_cache_filepath(file_id)
self.media_repo._makedirs(fname)
try:
@@ -303,6 +303,7 @@ class PreviewUrlResource(Resource):
upload_name=download_name,
media_length=length,
user_id=user,
+ url_cache=url,
)
except Exception as e:
@@ -434,6 +435,8 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
+ else:
+ og['og:description'] = summarize_paragraphs([og['og:description']])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d8f54adc99..68d56b2b10 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -81,7 +81,7 @@ class ThumbnailResource(Resource):
method, m_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info:
+ if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -101,9 +101,16 @@ class ThumbnailResource(Resource):
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
+ if media_info["url_cache"]:
+ # TODO: Check the file still exists, if it doesn't we can redownload
+ # it from the url `media_info["url_cache"]`
+ file_path = self.filepaths.url_cache_thumbnail(
+ media_id, t_width, t_height, t_type, t_method,
+ )
+ else:
+ file_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method,
+ )
yield respond_with_file(request, t_type, file_path)
else:
@@ -117,7 +124,7 @@ class ThumbnailResource(Resource):
desired_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info:
+ if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -134,9 +141,18 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type, desired_method,
- )
+ if media_info["url_cache"]:
+ # TODO: Check the file still exists, if it doesn't we can redownload
+ # it from the url `media_info["url_cache"]`
+ file_path = self.filepaths.url_cache_thumbnail(
+ media_id, desired_width, desired_height, desired_type,
+ desired_method,
+ )
+ else:
+ file_path = self.filepaths.local_media_thumbnail(
+ media_id, desired_width, desired_height, desired_type,
+ desired_method,
+ )
yield respond_with_file(request, desired_type, file_path)
return
diff --git a/synapse/server.py b/synapse/server.py
index 12754c89ae..a38e5179e0 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -49,9 +49,11 @@ from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler
+from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
+from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler
@@ -135,6 +137,8 @@ class HomeServer(object):
'macaroon_generator',
'tcp_replication',
'read_marker_handler',
+ 'action_generator',
+ 'user_directory_handler',
]
def __init__(self, hostname, **kwargs):
@@ -299,6 +303,12 @@ class HomeServer(object):
def build_tcp_replication(self):
raise NotImplementedError()
+ def build_action_generator(self):
+ return ActionGenerator(self)
+
+ def build_user_directory_handler(self):
+ return UserDirectoyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/state.py b/synapse/state.py
index 02fee47f39..390799fbd5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
+from synapse.util.caches import CACHE_SIZE_FACTOR
from collections import namedtuple
from frozendict import frozendict
import logging
import hashlib
-import os
logger = logging.getLogger(__name__)
@@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
@@ -170,9 +167,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_state(
- room_id, entry.state_id, entry.state
- )
+ joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@defer.inlineCallbacks
@@ -181,9 +176,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
- joined_hosts = yield self.store.get_joined_hosts(
- room_id, entry.state_id, entry.state
- )
+ joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts)
@defer.inlineCallbacks
@@ -195,12 +188,12 @@ class StateHandler(object):
Returns:
synapse.events.snapshot.EventContext:
"""
- context = EventContext()
if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
+ context = EventContext()
if old_state:
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
@@ -219,6 +212,7 @@ class StateHandler(object):
defer.returnValue(context)
if old_state:
+ context = EventContext()
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
@@ -239,19 +233,13 @@ class StateHandler(object):
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
- if event.is_state():
- entry = yield self.resolve_state_groups(
- event.room_id, [e for e, _ in event.prev_events],
- event_type=event.type,
- state_key=event.state_key,
- )
- else:
- entry = yield self.resolve_state_groups(
- event.room_id, [e for e, _ in event.prev_events],
- )
+ entry = yield self.resolve_state_groups(
+ event.room_id, [e for e, _ in event.prev_events],
+ )
curr_state = entry.state
+ context = EventContext()
context.prev_state_ids = curr_state
if event.is_state():
context.state_group = self.store.get_next_state_group()
@@ -264,10 +252,14 @@ class StateHandler(object):
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
- context.prev_group = entry.prev_group
- context.delta_ids = entry.delta_ids
- if context.delta_ids is not None:
- context.delta_ids = dict(context.delta_ids)
+ if entry.state_group:
+ context.prev_group = entry.state_group
+ context.delta_ids = {
+ key: event.event_id
+ }
+ elif entry.prev_group:
+ context.prev_group = entry.prev_group
+ context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id
else:
if entry.state_group is None:
@@ -284,7 +276,7 @@ class StateHandler(object):
@defer.inlineCallbacks
@log_function
- def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
+ def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -309,11 +301,13 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
+ prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
- prev_group=name,
- delta_ids={},
+ prev_group=prev_group,
+ delta_ids=delta_ids,
))
with (yield self.resolve_linearizer.queue(group_names)):
@@ -357,20 +351,18 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
- if state_group is None:
- # Worker instances don't have access to this method, but we want
- # to set the state_group on the main instance to increase cache
- # hits.
- if hasattr(self.store, "get_next_state_group"):
- state_group = self.store.get_next_state_group()
+
+ # TODO: We want to create a state group for this set of events, to
+ # increase cache hits, but we need to make sure that it doesn't
+ # end up as a prev_group without being added to the database
prev_group = None
delta_ids = None
- for old_group, old_ids in state_groups_ids.items():
- if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
+ for old_group, old_ids in state_groups_ids.iteritems():
+ if not set(new_state) - set(old_ids):
n_delta_ids = {
k: v
- for k, v in new_state.items()
+ for k, v in new_state.iteritems()
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2970df138b..b92472df33 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
from .client_ips import ClientIpStore
+from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
@@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
+ UserDirectoryStore,
):
def __init__(self, db_conn, hs):
@@ -221,11 +223,24 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[]
+ after_callbacks=[],
+ final_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -289,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
- def get_user_ip_and_agents(self, user):
- return self._simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user.to_string()},
- retcols=[
- "access_token", "ip", "user_agent", "last_seen"
- ],
- desc="get_user_ip_and_agents",
- )
-
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 58b73af7d2..6f54036d67 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,6 +16,7 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
@@ -27,10 +28,6 @@ from twisted.internet import defer
import sys
import time
import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__)
@@ -52,13 +49,17 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+ __slots__ = [
+ "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ ]
- def __init__(self, txn, name, database_engine, after_callbacks):
+ def __init__(self, txn, name, database_engine, after_callbacks,
+ final_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
+ object.__setattr__(self, "final_callbacks", final_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -67,6 +68,9 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
+ def call_finally(self, callback, *args, **kwargs):
+ self.final_callbacks.append((callback, args, kwargs))
+
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -217,8 +221,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, logging_context,
- func, *args, **kwargs):
+ def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+ logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -237,7 +241,8 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks
+ txn, name, self.database_engine, after_callbacks,
+ final_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -298,6 +303,7 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
after_callbacks = []
+ final_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
@@ -309,7 +315,7 @@ class SQLBaseStore(object):
current_context.copy_to(context)
return self._new_transaction(
- conn, desc, after_callbacks, current_context,
+ conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
@@ -318,9 +324,13 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
- finally:
+
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
+ finally:
+ for after_callback, after_args, after_kwargs in final_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -425,6 +435,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
+ def _simple_insert_many(self, table, values, desc):
+ return self.runInteraction(
+ desc, self._simple_insert_many_txn, table, values
+ )
+
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
@@ -936,7 +951,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_finally(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..c63935cb07 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
import simplejson as json
from twisted.internet import defer
@@ -26,6 +27,25 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+def _make_exclusive_regex(services_cache):
+ # We precompie a regex constructed from all the regexes that the AS's
+ # have registered for exclusive users.
+ exclusive_user_regexes = [
+ regex.pattern
+ for service in services_cache
+ for regex in service.get_exlusive_user_regexes()
+ ]
+ if exclusive_user_regexes:
+ exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
+ exclusive_user_regex = re.compile(exclusive_user_regex)
+ else:
+ # We handle this case specially otherwise the constructed regex
+ # will always match
+ exclusive_user_regex = None
+
+ return exclusive_user_regex
+
+
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
@@ -35,17 +55,18 @@ class ApplicationServiceStore(SQLBaseStore):
hs.hostname,
hs.config.app_service_config_files
)
+ self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
- """Check if the user is one associated with an app service
+ """Check if the user is one associated with an app service (exclusively)
"""
- for service in self.services_cache:
- if service.is_interested_in_user(user_id):
- return True
- return False
+ if self.exclusive_user_regex:
+ return bool(self.exclusive_user_regex.match(user_id))
+ else:
+ return False
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 747d2df622..fc468ea185 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,11 +15,14 @@
import logging
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import Cache
from . import background_updates
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -33,7 +36,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
- max_entries=5000,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
)
super(ClientIpStore, self).__init__(hs)
@@ -45,7 +48,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
- @defer.inlineCallbacks
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -57,34 +67,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
+ return
self.client_ip_last_seen.prefill(key, now)
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
+ self._batch_row_update[key] = (user_agent, device_id, now)
+
+ def _update_client_ips_batch(self):
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
+ def _update_client_ips_batch_txn(self, txn, to_update):
+ self.database_engine.lock_table(txn, "user_ips")
+
+ for entry in to_update.iteritems():
+ (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -95,6 +119,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -103,23 +128,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
if not where_clauses:
return []
@@ -147,3 +183,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d9936c88bb..bb27fd1f70 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_pokes
+ FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
@@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore):
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # First we DELETE all rows such that only the latest row for each
- # (destination, user_id is left. We do this by selecting first and
- # deleting.
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
sql = """
- SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
- HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND stream_id < ?
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
"""
txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows)
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
- # Mark everything that is left as sent
sql = """
- UPDATE device_lists_outbound_pokes SET sent = ?
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
- txn.execute(sql, (True, destination, stream_id,))
+ txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
@@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore):
)
)
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 9caaf81f2c..79e7c540ad 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -170,3 +170,17 @@ class DirectoryStore(SQLBaseStore):
"room_alias",
desc="get_aliases_for_room",
)
+
+ def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+ def _update_aliases_for_room_txn(txn):
+ sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
+ txn.execute(sql, (new_room_id, creator, old_room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (old_room_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (new_room_id,)
+ )
+ return self.runInteraction(
+ "_update_aliases_for_room_txn", _update_aliases_for_room_txn
+ )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index e00f31da2b..2cebb203c6 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys
],
)
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
@@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
- yield self._simple_delete(
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_device_keys_by_device"
- )
- yield self._simple_delete(
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_one_time_keys_by_device"
+ def delete_e2e_keys_by_device_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
+ return self.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 519059c306..e8133de2fa 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def get_auth_chain(self, event_ids):
- return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+ def get_auth_chain(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
+
+ Returns:
+ list of events
+ """
+ return self.get_auth_chain_ids(
+ event_ids, include_given=include_given,
+ ).addCallback(self._get_events)
+
+ def get_auth_chain_ids(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def get_auth_chain_ids(self, event_ids):
+ Returns:
+ list of event_ids
+ """
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_ids
+ event_ids, include_given
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
- results = set()
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ if include_given:
+ results = set(event_ids)
+ else:
+ results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+ @defer.inlineCallbacks
+ def _background_delete_non_state_event_auth(self, progress, batch_size):
+ def delete_event_auth(txn):
+ target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+ max_stream_id = progress.get("max_stream_id_exclusive")
+
+ if not target_min_stream_id or not max_stream_id:
+ txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ target_min_stream_id = rows[0][0]
+
+ txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ max_stream_id = rows[0][0]
+
+ min_stream_id = max_stream_id - batch_size
+
+ sql = """
+ DELETE FROM event_auth
+ WHERE event_id IN (
+ SELECT event_id FROM events
+ LEFT JOIN state_events USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND state_key IS null
+ )
+ """
+
+ txn.execute(sql, (min_stream_id, max_stream_id,))
+
+ new_progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+ )
+
+ return min_stream_id >= target_min_stream_id
+
+ result = yield self.runInteraction(
+ self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+ defer.returnValue(batch_size)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f29d71589d..7002b3752e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -403,6 +403,11 @@ class EventsStore(SQLBaseStore):
(room_id, ), new_state
)
+ for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
+
@defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
"""Calculates the new forward extremeties for a room given events to
@@ -647,9 +652,10 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities.
"""
- self._update_current_state_txn(txn, current_state_for_room)
-
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+ self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -712,7 +718,7 @@ class EventsStore(SQLBaseStore):
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room):
+ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple
txn.executemany(
@@ -734,6 +740,29 @@ class EventsStore(SQLBaseStore):
],
)
+ state_deltas = {key: None for key in to_delete}
+ state_deltas.update(to_insert)
+
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_delta_stream",
+ values=[
+ {
+ "stream_id": max_stream_order,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": ev_id,
+ "prev_event_id": to_delete.get(key, None),
+ }
+ for key, ev_id in state_deltas.iteritems()
+ ]
+ )
+
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ room_id, max_stream_order,
+ )
+
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
@@ -742,11 +771,7 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in to_delete.iterkeys()
- if ev_type == EventTypes.Member
- )
- members_changed.update(
- state_key for ev_type, state_key in to_insert.iterkeys()
+ state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
@@ -755,6 +780,11 @@ class EventsStore(SQLBaseStore):
txn, self.get_rooms_for_user, (member,)
)
+ for host in set(get_domain_from_id(u) for u in members_changed):
+ self._invalidate_cache_and_stream(
+ txn, self.is_host_joined, (room_id, host)
+ )
+
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
@@ -1119,6 +1149,7 @@ class EventsStore(SQLBaseStore):
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
+ if event.is_state()
],
)
@@ -1418,7 +1449,7 @@ class EventsStore(SQLBaseStore):
]
rows = self._new_transaction(
- conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
)
row_dict = {
@@ -2243,6 +2274,24 @@ class EventsStore(SQLBaseStore):
defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+ def get_max_current_state_delta_stream_id(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index a2ccc66ea7..78b1e30945 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -19,6 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from canonicaljson import encode_canonical_json
import simplejson as json
@@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore):
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
- def_json = json.dumps(user_filter).encode("utf-8")
+ def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
+ "SELECT filter_id FROM user_filters "
+ "WHERE user_id = ? AND filter_json = ?"
+ )
+ txn.execute(sql, (user_localpart, def_json))
+ filter_id_response = txn.fetchone()
+ if filter_id_response is not None:
+ return filter_id_response[0]
+
+ sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 4c0f82353d..82bb61b811 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -30,13 +30,16 @@ class MediaRepositoryStore(SQLBaseStore):
return self._simple_select_one(
"local_media_repository",
{"media_id": media_id},
- ("media_type", "media_length", "upload_name", "created_ts"),
+ (
+ "media_type", "media_length", "upload_name", "created_ts",
+ "quarantined_by", "url_cache",
+ ),
allow_none=True,
desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
- media_length, user_id):
+ media_length, user_id, url_cache=None):
return self._simple_insert(
"local_media_repository",
{
@@ -46,6 +49,7 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
+ "url_cache": url_cache,
},
desc="store_local_media",
)
@@ -138,7 +142,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_origin": origin, "media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
- "filesystem_id",
+ "filesystem_id", "quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6e623843d5..72b670b83b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 41
+SCHEMA_VERSION = 43
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 0a819d32c5..8758b1c0c7 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index efb90c3c91..f42b8014c7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore):
return
# Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
if res:
if isinstance(res, defer.Deferred) and res.called:
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 5d543652bb..23688430b7 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -24,6 +24,7 @@ from .engines import PostgresEngine, Sqlite3Engine
import collections
import logging
import ujson as json
+import re
logger = logging.getLogger(__name__)
@@ -507,3 +508,98 @@ class RoomStore(SQLBaseStore):
))
else:
defer.returnValue(None)
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+ @defer.inlineCallbacks
+ def block_room(self, room_id, user_id):
+ yield self._simple_insert(
+ table="blocked_rooms",
+ values={
+ "room_id": room_id,
+ "user_id": user_id,
+ },
+ desc="block_room",
+ )
+ self.is_room_blocked.invalidate((room_id,))
+
+ def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ """For a room loops through all events with media and quarantines
+ the associated media
+ """
+ def _get_media_ids_in_room(txn):
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+
+ total_media_quarantined = 0
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ local_media_mxcs = []
+ remote_media_mxcs = []
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ # Now update all the tables to set the quarantined_by flag
+
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_media_mxcs
+ )
+ )
+
+ total_media_quarantined += len(local_media_mxcs)
+ total_media_quarantined += len(remote_media_mxcs)
+
+ return total_media_quarantined
+
+ return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0829ae5bee..457ca288d0 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
+from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context,
)
- def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
- room_id, state_group, state_ids,
+ room_id, state_group, state_entry.state, context=state_entry,
)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@@ -499,42 +501,40 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(users_in_room)
- def is_host_joined(self, room_id, host, state_group, state_ids):
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
+ @cachedInlineCallbacks(max_entries=10000)
+ def is_host_joined(self, room_id, host):
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
- return self._is_host_joined(
- room_id, host, state_group, state_ids
- )
+ sql = """
+ SELECT state_key FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE membership = 'join'
+ AND type = 'm.room.member'
+ AND c.room_id = ?
+ AND state_key LIKE ?
+ LIMIT 1
+ """
- @cachedInlineCallbacks(num_args=3)
- def _is_host_joined(self, room_id, host, state_group, current_state_ids):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
- for (etype, state_key), event_id in current_state_ids.items():
- if etype == EventTypes.Member:
- try:
- if get_domain_from_id(state_key) != host:
- continue
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
+ rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- defer.returnValue(True)
+ if not rows:
+ defer.returnValue(False)
- defer.returnValue(False)
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
- def get_joined_hosts(self, room_id, state_group, state_ids):
+ defer.returnValue(True)
+
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -543,33 +543,20 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_hosts(
- room_id, state_group, state_ids
+ room_id, state_group, state_entry.state, state_entry=state_entry,
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- joined_hosts = set()
- for etype, state_key in current_state_ids:
- if etype == EventTypes.Member:
- try:
- host = get_domain_from_id(state_key)
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
-
- if host in joined_hosts:
- continue
-
- event_id = current_state_ids[(etype, state_key)]
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- joined_hosts.add(intern_string(host))
+ cache = self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
defer.returnValue(joined_hosts)
@@ -647,3 +634,75 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)
+
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry,
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql
new file mode 100644
index 0000000000..d28851aff8
--- /dev/null
+++ b/synapse/storage/schema/delta/42/current_state_delta.sql
@@ -0,0 +1,26 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE current_state_delta_stream (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT, -- Is null if the key was removed
+ prev_event_id TEXT -- Is null if the key was added
+);
+
+CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql
new file mode 100644
index 0000000000..9ab8c14fa3
--- /dev/null
+++ b/synapse/storage/schema/delta/42/device_list_last_id.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Table of last stream_id that we sent to destination for user_id. This is
+-- used to fill out the `prev_id` fields of outbound device list updates.
+CREATE TABLE device_lists_outbound_last_success (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_lists_outbound_last_success
+ SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_pokes
+ WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
+ GROUP BY destination, user_id;
+
+CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
+ destination, user_id, stream_id
+);
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql
new file mode 100644
index 0000000000..b8821ac759
--- /dev/null
+++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_auth_state_only', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
new file mode 100644
index 0000000000..ea6a18196d
--- /dev/null
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+
+logger = logging.getLogger(__name__)
+
+
+BOTH_TABLES = """
+CREATE TABLE user_directory_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_directory (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
+
+CREATE TABLE users_in_pubic_room (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL -- A room_id that we know is public
+);
+
+CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
+CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
+"""
+
+
+POSTGRES_TABLE = """
+CREATE TABLE user_directory_search (
+ user_id TEXT NOT NULL,
+ vector tsvector
+);
+
+CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
+CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
+"""
+
+
+SQLITE_TABLE = """
+CREATE VIRTUAL TABLE user_directory_search
+ USING fts4 ( user_id, value );
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(BOTH_TABLES.splitlines()):
+ cur.execute(statement)
+
+ if isinstance(database_engine, PostgresEngine):
+ for statement in get_statements(POSTGRES_TABLE.splitlines()):
+ cur.execute(statement)
+ elif isinstance(database_engine, Sqlite3Engine):
+ for statement in get_statements(SQLITE_TABLE.splitlines()):
+ cur.execute(statement)
+ else:
+ raise Exception("Unrecognized database engine")
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql
new file mode 100644
index 0000000000..0e3cd143ff
--- /dev/null
+++ b/synapse/storage/schema/delta/43/blocked_rooms.sql
@@ -0,0 +1,21 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE blocked_rooms (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL -- Admin who blocked the room
+);
+
+CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql
new file mode 100644
index 0000000000..630907ec4f
--- /dev/null
+++ b/synapse/storage/schema/delta/43/quarantine_media.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
+ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/schema/delta/43/url_cache.sql
new file mode 100644
index 0000000000..45ebe020da
--- /dev/null
+++ b/synapse/storage/schema/delta/43/url_cache.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
new file mode 100644
index 0000000000..4501d90cbb
--- /dev/null
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Table keeping track of who shares a room with who. We only keep track
+-- of this for local users, so `user_id` is local users only (but we do keep track
+-- of which remote users share a room)
+CREATE TABLE users_who_share_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room
+);
+
+
+CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
+CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
+CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
+
+
+-- Make sure that we popualte the table initially
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..5673e4aa96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from collections import namedtuple
import logging
@@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+ def _get_state_group_delta_txn(txn):
+ prev_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self._simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcols=("type", "state_key", "event_id",)
+ )
+
+ return _GetStateGroupDelta(prev_group, {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in delta_ids
+ })
+ return self.runInteraction(
+ "get_state_group_delta",
+ _get_state_group_delta_txn,
+ )
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
+ is_in_db = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": context.prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (context.prev_group,)
+ )
+
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
@@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
],
)
+ for event_id, state_group_id in state_groups.iteritems():
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
+
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
@@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=50000)
- def _get_state_group_for_event(self, room_id, event_id):
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
@@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
+
for typ, state_key in types:
+ key = (typ, state_key)
if state_key is None:
type_to_key[typ] = None
- missing_types.add((typ, state_key))
+ missing_types.add(key)
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict_ids:
- missing_types.add((typ, state_key))
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types.add(key)
sentinel = object()
@@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
return True
return False
- got_all = not (missing_types or types is None)
+ got_all = is_all or not missing_types
return {
k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all
@@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict_ids
@@ -653,19 +725,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in group_to_state_dict.iteritems():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {
- (intern_string(etype), intern_string(state_key)): None
- for (etype, state_key) in types
- }
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
+ state_dict = results[group]
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
key=group,
value=state_dict,
full=(types is None),
+ known_absent=types,
)
- # Remove all the entries with None values. The None values were just
- # used for bookkeeping in the cache.
- for group, state_dict in results.iteritems():
- results[group] = {
- key: event_id
- for key, event_id in state_dict.iteritems()
- if event_id
- }
-
defer.returnValue(results)
def get_next_state_group(self):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
new file mode 100644
index 0000000000..2a4db3f03c
--- /dev/null
+++ b/synapse/storage/user_directory.py
@@ -0,0 +1,743 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id, get_localpart_from_id
+
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryStore(SQLBaseStore):
+ @cachedInlineCallbacks(cache_context=True)
+ def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ """Check if the room is either world_readable or publically joinable
+ """
+ current_state_ids = yield self.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_id:
+ join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ if join_rule_ev:
+ if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
+ defer.returnValue(True)
+
+ hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_id:
+ hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ if hist_vis_ev:
+ if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def add_users_to_public_room(self, room_id, user_ids):
+ """Add user to the list of users in public rooms
+
+ Args:
+ room_id (str): A room_id that all users are in that is world_readable
+ or publically joinable
+ user_ids (list(str)): Users to add
+ """
+ yield self._simple_insert_many(
+ table="users_in_pubic_room",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+ for user_id in user_ids
+ ],
+ desc="add_users_to_public_room"
+ )
+ for user_id in user_ids:
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def add_profiles_to_user_dir(self, room_id, users_with_profile):
+ """Add profiles to the user directory
+
+ Args:
+ room_id (str): A room_id that all users are joined to
+ users_with_profile (dict): Users to add to directory in the form of
+ mapping of user_id -> ProfileInfo
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ args = (
+ (
+ user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ profile.display_name,
+ )
+ for user_id, profile in users_with_profile.iteritems()
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT INTO user_directory_search(user_id, value)
+ VALUES (?,?)
+ """
+ args = (
+ (
+ user_id,
+ "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+ )
+ for user_id, p in users_with_profile.iteritems()
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ def _add_profiles_to_user_dir_txn(txn):
+ txn.executemany(sql, args)
+ self._simple_insert_many_txn(
+ txn,
+ table="user_directory",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ "display_name": profile.display_name,
+ "avatar_url": profile.avatar_url,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ ]
+ )
+ for user_id in users_with_profile:
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+
+ return self.runInteraction(
+ "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_user_dir(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_user_dir",
+ )
+ self.get_user_in_directory.invalidate((user_id,))
+
+ def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
+ def _update_profile_in_user_dir_txn(txn):
+ new_entry = self._simple_upsert_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ insertion_values={"room_id": room_id},
+ values={"display_name": display_name, "avatar_url": avatar_url},
+ lock=False, # We're only inserter
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ if new_entry:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id, get_localpart_from_id(user_id),
+ get_domain_from_id(user_id), display_name,
+ )
+ )
+ else:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ display_name, user_id,
+ )
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ value = "%s %s" % (user_id, display_name,) if display_name else user_id
+ self._simple_upsert_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ values={"value": value},
+ lock=False, # We're only inserter
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.runInteraction(
+ "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_public_user_list(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_public_user_list",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ )
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_user_in_public_room.invalidate, (user_id,)
+ )
+ return self.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn,
+ )
+
+ @defer.inlineCallbacks
+ def remove_from_user_in_public_room(self, user_id):
+ yield self._simple_delete(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ desc="remove_from_user_in_public_room",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def get_users_in_public_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ return self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_public_due_to_room",
+ )
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ user_ids_dir = yield self._simple_select_onecol(
+ table="user_directory",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_pub = yield self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share = yield self._simple_select_onecol(
+ table="users_who_share_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_dir)
+ user_ids.update(user_ids_pub)
+ user_ids.update(user_ids_share)
+
+ defer.returnValue(user_ids)
+
+ @defer.inlineCallbacks
+ def get_all_rooms(self):
+ """Get all room_ids we've ever known about, in ascending order of "size"
+ """
+ sql = """
+ SELECT room_id FROM current_state_events
+ GROUP BY room_id
+ ORDER BY count(*) ASC
+ """
+ rows = yield self._execute("get_all_rooms", None, sql)
+ defer.returnValue([room_id for room_id, in rows])
+
+ def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
+ """Insert entries into the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _add_users_who_share_room_txn(txn):
+ self._simple_insert_many_txn(
+ txn,
+ table="users_who_share_rooms",
+ values=[
+ {
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ "room_id": room_id,
+ "share_private": share_private,
+ }
+ for user_id, other_user_id in user_id_tuples
+ ],
+ )
+ for user_id, other_user_id in user_id_tuples:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "add_users_who_share_room", _add_users_who_share_room_txn
+ )
+
+ def update_users_who_share_room(self, room_id, share_private, user_id_sets):
+ """Updates entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _update_users_who_share_room_txn(txn):
+ sql = """
+ UPDATE users_who_share_rooms
+ SET room_id = ?, share_private = ?
+ WHERE user_id = ? AND other_user_id = ?
+ """
+ txn.executemany(
+ sql,
+ (
+ (room_id, share_private, uid, oid)
+ for uid, oid in user_id_sets
+ )
+ )
+ for user_id, other_user_id in user_id_sets:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "update_users_who_share_room", _update_users_who_share_room_txn
+ )
+
+ def remove_user_who_share_room(self, user_id, other_user_id):
+ """Deletes entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _remove_user_who_share_room_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ )
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+
+ return self.runInteraction(
+ "remove_user_who_share_room", _remove_user_who_share_room_txn
+ )
+
+ @cached(max_entries=500000)
+ def get_if_users_share_a_room(self, user_id, other_user_id):
+ """Gets if users share a room.
+
+ Args:
+ user_id (str): Must be a local user_id
+ other_user_id (str)
+
+ Returns:
+ bool|None: None if they don't share a room, otherwise whether they
+ share a private room or not.
+ """
+ return self._simple_select_one_onecol(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ retcol="share_private",
+ allow_none=True,
+ desc="get_if_users_share_a_room",
+ )
+
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ def get_users_who_share_room_from_dir(self, user_id):
+ """Returns the set of users who share a room with `user_id`
+
+ Args:
+ user_id(str): Must be a local user
+
+ Returns:
+ dict: user_id -> share_private mapping
+ """
+ rows = yield self._simple_select_list(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("other_user_id", "share_private",),
+ desc="get_users_who_share_room_with_user",
+ )
+
+ defer.returnValue({
+ row["other_user_id"]: row["share_private"]
+ for row in rows
+ })
+
+ def get_users_in_share_dir_with_room_id(self, user_id, room_id):
+ """Get all user tuples that are in the users_who_share_rooms due to the
+ given room_id.
+
+ Returns:
+ [(user_id, other_user_id)]: where one of the two will match the given
+ user_id.
+ """
+ sql = """
+ SELECT user_id, other_user_id FROM users_who_share_rooms
+ WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
+ """
+ return self._execute(
+ "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_in_common_for_users(self, user_id, other_user_id):
+ """Given two user_ids find out the list of rooms they share.
+ """
+ sql = """
+ SELECT room_id FROM (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) AS f1 INNER JOIN (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) f2 USING (room_id)
+ """
+
+ rows = yield self._execute(
+ "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+ )
+
+ defer.returnValue([room_id for room_id, in rows])
+
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_pubic_room")
+ txn.execute("DELETE FROM users_who_share_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+ txn.call_after(self.get_user_in_public_room.invalidate_all)
+ txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
+ txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+ return self.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self._simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id", "display_name", "avatar_url",),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ @cached()
+ def get_user_in_public_room(self, user_id):
+ return self._simple_select_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id",),
+ allow_none=True,
+ desc="get_user_in_public_room",
+ )
+
+ def get_user_directory_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="get_user_directory_stream_pos",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ @defer.inlineCallbacks
+ def search_user_dir(self, user_id, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+
+ # We order by rank and then if they have profile info
+ # The ranking algorithm is hand tweaked for "best" results. Broadly
+ # the idea is we give a higher weight to exact matches.
+ # The array of numbers are the weights for the various part of the
+ # search: (domain, _, display name, localpart)
+ sql = """
+ SELECT d.user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users_in_pubic_room AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ WHERE
+ (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ AND vector @@ to_tsquery('english', ?)
+ ORDER BY
+ (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+ * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (
+ 3 * ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ + ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ )
+ DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ search_query = _parse_query_sqlite(search_term)
+
+ sql = """
+ SELECT d.user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users_in_pubic_room AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ WHERE
+ (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ AND value MATCH ?
+ ORDER BY
+ rank(matchinfo(user_directory_search)) DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (user_id, search_query, limit + 1)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ results = yield self._execute(
+ "search_user_dir", self.cursor_to_dict, sql, *args
+ )
+
+ limited = len(results) > limit
+
+ defer.returnValue({
+ "limited": limited,
+ "results": results,
+ })
+
+
+def _parse_query_sqlite(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+
+ We specifically add both a prefix and non prefix matching term so that
+ exact matches get ranked higher.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+
+
+def _parse_query_postgres(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+ exact = " & ".join("%s" % (result,) for result in results)
+ prefix = " & ".join("%s:*" % (result,) for result in results)
+
+ return both, exact, prefix
diff --git a/synapse/types.py b/synapse/types.py
index 445bdcb4d7..111948540d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -62,6 +62,13 @@ def get_domain_from_id(string):
return string[idx + 1:]
+def get_localpart_from_id(string):
+ idx = string.find(":")
+ if idx == -1:
+ raise SynapseError(400, "Invalid ID: %r" % (string,))
+ return string[1:idx]
+
+
class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain"))
):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4a83c46d98..4adae96681 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
import synapse.metrics
import os
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 48dcbafeef..af65bfe7b8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -16,6 +16,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
@@ -25,7 +26,6 @@ from . import register_cache
from twisted.internet import defer
from collections import namedtuple
-import os
import functools
import inspect
import threading
@@ -37,9 +37,6 @@ logger = logging.getLogger(__name__)
_CacheSentinel = object()
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
@@ -404,6 +401,7 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.invalidate_all = cache.invalidate_all
wrapped.cache = cache
+ wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped
@@ -451,8 +449,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
def __get__(self, obj, objtype=None):
-
- cache = getattr(obj, self.cached_method_name).cache
+ cached_method = getattr(obj, self.cached_method_name)
+ cache = cached_method.cache
+ num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
@@ -469,12 +468,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
results = {}
cached_defers = {}
missing = []
- for arg in list_args:
+
+ # If the cache takes a single arg then that is used as the key,
+ # otherwise a tuple is used.
+ if num_args == 1:
+ def cache_get(arg):
+ return cache.get(arg, callback=invalidate_callback)
+ else:
key = list(keyargs)
- key[self.list_pos] = arg
+ def cache_get(arg):
+ key[self.list_pos] = arg
+ return cache.get(tuple(key), callback=invalidate_callback)
+
+ for arg in list_args:
try:
- res = cache.get(tuple(key), callback=invalidate_callback)
+ res = cache_get(arg)
+
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@@ -505,17 +515,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
observer = ObservableDeferred(observer)
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
+ if num_args == 1:
+ cache.set(
+ arg, observer,
+ callback=invalidate_callback
+ )
+
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, arg)
+ else:
+ key = list(keyargs)
+ key[self.list_pos] = arg
+ cache.set(
+ tuple(key), observer,
+ callback=invalidate_callback
+ )
+
+ def invalidate(f, key):
+ cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..d4105822b3 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
logger = logging.getLogger(__name__)
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+ """Returned when getting an entry from the cache
+
+ Attributes:
+ full (bool): Whether the cache has the full or dict or just some keys.
+ If not full then not all requested keys will necessarily be present
+ in `value`
+ known_absent (set): Keys that were looked up in the dict and were not
+ there.
+ value (dict): The full or partial dict value
+ """
def __len__(self):
return len(self.value)
@@ -58,21 +68,31 @@ class DictionaryCache(object):
)
def get(self, key, dict_keys=None):
+ """Fetch an entry out of the cache
+
+ Args:
+ key
+ dict_key(list): If given a set of keys then return only those keys
+ that exist in the cache.
+
+ Returns:
+ DictionaryEntry
+ """
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
self.metrics.inc_hits()
if dict_keys is None:
- return DictionaryEntry(entry.full, dict(entry.value))
+ return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
else:
- return DictionaryEntry(entry.full, {
+ return DictionaryEntry(entry.full, entry.known_absent, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
self.metrics.inc_misses()
- return DictionaryEntry(False, {})
+ return DictionaryEntry(False, set(), {})
def invalidate(self, key):
self.check_thread()
@@ -87,19 +107,34 @@ class DictionaryCache(object):
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, full=False):
+ def update(self, sequence, key, value, full=False, known_absent=None):
+ """Updates the entry in the cache
+
+ Args:
+ sequence
+ key
+ value (dict): The value to update the cache with.
+ full (bool): Whether the given value is the full dict, or just a
+ partial subset there of. If not full then any existing entries
+ for the key will be updated.
+ known_absent (set): Set of keys that we know don't exist in the full
+ dict.
+ """
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
+ if known_absent is None:
+ known_absent = set()
if full:
- self._insert(key, value)
+ self._insert(key, value, known_absent)
else:
- self._update_or_insert(key, value)
+ self._update_or_insert(key, value, known_absent)
- def _update_or_insert(self, key, value):
- entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+ def _update_or_insert(self, key, value, known_absent):
+ entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
+ entry.known_absent.update(known_absent)
- def _insert(self, key, value):
- self.cache[key] = DictionaryEntry(True, value)
+ def _insert(self, key, value, known_absent):
+ self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index cbdde34a57..6ad53a6390 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -94,6 +94,9 @@ class ExpiringCache(object):
return entry.value
+ def __contains__(self, key):
+ return key in self._cache
+
def get(self, key, default=None):
try:
return self[key]
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 70fe00ce0b..941d873ab8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,20 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import register_cache
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
from blist import sorteddict
import logging
-import os
logger = logging.getLogger(__name__)
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
@@ -89,6 +85,21 @@ class StreamChangeCache(object):
return result
+ def has_any_entity_changed(self, stream_pos):
+ """Returns if any entity has changed
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ self.metrics.inc_hits()
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ return i < len(keys)
+ else:
+ self.metrics.inc_misses()
+ return True
+
def get_all_entities_changed(self, stream_pos):
"""Returns all entites that have had new things since the given
position. If the position is too old it will return None.
|