diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 6c13ada5df..6eff83e5f8 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
+import ujson as json
+
class Filtering(object):
@@ -149,6 +151,9 @@ class FilterCollection(object):
"include_leave", False
)
+ def __repr__(self):
+ return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
+
def get_filter_json(self):
return self._filter_json
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6928d9d3e4..e5066c48ef 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -50,16 +50,14 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer
from twisted.application import service
-from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
-from synapse.http.server import JsonResource, RootRedirect
+from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
@@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.federation.transport.server import TransportLayerServer
from synapse import events
@@ -95,80 +94,37 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
-class SynapseHomeServer(HomeServer):
-
- def build_http_client(self):
- return MatrixFederationHttpClient(self)
-
- def build_client_resource(self):
- return ClientRestResource(self)
-
- def build_resource_for_federation(self):
- return JsonResource(self)
-
- def build_resource_for_web_client(self):
- webclient_path = self.get_config().web_client_location
- if not webclient_path:
- try:
- import syweb
- except ImportError:
- quit_with_error(
- "Could not find a webclient.\n\n"
- "Please either install the matrix-angular-sdk or configure\n"
- "the location of the source to serve via the configuration\n"
- "option `web_client_location`\n\n"
- "To install the `matrix-angular-sdk` via pip, run:\n\n"
- " pip install '%(dep)s'\n"
- "\n"
- "You can also disable hosting of the webclient via the\n"
- "configuration option `web_client`\n"
- % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
- )
- syweb_path = os.path.dirname(syweb.__file__)
- webclient_path = os.path.join(syweb_path, "webclient")
- # GZip is disabled here due to
- # https://twistedmatrix.com/trac/ticket/7678
- # (It can stay enabled for the API resources: they call
- # write() with the whole body and then finish() straight
- # after and so do not trigger the bug.
- # GzipFile was removed in commit 184ba09
- # return GzipFile(webclient_path) # TODO configurable?
- return File(webclient_path) # TODO configurable?
-
- def build_resource_for_static_content(self):
- # This is old and should go away: not going to bother adding gzip
- return File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- )
-
- def build_resource_for_content_repo(self):
- return ContentRepoResource(
- self, self.config.uploads_path, self.auth, self.content_addr
- )
-
- def build_resource_for_media_repository(self):
- return MediaRepositoryResource(self)
-
- def build_resource_for_server_key(self):
- return LocalKey(self)
-
- def build_resource_for_server_key_v2(self):
- return KeyApiV2Resource(self)
-
- def build_resource_for_metrics(self):
- if self.get_config().enable_metrics:
- return MetricsResource(self)
- else:
- return None
-
- def build_db_pool(self):
- name = self.db_config["name"]
+def build_resource_for_web_client(hs):
+ webclient_path = hs.get_config().web_client_location
+ if not webclient_path:
+ try:
+ import syweb
+ except ImportError:
+ quit_with_error(
+ "Could not find a webclient.\n\n"
+ "Please either install the matrix-angular-sdk or configure\n"
+ "the location of the source to serve via the configuration\n"
+ "option `web_client_location`\n\n"
+ "To install the `matrix-angular-sdk` via pip, run:\n\n"
+ " pip install '%(dep)s'\n"
+ "\n"
+ "You can also disable hosting of the webclient via the\n"
+ "configuration option `web_client`\n"
+ % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
+ )
+ syweb_path = os.path.dirname(syweb.__file__)
+ webclient_path = os.path.join(syweb_path, "webclient")
+ # GZip is disabled here due to
+ # https://twistedmatrix.com/trac/ticket/7678
+ # (It can stay enabled for the API resources: they call
+ # write() with the whole body and then finish() straight
+ # after and so do not trigger the bug.
+ # GzipFile was removed in commit 184ba09
+ # return GzipFile(webclient_path) # TODO configurable?
+ return File(webclient_path) # TODO configurable?
- return adbapi.ConnectionPool(
- name,
- **self.db_config.get("args", {})
- )
+class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
@@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls:
return
- metrics_resource = self.get_resource_for_metrics()
-
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
- client_resource = self.get_client_resource()
+ client_resource = ClientRestResource(self)
if res["compress"]:
client_resource = gz_wrap(client_resource)
@@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer):
if name == "federation":
resources.update({
- FEDERATION_PREFIX: self.get_resource_for_federation(),
+ FEDERATION_PREFIX: TransportLayerServer(self),
})
if name in ["static", "client"]:
resources.update({
- STATIC_PREFIX: self.get_resource_for_static_content(),
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ ),
})
if name in ["media", "federation", "client"]:
resources.update({
- MEDIA_PREFIX: self.get_resource_for_media_repository(),
- CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
+ MEDIA_PREFIX: MediaRepositoryResource(self),
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path, self.auth, self.content_addr
+ ),
})
if name in ["keys", "federation"]:
resources.update({
- SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
- SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
+ SERVER_KEY_PREFIX: LocalKey(self),
+ SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
if name == "webclient":
- resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
+ resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
- if name == "metrics" and metrics_resource:
- resources[METRICS_PREFIX] = metrics_resource
+ if name == "metrics" and self.get_config().enable_metrics:
+ resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources)
if tls:
@@ -296,6 +254,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
+ def get_db_conn(self):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
def quit_with_error(error_string):
message_lines = error_string.split("\n")
@@ -432,13 +402,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
- db_conn = database_engine.module.connect(
- **{
- k: v for k, v in config.database_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
-
+ db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
@@ -453,13 +417,17 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
+ hs.setup()
hs.start_listening()
- hs.get_pusherpool().start()
- hs.get_state_handler().start_caching()
- hs.get_datastore().start_profiling()
- hs.get_datastore().start_doing_background_updates()
- hs.get_replication_layer().start_get_pdu_cache()
+ def start():
+ hs.get_pusherpool().start()
+ hs.get_state_handler().start_caching()
+ hs.get_datastore().start_profiling()
+ hs.get_datastore().start_doing_background_updates()
+ hs.get_replication_layer().start_get_pdu_cache()
+
+ reactor.callWhenRunning(start)
return hs
@@ -675,7 +643,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B.
Args:
- resource (Resource): The *parent* Resource
+ resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 0bfb79d09f..979fdf2431 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -17,15 +17,10 @@
"""
from .replication import ReplicationLayer
-from .transport import TransportLayer
+from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver):
- transport = TransportLayer(
- homeserver,
- homeserver.hostname,
- server=homeserver.get_resource_for_federation(),
- client=homeserver.get_http_client()
- )
+ transport = TransportLayerClient(homeserver)
return ReplicationLayer(homeserver, transport)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 6e0be8ef15..3e062a5eab 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
- self.transport_layer.register_received_handler(self)
- self.transport_layer.register_request_handler(self)
self.federation_client = self
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 155a7d5870..d9fcc520a0 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
-
-from .server import TransportLayerServer
-from .client import TransportLayerClient
-
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-
-class TransportLayer(TransportLayerServer, TransportLayerClient):
- """This is a basic implementation of the transport layer that translates
- transactions and other requests to/from HTTP.
-
- Attributes:
- server_name (str): Local home server host
-
- server (synapse.http.server.HttpServer): the http server to
- register listeners on
-
- client (synapse.http.client.HttpClient): the http client used to
- send requests
-
- request_handler (TransportRequestHandler): The handler to fire when we
- receive requests for data.
-
- received_handler (TransportReceivedHandler): The handler to fire when
- we receive data.
- """
-
- def __init__(self, homeserver, server_name, server, client):
- """
- Args:
- server_name (str): Local home server host
- server (synapse.protocol.http.HttpServer): the http server to
- register listeners on
- client (synapse.protocol.http.HttpClient): the http client used to
- send requests
- """
- self.keyring = homeserver.get_keyring()
- self.clock = homeserver.get_clock()
- self.server_name = server_name
- self.server = server
- self.client = client
- self.request_handler = None
- self.received_handler = None
-
- self.ratelimiter = FederationRateLimiter(
- self.clock,
- window_size=homeserver.config.federation_rc_window_size,
- sleep_limit=homeserver.config.federation_rc_sleep_limit,
- sleep_msec=homeserver.config.federation_rc_sleep_delay,
- reject_limit=homeserver.config.federation_rc_reject_limit,
- concurrent_requests=homeserver.config.federation_rc_concurrent,
- )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 949d01dea8..2b5d40ea7f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers"""
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.client = hs.get_http_client()
+
@log_function
def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8dca0a7f6b..65e054f7dd 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
-from synapse.util.logutils import log_function
+from synapse.http.server import JsonResource
+from synapse.util.ratelimitutils import FederationRateLimiter
import functools
import logging
@@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__)
-class TransportLayerServer(object):
+class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
+ def __init__(self, hs):
+ self.hs = hs
+ self.clock = hs.get_clock()
+
+ super(TransportLayerServer, self).__init__(hs)
+
+ self.authenticator = Authenticator(hs)
+ self.ratelimiter = FederationRateLimiter(
+ self.clock,
+ window_size=hs.config.federation_rc_window_size,
+ sleep_limit=hs.config.federation_rc_sleep_limit,
+ sleep_msec=hs.config.federation_rc_sleep_delay,
+ reject_limit=hs.config.federation_rc_reject_limit,
+ concurrent_requests=hs.config.federation_rc_concurrent,
+ )
+
+ self.register_servlets()
+
+ def register_servlets(self):
+ register_servlets(
+ self.hs,
+ resource=self,
+ ratelimiter=self.ratelimiter,
+ authenticator=self.authenticator,
+ )
+
+
+class Authenticator(object):
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.server_name = hs.hostname
+
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request):
@@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
- @log_function
- def register_received_handler(self, handler):
- """ Register a handler that will be fired when we receive data.
-
- Args:
- handler (TransportReceivedHandler)
- """
- FederationSendServlet(
- handler,
- authenticator=self,
- ratelimiter=self.ratelimiter,
- server_name=self.server_name,
- ).register(self.server)
-
- @log_function
- def register_request_handler(self, handler):
- """ Register a handler that will be fired when we get asked for data.
-
- Args:
- handler (TransportRequestHandler)
- """
- for servletclass in SERVLET_CLASSES:
- servletclass(
- handler,
- authenticator=self,
- ratelimiter=self.ratelimiter,
- ).register(self.server)
-
class BaseFederationServlet(object):
- def __init__(self, handler, authenticator, ratelimiter):
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
@@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs):
- super(FederationSendServlet, self).__init__(handler, **kwargs)
+ super(FederationSendServlet, self).__init__(
+ handler, server_name=server_name, **kwargs
+ )
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = (
+ FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
@@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
)
+
+
+def register_servlets(hs, resource, authenticator, ratelimiter):
+ for servletclass in SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_replication_layer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6c19d6ae8c..2ce1e9d6c7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1186,7 +1186,13 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(e, auth_events=auth_for_e)
- except AuthError as err:
+ except SynapseError as err:
+ # we may get SynapseErrors here as well as AuthErrors. For
+ # instance, there are a couple of (ancient) events in some
+ # rooms whose senders do not have the correct sigil; these
+ # cause SynapseErrors in auth.check. We don't want to give up
+ # the attempt to federate altogether in such cases.
+
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ff800f8af1..82c8cb5f0c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError, AuthError, Codes
+from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -105,8 +105,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
- if room_token.topological is None:
- raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
@@ -117,27 +115,31 @@ class MessageHandler(BaseHandler):
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
+
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
+
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < room_token.topological:
- source_config.from_key = str(leave_token)
-
- if source_config.direction == "f":
- if source_config.to_key is None:
- source_config.to_key = str(leave_token)
- else:
- to_token = RoomStreamToken.parse(source_config.to_key)
- if leave_token.topological < to_token.topological:
- source_config.to_key = str(leave_token)
-
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, room_token.topological
- )
events, next_key = yield data_source.get_pagination_rows(
requester.user, source_config, room_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 328c049b03..075566417f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
)
-class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
+class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
@@ -298,46 +298,19 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token
)
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, ephemeral_by_room
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id, sync_config,
+ now_token=now_token,
+ since_token=timeline_since_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ all_ephemeral_by_room=ephemeral_by_room,
+ batch=batch,
+ full_state=True,
)
- unread_notifications = {}
- if notifs is not None:
- unread_notifications["notification_count"] = len(notifs)
- unread_notifications["highlight_count"] = len([
- 1 for notif in notifs if _action_has_highlight(notif["actions"])
- ])
-
- current_state = yield self.get_state_at(room_id, now_token)
-
- current_state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- current_state.values()
- )
- }
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
- )
-
- ephemeral = sync_config.filter_collection.filter_room_ephemeral(
- ephemeral_by_room.get(room_id, [])
- )
-
- defer.returnValue(JoinedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=current_state,
- ephemeral=ephemeral,
- account_data=account_data,
- unread_notifications=unread_notifications,
- ))
+ defer.returnValue(room_sync)
def account_data_for_user(self, account_data):
account_data_events = []
@@ -429,44 +402,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room))
- @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state
Returns:
- A Deferred JoinedSyncResult.
+ A Deferred ArchivedSyncResult.
"""
- batch = yield self.load_filtered_recents(
- room_id, sync_config, leave_token, since_token=timeline_since_token
+ return self.incremental_sync_for_archived_room(
+ sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
+ account_data_by_room, full_state=True, leave_token=leave_token,
)
- leave_state = yield self.store.get_state_for_event(leave_event_id)
-
- leave_state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- leave_state.values()
- )
- }
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
- )
-
- defer.returnValue(ArchivedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=leave_state,
- account_data=account_data,
- ))
-
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
@@ -512,154 +461,127 @@ class SyncHandler(BaseHandler):
sync_config.user
)
- timeline_limit = sync_config.filter_collection.timeline_limit()
+ user_id = sync_config.user.to_string()
- room_events, _ = yield self.store.get_room_events_stream(
- sync_config.user.to_string(),
- from_key=since_token.room_key,
- to_key=now_token.room_key,
- limit=timeline_limit + 1,
- )
+ timeline_limit = sync_config.filter_collection.timeline_limit()
tags_by_room = yield self.store.get_updated_tags(
- sync_config.user.to_string(),
+ user_id,
since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
- sync_config.user.to_string(),
+ user_id,
since_token.account_data_key,
)
)
- joined = []
- archived = []
- if len(room_events) <= timeline_limit:
- # There is no gap in any of the rooms. Therefore we can just
- # partition the new events by room and return them.
- logger.debug("Got %i events for incremental sync - not limited",
- len(room_events))
-
- invite_events = []
- leave_events = []
- events_by_room_id = {}
- for event in room_events:
- events_by_room_id.setdefault(event.room_id, []).append(event)
- if event.room_id not in joined_room_ids:
- if (event.type == EventTypes.Member
- and event.state_key == sync_config.user.to_string()):
- if event.membership == Membership.INVITE:
- invite_events.append(event)
- elif event.membership in (Membership.LEAVE, Membership.BAN):
- leave_events.append(event)
-
- for room_id in joined_room_ids:
- recents = events_by_room_id.get(room_id, [])
- logger.debug("Events for room %s: %r", room_id, recents)
- state = {
- (event.type, event.state_key): event
- for event in recents if event.is_state()}
- limited = False
-
- if recents:
- prev_batch = now_token.copy_and_replace(
- "room_key", recents[0].internal_metadata.before
- )
- else:
- prev_batch = now_token
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- logger.debug("User has just joined %s: needs full state",
- room_id)
- state = yield self.get_state_at(room_id, now_token)
- # the timeline is inherently limited if we've just joined
- limited = True
-
- recents = sync_config.filter_collection.filter_room_timeline(recents)
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state.values()
- )
- }
-
- acc_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
+ # Get a list of membership change events that have happened.
+ rooms_changed = yield self.store.get_room_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key
+ )
- acc_data = sync_config.filter_collection.filter_room_account_data(
- acc_data
- )
+ mem_change_events_by_room_id = {}
+ for event in rooms_changed:
+ mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
- ephemeral = sync_config.filter_collection.filter_room_ephemeral(
- ephemeral_by_room.get(room_id, [])
- )
+ newly_joined_rooms = []
+ archived = []
+ invited = []
+ for room_id, events in mem_change_events_by_room_id.items():
+ non_joins = [e for e in events if e.membership != Membership.JOIN]
+ has_join = len(non_joins) != len(events)
+
+ # We want to figure out if we joined the room at some point since
+ # the last sync (even if we have since left). This is to make sure
+ # we do send down the room, and with full state, where necessary
+ if room_id in joined_room_ids or has_join:
+ old_state = yield self.get_state_at(room_id, since_token)
+ old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+ if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
+ newly_joined_rooms.append(room_id)
+
+ if room_id in joined_room_ids:
+ continue
+
+ if not non_joins:
+ continue
- room_sync = JoinedSyncResult(
- room_id=room_id,
- timeline=TimelineBatch(
- events=recents,
- prev_batch=prev_batch,
- limited=limited,
- ),
- state=state,
- ephemeral=ephemeral,
- account_data=acc_data,
- unread_notifications={},
+ # Only bother if we're still currently invited
+ should_invite = non_joins[-1].membership == Membership.INVITE
+ if should_invite:
+ room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if room_sync:
+ invited.append(room_sync)
+
+ # Always include leave/ban events. Just take the last one.
+ # TODO: How do we handle ban -> leave in same batch?
+ leave_events = [
+ e for e in non_joins
+ if e.membership in (Membership.LEAVE, Membership.BAN)
+ ]
+
+ if leave_events:
+ leave_event = leave_events[-1]
+ room_sync = yield self.incremental_sync_for_archived_room(
+ sync_config, room_id, leave_event.event_id, since_token,
+ tags_by_room, account_data_by_room,
+ full_state=room_id in newly_joined_rooms
)
- logger.debug("Result for room %s: %r", room_id, room_sync)
-
if room_sync:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, all_ephemeral_by_room
- )
+ archived.append(room_sync)
- if notifs is not None:
- notif_dict = room_sync.unread_notifications
- notif_dict["notification_count"] = len(notifs)
- notif_dict["highlight_count"] = len([
- 1 for notif in notifs
- if _action_has_highlight(notif["actions"])
- ])
+ # Get all events for rooms we're currently joined to.
+ room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_ids=joined_room_ids,
+ from_key=since_token.room_key,
+ to_key=now_token.room_key,
+ limit=timeline_limit + 1,
+ )
- joined.append(room_sync)
+ joined = []
+ # We loop through all room ids, even if there are no new events, in case
+ # there are non room events taht we need to notify about.
+ for room_id in joined_room_ids:
+ room_entry = room_to_events.get(room_id, None)
- else:
- logger.debug("Got %i events for incremental sync - hit limit",
- len(room_events))
+ if room_entry:
+ events, start_key = room_entry
- invite_events = yield self.store.get_invites_for_user(
- sync_config.user.to_string()
- )
+ prev_batch_token = now_token.copy_and_replace("room_key", start_key)
- leave_events = yield self.store.get_leave_and_ban_events_for_user(
- sync_config.user.to_string()
- )
+ newly_joined_room = room_id in newly_joined_rooms
+ full_state = newly_joined_room
- for room_id in joined_room_ids:
- room_sync = yield self.incremental_sync_with_gap_for_room(
- room_id, sync_config, since_token, now_token,
- ephemeral_by_room, tags_by_room, account_data_by_room,
- all_ephemeral_by_room=all_ephemeral_by_room,
+ batch = yield self.load_filtered_recents(
+ room_id, sync_config, prev_batch_token,
+ since_token=since_token,
+ recents=events,
+ newly_joined_room=newly_joined_room,
)
- if room_sync:
- joined.append(room_sync)
+ else:
+ batch = TimelineBatch(
+ events=[],
+ prev_batch=since_token,
+ limited=False,
+ )
+ full_state = False
- for leave_event in leave_events:
- room_sync = yield self.incremental_sync_for_archived_room(
- sync_config, leave_event, since_token, tags_by_room,
- account_data_by_room
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id=room_id,
+ sync_config=sync_config,
+ since_token=since_token,
+ now_token=now_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ all_ephemeral_by_room=all_ephemeral_by_room,
+ batch=batch,
+ full_state=full_state,
)
if room_sync:
- archived.append(room_sync)
-
- invited = [
- InvitedSyncResult(room_id=event.room_id, invite=event)
- for event in invite_events
- ]
+ joined.append(room_sync)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@@ -680,28 +602,40 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None):
+ since_token=None, recents=None, newly_joined_room=False):
"""
:returns a Deferred TimelineBatch
"""
- limited = True
- recents = []
filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit()
- load_limit = max(timeline_limit * filtering_factor, 100)
- max_repeat = 3 # Only try a few times per room, otherwise
+ load_limit = max(timeline_limit * filtering_factor, 10)
+ max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key
end_key = room_key
+ limited = recents is None or newly_joined_room or timeline_limit < len(recents)
+
+ if recents is not None:
+ recents = sync_config.filter_collection.filter_room_timeline(recents)
+ recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(),
+ recents,
+ is_peeking=sync_config.is_guest,
+ )
+ else:
+ recents = []
+
+ since_key = None
+ if since_token and not newly_joined_room:
+ since_key = since_token.room_key
+
while limited and len(recents) < timeline_limit and max_repeat:
- events, keys = yield self.store.get_recent_events_for_room(
+ events, end_key = yield self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
- from_token=since_token.room_key if since_token else None,
- end_token=end_key,
+ from_key=since_key,
+ to_key=end_key,
)
- room_key, _ = keys
- end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
@@ -710,8 +644,10 @@ class SyncHandler(BaseHandler):
)
loaded_recents.extend(recents)
recents = loaded_recents
+
if len(events) <= load_limit:
limited = False
+ break
max_repeat -= 1
if len(recents) > timeline_limit:
@@ -724,7 +660,9 @@ class SyncHandler(BaseHandler):
)
defer.returnValue(TimelineBatch(
- events=recents, prev_batch=prev_batch_token, limited=limited
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room
))
@defer.inlineCallbacks
@@ -732,25 +670,12 @@ class SyncHandler(BaseHandler):
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room,
- all_ephemeral_by_room):
- """ Get the incremental delta needed to bring the client up to date for
- the room. Gives the client the most recent events and the changes to
- state.
- Returns:
- A Deferred JoinedSyncResult
- """
- logger.debug("Doing incremental sync for room %s between %s and %s",
- room_id, since_token, now_token)
-
- # TODO(mjark): Check for redactions we might have missed.
-
- batch = yield self.load_filtered_recents(
- room_id, sync_config, now_token, since_token,
- )
-
- logger.debug("Recents %r", batch)
+ all_ephemeral_by_room,
+ batch, full_state=False):
+ if full_state:
+ state = yield self.get_state_at(room_id, now_token)
- if batch.limited:
+ elif batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
@@ -772,17 +697,6 @@ class SyncHandler(BaseHandler):
if just_joined:
state = yield self.get_state_at(room_id, now_token)
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, all_ephemeral_by_room
- )
-
- unread_notifications = {}
- if notifs is not None:
- unread_notifications["notification_count"] = len(notifs)
- unread_notifications["highlight_count"] = len([
- 1 for notif in notifs if _action_has_highlight(notif["actions"])
- ])
-
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
@@ -800,6 +714,7 @@ class SyncHandler(BaseHandler):
ephemeral_by_room.get(room_id, [])
)
+ unread_notifications = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -809,48 +724,64 @@ class SyncHandler(BaseHandler):
unread_notifications=unread_notifications,
)
+ if room_sync:
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config, all_ephemeral_by_room
+ )
+
+ if notifs is not None:
+ unread_notifications["notification_count"] = len(notifs)
+ unread_notifications["highlight_count"] = len([
+ 1 for notif in notifs if _action_has_highlight(notif["actions"])
+ ])
+
logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
- def incremental_sync_for_archived_room(self, sync_config, leave_event,
+ def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room,
- account_data_by_room):
+ account_data_by_room, full_state,
+ leave_token=None):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
- stream_token = yield self.store.get_stream_token_for_event(
- leave_event.event_id
- )
+ if not leave_token:
+ stream_token = yield self.store.get_stream_token_for_event(
+ leave_event_id
+ )
- leave_token = since_token.copy_and_replace("room_key", stream_token)
+ leave_token = since_token.copy_and_replace("room_key", stream_token)
- if since_token.is_after(leave_token):
+ if since_token and since_token.is_after(leave_token):
defer.returnValue(None)
batch = yield self.load_filtered_recents(
- leave_event.room_id, sync_config, leave_token, since_token,
+ room_id, sync_config, leave_token, since_token,
)
logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
- leave_event.event_id
+ leave_event_id
)
- state_at_previous_sync = yield self.get_state_at(
- leave_event.room_id, stream_position=since_token
- )
+ if not full_state:
+ state_at_previous_sync = yield self.get_state_at(
+ room_id, stream_position=since_token
+ )
- state_events_delta = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=state_events_at_leave,
- )
+ state_events_delta = yield self.compute_state_delta(
+ since_token=since_token,
+ previous_state=state_at_previous_sync,
+ current_state=state_events_at_leave,
+ )
+ else:
+ state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
@@ -860,7 +791,7 @@ class SyncHandler(BaseHandler):
}
account_data = self.account_data_for_room(
- leave_event.room_id, tags_by_room, account_data_by_room
+ room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
@@ -868,7 +799,7 @@ class SyncHandler(BaseHandler):
)
room_sync = ArchivedSyncResult(
- room_id=leave_event.room_id,
+ room_id=room_id,
timeline=batch,
state=state_events_delta,
account_data=account_data,
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index d37dc11470..a87628ed85 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -22,7 +22,7 @@ REQUIREMENTS = {
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
- "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
+ "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index cb3ec23872..96633a176c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, e.message)
before = request.args.get("before", None)
- if before and len(before):
- before = before[0]
+ if before:
+ before = _namespaced_rule_id(spec, before[0])
+
after = request.args.get("after", None)
- if after and len(after):
- after = after[0]
+ if after:
+ after = _namespaced_rule_id(spec, after[0])
try:
yield self.hs.get_datastore().add_push_rule(
@@ -452,11 +453,15 @@ def _strip_device_condition(rule):
def _namespaced_rule_id_from_spec(spec):
+ return _namespaced_rule_id(spec, spec['rule_id'])
+
+
+def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
- return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
+ return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
diff --git a/synapse/server.py b/synapse/server.py
index 4a5796b982..5fee7fe130 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -20,6 +20,8 @@
# Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
+from twisted.enterprise import adbapi
+
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
@@ -36,8 +38,15 @@ from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+
+import logging
+
-class BaseHomeServer(object):
+logger = logging.getLogger(__name__)
+
+
+class HomeServer(object):
"""A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as
@@ -98,39 +107,18 @@ class BaseHomeServer(object):
self.hostname = hostname
self._building = {}
+ self.clock = Clock()
+ self.distributor = Distributor()
+ self.ratelimiter = Ratelimiter()
+
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
- @classmethod
- def _make_dependency_method(cls, depname):
- def _get(self):
- if hasattr(self, depname):
- return getattr(self, depname)
-
- if hasattr(self, "build_%s" % (depname)):
- # Prevent cyclic dependencies from deadlocking
- if depname in self._building:
- raise ValueError("Cyclic dependency while building %s" % (
- depname,
- ))
- self._building[depname] = 1
-
- builder = getattr(self, "build_%s" % (depname))
- dep = builder()
- setattr(self, depname, dep)
-
- del self._building[depname]
-
- return dep
-
- raise NotImplementedError(
- "%s has no %s nor a builder for it" % (
- type(self).__name__, depname,
- )
- )
-
- setattr(BaseHomeServer, "get_%s" % (depname), _get)
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = DataStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
@@ -142,33 +130,9 @@ class BaseHomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
-# Build magic accessors for every dependency
-for depname in BaseHomeServer.DEPENDENCIES:
- BaseHomeServer._make_dependency_method(depname)
-
-
-class HomeServer(BaseHomeServer):
- """A homeserver object that will construct most of its dependencies as
- required.
-
- It still requires the following to be specified by the caller:
- resource_for_client
- resource_for_web_client
- resource_for_federation
- resource_for_content_repo
- http_client
- db_pool
- """
-
- def build_clock(self):
- return Clock()
-
def build_replication_layer(self):
return initialize_http_replication(self)
- def build_datastore(self):
- return DataStore(self)
-
def build_handlers(self):
return Handlers(self)
@@ -179,10 +143,9 @@ class HomeServer(BaseHomeServer):
return Auth(self)
def build_http_client_context_factory(self):
- config = self.get_config()
return (
InsecureInterceptableContextFactory()
- if config.use_insecure_ssl_client_just_for_testing_do_not_use
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
@@ -201,15 +164,9 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self):
return StateHandler(self)
- def build_distributor(self):
- return Distributor()
-
def build_event_sources(self):
return EventSources(self)
- def build_ratelimiter(self):
- return Ratelimiter()
-
def build_keyring(self):
return Keyring(self)
@@ -224,3 +181,55 @@ class HomeServer(BaseHomeServer):
def build_pusherpool(self):
return PusherPool(self)
+
+ def build_http_client(self):
+ return MatrixFederationHttpClient(self)
+
+ def build_db_pool(self):
+ name = self.db_config["name"]
+
+ return adbapi.ConnectionPool(
+ name,
+ **self.db_config.get("args", {})
+ )
+
+
+def _make_dependency_method(depname):
+ def _get(hs):
+ try:
+ return getattr(hs, depname)
+ except AttributeError:
+ pass
+
+ try:
+ builder = getattr(hs, "build_%s" % (depname))
+ except AttributeError:
+ builder = None
+
+ if builder:
+ # Prevent cyclic dependencies from deadlocking
+ if depname in hs._building:
+ raise ValueError("Cyclic dependency while building %s" % (
+ depname,
+ ))
+ hs._building[depname] = 1
+
+ dep = builder()
+ setattr(hs, depname, dep)
+
+ del hs._building[depname]
+
+ return dep
+
+ raise NotImplementedError(
+ "%s has no %s nor a builder for it" % (
+ type(hs).__name__, depname,
+ )
+ )
+
+ setattr(HomeServer, "get_%s" % (depname), _get)
+
+
+# Build magic accessors for every dependency
+for depname in HomeServer.DEPENDENCIES:
+ _make_dependency_method(depname)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..eb88842308 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore
+from util.id_generators import IdGenerator, StreamIdGenerator
+
+
import logging
@@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
- def __init__(self, hs):
- super(DataStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
self.hs = hs
- self.min_token_deferred = self._get_min_token()
- self.min_token = None
+ cur = db_conn.cursor()
+ try:
+ cur.execute("SELECT MIN(stream_ordering) FROM events",)
+ rows = cur.fetchall()
+ self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+ self.min_stream_token = min(self.min_stream_token, -1)
+ finally:
+ cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering"
+ )
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
+ self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+ self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+
+ super(DataStore, self).__init__(hs)
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 90d7aee94a..5e77320540 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
-from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
- self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result)
- def cursor_to_dict(self, cursor):
+ @staticmethod
+ def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
- @log_function
- def _simple_insert_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
- def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ @classmethod
+ def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
- ret = self._simple_select_onecol_txn(
+ ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
- def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ @staticmethod
+ def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
- def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+ @classmethod
+ def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -627,7 +620,7 @@ class SQLBaseStore(object):
)
txn.execute(sql)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
@@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results)
- def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+ @classmethod
+ def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -699,7 +693,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ @staticmethod
+ def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
- def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ @staticmethod
+ def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
- def _simple_delete_txn(self, txn, table, keyvalues):
+ @staticmethod
+ def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 9c6597e012..ed6587429b 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
import ujson as json
@@ -23,6 +24,14 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
+ def __init__(self, hs):
+ super(AccountDataStore, self).__init__(hs)
+
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache",
+ self._account_data_id_gen.get_max_token(None),
+ max_size=10000,
+ )
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
@@ -83,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
"""Get all the client account_data for a that's changed.
Args:
@@ -120,6 +129,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room)
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return ({}, {})
+
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -186,6 +201,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..5e85552029 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- start = self.min_token - 1
- self.min_token -= len(events_and_contexts) + 1
- stream_orderings = range(start, self.min_token, -1)
+ start = self.min_stream_token - 1
+ self.min_stream_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
@@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
+ self.min_stream_token -= 1
+ stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@@ -214,6 +210,12 @@ class EventsStore(SQLBaseStore):
for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id)
+ if not backfilled:
+ txn.call_after(
+ self._events_stream_cache.entity_has_changed,
+ event.room_id, event.internal_metadata.stream_ordering,
+ )
+
depth_updates = {}
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index f8fc9bdddc..5248736816 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -16,12 +16,13 @@
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json
class FilteringStore(SQLBaseStore):
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 1f51c90ee5..f9a48171ba 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore):
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None)
- relative_to_rule = kwargs.pop("before", after)
+ before = kwargs.pop("before", None)
+ relative_to_rule = before or after
res = self._simple_select_one_txn(
txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..8068c73740 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -15,11 +15,10 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
-from blist import sorteddict
import logging
import ujson as json
@@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
- self._receipts_stream_cache = _RoomStreamChangeCache()
+ self._receipts_stream_cache = StreamChangeCache(
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+ )
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -76,8 +77,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids)
if from_key:
- room_ids = yield self._receipts_stream_cache.get_rooms_changed(
- self, room_ids, from_key
+ room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
@@ -220,6 +221,11 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+ txn.call_after(
+ self._receipts_stream_cache.entity_has_changed,
+ room_id, stream_id
+ )
+
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
@@ -307,9 +313,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
- yield self._receipts_stream_cache.room_has_changed(
- self, room_id, stream_id
- )
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
@@ -368,63 +371,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
-
-class _RoomStreamChangeCache(object):
- """Keeps track of the stream_id of the latest change in rooms.
-
- Given a list of rooms and stream key, it will give a subset of rooms that
- may have changed since that key. If the key is too old then the cache
- will simply return all rooms.
- """
- def __init__(self, size_of_cache=10000):
- self._size_of_cache = size_of_cache
- self._room_to_key = {}
- self._cache = sorteddict()
- self._earliest_key = None
- self.name = "ReceiptsRoomChangeCache"
- caches_by_name[self.name] = self._cache
-
- @defer.inlineCallbacks
- def get_rooms_changed(self, store, room_ids, key):
- """Returns subset of room ids that have had new receipts since the
- given key. If the key is too old it will just return the given list.
- """
- if key > (yield self._get_earliest_key(store)):
- keys = self._cache.keys()
- i = keys.bisect_right(key)
-
- result = set(
- self._cache[k] for k in keys[i:]
- ).intersection(room_ids)
-
- cache_counter.inc_hits(self.name)
- else:
- result = room_ids
- cache_counter.inc_misses(self.name)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def room_has_changed(self, store, room_id, key):
- """Informs the cache that the room has been changed at the given key.
- """
- if key > (yield self._get_earliest_key(store)):
- old_key = self._room_to_key.get(room_id, None)
- if old_key:
- key = max(key, old_key)
- self._cache.pop(old_key, None)
- self._cache[key] = room_id
-
- while len(self._cache) > self._size_of_cache:
- k, r = self._cache.popitem()
- self._earliest_key = max(k, self._earliest_key)
- self._room_to_key.pop(r, None)
-
- @defer.inlineCallbacks
- def _get_earliest_key(self, store):
- if self._earliest_key is None:
- self._earliest_key = yield store.get_max_receipt_stream_id()
- self._earliest_key = int(self._earliest_key)
-
- defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index edfecced05..1d3e004c90 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -241,7 +241,7 @@ class RoomMemberStore(SQLBaseStore):
return rows
- @cached()
+ @cached(max_entries=5000)
def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql
new file mode 100644
index 0000000000..200c35e6e2
--- /dev/null
+++ b/synapse/storage/schema/delta/28/events_room_stream.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX events_room_stream on events(room_id, stream_ordering);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..6e81d46c60 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,6 +37,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function
@@ -77,6 +78,12 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
+ def __init__(self, hs):
+ super(StreamStore, self).__init__(hs)
+
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
+ )
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@@ -157,6 +164,135 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+ room_ids = yield self._events_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+
+ if not room_ids:
+ defer.returnValue({})
+
+ results = {}
+ room_ids = list(room_ids)
+ for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
+ res = yield defer.gatherResults([
+ self.get_room_events_stream_for_room(
+ room_id, from_key, to_key, limit
+ ).addCallback(lambda r, rm: (rm, r), room_id)
+ for room_id in room_ids
+ ])
+ results.update(dict(res))
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ defer.returnValue(([], from_key))
+
+ if from_id:
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
+
+ if not has_changed:
+ defer.returnValue(([], from_key))
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC LIMIT ?"
+ )
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+ else:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC LIMIT ?"
+ )
+ txn.execute(sql, (room_id, to_id, limit))
+
+ rows = self.cursor_to_dict(txn)
+
+ ret = self._get_events_txn(
+ txn,
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows, topo_order=False)
+
+ ret.reverse()
+
+ if rows:
+ key = "s%d" % min(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = from_key
+
+ return ret, key
+ res = yield self.runInteraction("get_room_events_stream_for_room", f)
+ defer.returnValue(res)
+
+ def get_room_changes_for_user(self, user_id, from_key, to_key):
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ return defer.succeed([])
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+ else:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, to_id,))
+ rows = self.cursor_to_dict(txn)
+
+ ret = self._get_events_txn(
+ txn,
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ return ret
+
+ return self.runInteraction("get_room_changes_for_user", f)
+
@log_function
def get_room_events_stream(
self,
@@ -174,7 +310,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
- " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+ " WHERE c.room_id IN (%s)"
+ " AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
@@ -434,6 +571,18 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
+ def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ sql = (
+ "SELECT max(topological_ordering) FROM events"
+ " WHERE room_id = ? AND stream_ordering < ?"
+ )
+ return self._execute(
+ "get_max_topological_token_for_stream_and_room", None,
+ sql, room_id, stream_key,
+ ).addCallback(
+ lambda r: r[0][0] if r else 0
+ )
+
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
@@ -444,24 +593,14 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
-
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
-
- logger.debug("min_token is: %s", self.min_token)
-
- defer.returnValue(self.min_token)
-
@staticmethod
- def _set_before_and_after(events, rows):
+ def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows):
stream = row["stream_ordering"]
- topo = event.depth
+ if topo_order:
+ topo = event.depth
+ else:
+ topo = None
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..e1a9c0c261 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@@ -25,13 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
- def __init__(self, hs):
- super(TagsStore, self).__init__(hs)
-
- self._account_data_id_gen = StreamIdGenerator(
- "account_data_max_stream_id", "stream_id"
- )
-
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@@ -87,6 +79,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ defer.returnValue({})
+
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
@@ -184,6 +182,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to.
"""
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id
+ )
+
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..5c522f4ab9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
- def __init__(self, table, column):
+ def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
- self._current_max = None
+ cur = db_conn.cursor()
+ self._current_max = self._get_or_compute_current_max(cur)
+ cur.close()
+
self._unfinished_ids = deque()
- @defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
self._current_max += 1
next_id = self._current_max
@@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
if self._unfinished_ids:
- defer.returnValue(self._unfinished_ids[0] - 1)
+ return self._unfinished_ids[0] - 1
- defer.returnValue(self._current_max)
+ return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e6a66dc041..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -37,7 +37,7 @@ class LruCache(object):
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
- self.size = 0
+ self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
- self.size += 1
def move_node_to_front(node):
prev_node = node[PREV]
@@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- self.size -= 1
@synchronized
def cache_get(key, default=None):
@@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE]
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -145,7 +143,7 @@ class LruCache(object):
@synchronized
def cache_len():
- return self.size
+ return len(cache)
@synchronized
def cache_contains(key):
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
new file mode 100644
index 0000000000..c673b1bdfc
--- /dev/null
+++ b/synapse/util/caches/stream_change_cache.py
@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class StreamChangeCache(object):
+ """Keeps track of the stream positions of the latest change in a set of entities.
+
+ Typically the entity will be a room or user id.
+
+ Given a list of entities and a stream position, it will give a subset of
+ entities that may have changed since that position. If position key is too
+ old then the cache will simply return all given entities.
+ """
+ def __init__(self, name, current_stream_pos, max_size=10000):
+ self._max_size = max_size
+ self._entity_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_known_stream_pos = current_stream_pos
+ self.name = name
+ caches_by_name[self.name] = self._cache
+
+ def has_entity_changed(self, entity, stream_pos):
+ """Returns True if the entity may have been updated since stream_pos
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos < self._earliest_known_stream_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ if stream_pos == self._earliest_known_stream_pos:
+ # If the same as the earliest key, assume nothing has changed.
+ cache_counter.inc_hits(self.name)
+ return False
+
+ latest_entity_change_pos = self._entity_to_key.get(entity, None)
+ if latest_entity_change_pos is None:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ if stream_pos < latest_entity_change_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ cache_counter.inc_hits(self.name)
+ return False
+
+ def get_entities_changed(self, entities, stream_pos):
+ """Returns subset of entities that have had new things since the
+ given position. If the position is too old it will just return the given list.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(entities)
+
+ cache_counter.inc_hits(self.name)
+ else:
+ result = entities
+ cache_counter.inc_misses(self.name)
+
+ return result
+
+ def entity_has_changed(self, entity, stream_pos):
+ """Informs the cache that the entity has been changed at the given
+ position.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos > self._earliest_known_stream_pos:
+ old_pos = self._entity_to_key.get(entity, None)
+ if old_pos:
+ stream_pos = max(stream_pos, old_pos)
+ self._cache.pop(old_pos, None)
+ self._cache[stream_pos] = entity
+ self._entity_to_key[entity] = stream_pos
+
+ while len(self._cache) > self._max_size:
+ k, r = self._cache.popitem()
+ self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 3b58860910..29d02f7e95 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples.
"""
def __init__(self):
+ self.size = 0
self.root = {}
def __setitem__(self, key, value):
@@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
- node[key[-1]] = value
+ node[key[-1]] = _Entry(value)
+ self.size += 1
def get(self, key, default=None):
node = self.root
@@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None)
if node is None:
return default
- return node.get(key[-1], default)
+ return node.get(key[-1], _Entry(default)).value
def clear(self):
+ self.size = 0
self.root = {}
def pop(self, key, default=None):
@@ -57,4 +60,33 @@ class TreeCache(object):
break
node_and_keys[i+1][0].pop(k)
+ popped, cnt = _strip_and_count_entires(popped)
+ self.size -= cnt
return popped
+
+ def __len__(self):
+ return self.size
+
+
+class _Entry(object):
+ __slots__ = ["value"]
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _strip_and_count_entires(d):
+ """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+ value or a dictionary with _Entry's replaced by their values.
+
+ Also returns the count of _Entry's
+ """
+ if isinstance(d, dict):
+ cnt = 0
+ for key, value in d.items():
+ v, n = _strip_and_count_entires(value)
+ d[key] = v
+ cnt += n
+ return d, cnt
+ else:
+ return d.value, 1
|