diff options
24 files changed, 367 insertions, 92 deletions
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index c8dde0fcb8..8d755a4b33 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -80,11 +80,6 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) @@ -168,7 +163,6 @@ class PusherServer(HomeServer): store = self.get_datastore() replication_url = self.config.worker_replication_url pusher_pool = self.get_pusherpool() - clock = self.get_clock() def stop_pusher(user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) @@ -220,21 +214,11 @@ class PusherServer(HomeServer): min_stream_id, max_stream_id, affected_room_ids ) - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) poke_pushers(result) except: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 48bc97636c..e3173533e2 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite from synapse.http.server import JsonResource from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import events from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -74,11 +75,6 @@ class SynchrotronSlavedStore( BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different ): - # XXX: This is a bit broken because we don't persist forgotten rooms - # in a way that they can be streamed. This means that we don't have a - # way to invalidate the forgotten rooms cache correctly. - # For now we expire the cache every 10 minutes. - BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 who_forgot_in_room = ( RoomMemberStore.__dict__["who_forgot_in_room"] ) @@ -89,17 +85,23 @@ class SynchrotronSlavedStore( get_presence_list_accepted = PresenceStore.__dict__[ "get_presence_list_accepted" ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() + self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = { @@ -124,6 +126,8 @@ class SynchrotronPresence(object): pass get_states = PresenceHandler.get_states.__func__ + get_state = PresenceHandler.get_state.__func__ + _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ @defer.inlineCallbacks @@ -194,19 +198,39 @@ class SynchrotronPresence(object): self._need_to_send_sync = False yield self._send_syncing_users_now() + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield self._get_interested_parties( + states, calculate_remote_hosts=False + ) + room_ids_to_states, users_to_states, _ = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=users_to_states.keys() + ) + + @defer.inlineCallbacks def process_replication(self, result): stream = result.get("presence", {"rows": []}) + states = [] for row in stream["rows"]: ( position, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) = row - self.user_to_current_state[user_id] = UserPresenceState( + state = UserPresenceState( user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active ) + self.user_to_current_state[user_id] = state + states.append(state) + + if states and "position" in stream: + stream_id = int(stream["position"]) + yield self.notify_from_replication(states, stream_id) class SynchrotronTyping(object): @@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) sync.register_servlets(self, resource) + events.register_servlets(self, resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, }) root_resource = create_resource_tree(resources, Resource()) @@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer): http_client = self.get_simple_http_client() store = self.get_datastore() replication_url = self.config.worker_replication_url - clock = self.get_clock() notifier = self.get_notifier() presence_handler = self.get_presence_handler() typing_handler = self.get_typing_handler() - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - store.get_presence_list_accepted.invalidate_all() - def notify_from_stream( result, stream_name, stream_key, room=None, user=None ): @@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer): result, "typing", "typing_key", room="room_id" ) - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args.update(typing_handler.stream_positions()) args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) typing_handler.process_replication(result) - presence_handler.process_replication(result) + yield presence_handler.process_replication(result) notify(result) except: logger.exception("Error replicating from %r", replication_url) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 1a50a2ec98..63d05f2531 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -19,7 +19,6 @@ from .room import ( ) from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler @@ -53,8 +52,6 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2293b5fdf7..6a1fe76c88 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -503,7 +503,7 @@ class PresenceHandler(object): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -526,14 +526,15 @@ class PresenceHandler(object): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states) @@ -565,6 +566,16 @@ class PresenceHandler(object): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c2d487ff4..84993b33b3 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -41,6 +41,7 @@ STREAM_NAMES = ( ("push_rules",), ("pushers",), ("state",), + ("caches",), ) @@ -70,6 +71,7 @@ class ReplicationResource(Resource): * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. * "pushers": Per user changes to their pushers. + * "caches": Cache invalidations. The API takes two additional query parameters: @@ -129,6 +131,7 @@ class ReplicationResource(Resource): push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() state_token = self.store.get_state_stream_token() + caches_token = self.store.get_cache_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -140,6 +143,7 @@ class ReplicationResource(Resource): push_rules_token, pushers_token, state_token, + caches_token, )) @request_handler() @@ -188,6 +192,7 @@ class ReplicationResource(Resource): yield self.push_rules(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams) + yield self.caches(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) logger.info("Replicated %d rows", writer.total) @@ -379,6 +384,20 @@ class ReplicationResource(Resource): "position", "type", "state_key", "event_id" )) + @defer.inlineCallbacks + def caches(self, writer, current_token, limit, request_streams): + current_position = current_token.caches + + caches = request_streams.get("caches") + + if caches is not None: + updated_caches = yield self.store.get_all_updated_caches( + caches, current_position, limit + ) + writer.write_header_and_rows("caches", updated_caches, ( + "position", "cache_func", "keys", "invalidation_ts" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -407,7 +426,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state" + "push_rules", "pushers", "state", "caches", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 46e43ce1c7..d839d169ab 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,15 +14,43 @@ # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from ._slaved_id_tracker import SlavedIdTracker + +import logging + +logger = logging.getLogger(__name__) + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(hs) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = SlavedIdTracker( + db_conn, "cache_invalidation_stream", "stream_id", + ) + else: + self._cache_id_gen = None def stream_positions(self): - return {} + pos = {} + if self._cache_id_gen: + pos["caches"] = self._cache_id_gen.get_current_token() + return pos def process_replication(self, result): + stream = result.get("caches") + if stream: + for row in stream["rows"]: + ( + position, cache_func, keys, invalidation_ts, + ) = row + + try: + getattr(self, cache_func).invalidate(tuple(keys)) + except AttributeError: + logger.warn("Got unexpected cache_func: %r", cache_func) + self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 5fbe3a303a..7301d885f2 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore class DirectoryStore(BaseSlavedStore): get_aliases_for_room = DirectoryStore.__dict__[ "get_aliases_for_room" - ].orig + ] diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index b0cb31a448..af21661d7c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(WhoisRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" ) + def __init__(self, hs): + super(PurgeHistoryRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 96b49b01f2..c2a8447860 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet): hs (synapse.server.HomeServer): """ self.hs = hs - self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8ac09419dc..09d0831594 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -36,6 +36,10 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") + def __init__(self, hs): + super(ClientDirectoryServer, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) @@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryListServer, self).__init__(hs) self.store = hs.get_datastore() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 498bb9e18a..701b6f549b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 + def __init__(self, hs): + super(EventStreamRestServlet, self).__init__(hs) + self.event_stream_handler = hs.get_event_stream_handler() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( @@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet): if "room_id" in request.args: room_id = request.args["room_id"][0] - handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if "timeout" in request.args: @@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args - chunk = yield handler.get_stream( + chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.event_handler - event = yield handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 36c3520567..113a49e539 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns class InitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/initialSync$") + def __init__(self, hs): + super(InitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b31e27f7b3..6c0eec8fb3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): flows = [] @@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet): def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) self.sp_config = hs.config.saml2_config_path + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): @@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 65c4e2ebef..355e82474b 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") + def __init__(self, hs): + super(ProfileDisplaynameRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") + def __init__(self, hs): + super(ProfileAvatarURLRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) @@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(ProfileRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2383b9df86..71d58c8e8d 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet): self.sessions = {} self.enable_registration = hs.config.enable_registration self.auth_handler = hs.get_auth_handler() + self.handlers = hs.get_handlers() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 866a1e9120..89c3895118 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -35,6 +35,10 @@ logger = logging.getLogger(__name__) class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here + def __init__(self, hs): + super(RoomCreateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) @@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomStateEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomSendEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") @@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ClientV1RestServlet): + def __init__(self, hs): + super(JoinRoomAliasServlet, self).__init__(hs) + self.handlers = hs.get_handlers() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") + def __init__(self, hs): + super(RoomMemberListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") + def __init__(self, hs): + super(RoomMessageListRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") + def __init__(self, hs): + super(RoomStateRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") + def __init__(self, hs): + super(RoomInitialSyncRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) @@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet): def __init__(self, hs): super(RoomEventContext, self).__init__(hs) self.clock = hs.get_clock() + self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomForgetRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") register_txn_path(self, PATTERNS, http_server) @@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomMembershipRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" @@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): + def __init__(self, hs): + super(RoomRedactEventRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") register_txn_path(self, PATTERNS, http_server) @@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet): "/search$" ) + def __init__(self, hs): + super(SearchRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) diff --git a/synapse/server.py b/synapse/server.py index 6bb4988309..af3246504b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.room import RoomListHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler +from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -94,6 +95,8 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'event_handler', + 'event_stream_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -214,6 +217,12 @@ class HomeServer(object): def build_application_service_handler(self): return ApplicationServicesHandler(self) + def build_event_handler(self): + return EventHandler(self) + + def build_event_stream_handler(self): + return EventStreamHandler(self) + def build_event_sources(self): return EventSources(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 73fb334dd6..7efc5bfeef 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -50,6 +50,7 @@ from .openid import OpenIdStore from .client_ips import ClientIpStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .engines import PostgresEngine from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_invalidation_stream", "stream_id", + ) + else: + self._cache_id_gen = None + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0117fdc639..b0923a9cad 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache from synapse.util.caches import intern_dict +from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -305,13 +306,14 @@ class SQLBaseStore(object): func, *args, **kwargs ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - - for after_callback, after_args in after_callbacks: - after_callback(*after_args) + try: + with PreserveLoggingContext(): + result = yield self._db_pool.runWithConnection( + inner_func, *args, **kwargs + ) + finally: + for after_callback, after_args in after_callbacks: + after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks @@ -860,6 +862,58 @@ class SQLBaseStore(object): return cache, min_val + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_after(ctx.__exit__, None, None, None) + + self._simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_func.__name__, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + } + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit,)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index ef231a04dc..9caaf81f2c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore): Returns: Deferred """ - try: - yield self._simple_insert( + def alias_txn(txn): + self._simple_insert_txn( + txn, "room_aliases", { "room_alias": room_alias.to_string(), "room_id": room_id, "creator": creator, }, - desc="create_room_alias_association", - ) - except self.database_engine.module.IntegrityError: - raise SynapseError( - 409, "Room alias %s already exists" % room_alias.to_string() ) - for server in servers: - # TODO(erikj): Fix this to bulk insert - yield self._simple_insert( - "room_alias_servers", - { + self._simple_insert_many_txn( + txn, + table="room_alias_servers", + values=[{ "room_alias": room_alias.to_string(), "server": server, - }, - desc="create_room_alias_association", + } for server in servers], ) - self.get_aliases_for_room.invalidate((room_id,)) + + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + + try: + ret = yield self.runInteraction( + "create_room_alias_association", alias_txn + ) + except self.database_engine.module.IntegrityError: + raise SynapseError( + 409, "Room alias %s already exists" % room_alias.to_string() + ) + defer.returnValue(ret) def get_room_alias_creator(self, room_alias): return self._simple_select_one_onecol( diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 8801669a6b..b94ce7bea1 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 33 +SCHEMA_VERSION = 34 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index d03f7c541e..21d0696640 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore): desc="add_presence_list_pending", ) - @defer.inlineCallbacks def set_presence_list_accepted(self, observer_localpart, observed_userid): - result = yield self._simple_update_one( - table="presence_list", - keyvalues={"user_id": observer_localpart, - "observed_user_id": observed_userid}, - updatevalues={"accepted": True}, - desc="set_presence_list_accepted", + def update_presence_list_txn(txn): + result = self._simple_update_one_txn( + txn, + table="presence_list", + keyvalues={ + "user_id": observer_localpart, + "observed_user_id": observed_userid + }, + updatevalues={"accepted": True}, + ) + + self._invalidate_cache_and_stream( + txn, self.get_presence_list_accepted, (observer_localpart,) + ) + self._invalidate_cache_and_stream( + txn, self.get_presence_list_observers_accepted, (observed_userid,) + ) + + return result + + return self.runInteraction( + "set_presence_list_accepted", update_presence_list_txn, ) - self.get_presence_list_accepted.invalidate((observer_localpart,)) - self.get_presence_list_observers_accepted.invalidate((observed_userid,)) - defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): if accepted: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8bd693be72..a422ddf633 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore): user_id, membership_list=[Membership.JOIN], ) - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore): " room_id = ?" ) txn.execute(sql, (user_id, room_id)) - yield self.runInteraction("forget_membership", f) - self.was_forgotten_at.invalidate_all() - self.who_forgot_in_room.invalidate_all() - self.did_forget.invalidate((user_id, room_id)) + + txn.call_after(self.was_forgotten_at.invalidate_all) + txn.call_after(self.did_forget.invalidate, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.who_forgot_in_room, (room_id,) + ) + return self.runInteraction("forget_membership", f) @cachedInlineCallbacks(num_args=2) def did_forget(self, user_id, room_id): diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py new file mode 100644 index 0000000000..3b63a1562d --- /dev/null +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -0,0 +1,46 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +# This stream is used to notify replication slaves that some caches have +# been invalidated that they cannot infer from the other streams. +CREATE_TABLE = """ +CREATE TABLE cache_invalidation_stream ( + stream_id BIGINT, + cache_func TEXT, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass |