diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index a2a0f364cf..253a6ef6c7 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
import argparse
import curses
@@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
+ "presence_stream": ["currently_active"],
}
@@ -292,7 +294,7 @@ class Porter(object):
}
)
- database_engine.prepare_database(db_conn)
+ prepare_database(db_conn, database_engine, config=None)
db_conn.commit()
@@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"]
)
- sqlite_engine = create_engine(FakeConfig(sqlite_config))
- postgres_engine = create_engine(FakeConfig(postgres_config))
+ sqlite_engine = create_engine(sqlite_config)
+ postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
@@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)
-
-
-class FakeConfig:
- def __init__(self, database_config):
- self.database_config = database_config
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index fcdc8e6e10..2b4473b9ac 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -33,7 +33,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
-from synapse.storage.prepare_database import UpgradeDatabaseException
+from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer
@@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
- def get_db_conn(self):
+ def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
@@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer):
}
db_conn = self.database_engine.module.connect(**db_params)
- self.database_engine.on_new_connection(db_conn)
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
return db_conn
@@ -386,7 +387,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config)
- database_engine = create_engine(config)
+ database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
@@ -402,8 +403,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
- db_conn = hs.get_db_conn()
- database_engine.prepare_database(db_conn)
+ db_conn = hs.get_db_conn(run_new_connection=False)
+ prepare_database(db_conn, database_engine, config=config)
+ database_engine.on_new_connection(db_conn)
+
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 5eeb7042c6..88d8b9ba54 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = (
)
+MEMBERSHIP_PRIORITY = (
+ Membership.JOIN,
+ Membership.INVITE,
+ Membership.KNOCK,
+ Membership.LEAVE,
+ Membership.BAN,
+)
+
+
class BaseHandler(object):
"""
Common base class for the event handlers.
@@ -72,6 +81,7 @@ class BaseHandler(object):
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the
given events
+ events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
@@ -86,6 +96,12 @@ class BaseHandler(object):
)
def allowed(event, user_id, is_peeking):
+ """
+ Args:
+ event (synapse.events.EventBase): event to check
+ user_id (str)
+ is_peeking (bool)
+ """
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
@@ -117,17 +133,30 @@ class BaseHandler(object):
if old_priority < new_priority:
visibility = prev_visibility
- # get the user's membership at the time of the event. (or rather,
- # just *after* the event. Which means that people can see their
- # own join events, but not (currently) their own leave events.)
- membership_event = state.get((EventTypes.Member, user_id), None)
- if membership_event:
- if membership_event.event_id in event_id_forgotten:
- membership = None
- else:
- membership = membership_event.membership
- else:
- membership = None
+ # likewise, if the event is the user's own membership event, use
+ # the 'most joined' membership
+ membership = None
+ if event.type == EventTypes.Member and event.state_key == user_id:
+ membership = event.content.get("membership", None)
+ if membership not in MEMBERSHIP_PRIORITY:
+ membership = "leave"
+
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership not in MEMBERSHIP_PRIORITY:
+ prev_membership = "leave"
+
+ new_priority = MEMBERSHIP_PRIORITY.index(membership)
+ old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
+ if old_priority < new_priority:
+ membership = prev_membership
+
+ # otherwise, get the user's membership at the time of the event.
+ if membership is None:
+ membership_event = state.get((EventTypes.Member, user_id), None)
+ if membership_event:
+ if membership_event.event_id not in event_id_forgotten:
+ membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
@@ -204,20 +233,25 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
- def _create_new_client_event(self, builder):
- latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
- builder.room_id,
- )
-
- if latest_ret:
- depth = max([d for _, _, d in latest_ret]) + 1
+ def _create_new_client_event(self, builder, prev_event_ids=None):
+ if prev_event_ids:
+ prev_events = yield self.store.add_event_hashes(prev_event_ids)
+ prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
+ depth = prev_max_depth + 1
else:
- depth = 1
+ latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
+ builder.room_id,
+ )
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in latest_ret
- ]
+ if latest_ret:
+ depth = max([d for _, _, d in latest_ret]) + 1
+ else:
+ depth = 1
+
+ prev_events = [
+ (event_id, prev_hashes)
+ for event_id, prev_hashes, _ in latest_ret
+ ]
builder.prev_events = prev_events
builder.depth = depth
@@ -226,49 +260,6 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
- # If we've received an invite over federation, there are no latest
- # events in the room, because we don't know enough about the graph
- # fragment we received to treat it like a graph, so the above returned
- # no relevant events. It may have returned some events (if we have
- # joined and left the room), but not useful ones, like the invite.
- if (
- not self.is_host_in_room(context.current_state) and
- builder.type == EventTypes.Member
- ):
- prev_member_event = yield self.store.get_room_member(
- builder.sender, builder.room_id
- )
-
- # The prev_member_event may already be in context.current_state,
- # despite us not being present in the room; in particular, if
- # inviting user, and all other local users, have already left.
- #
- # In that case, we have all the information we need, and we don't
- # want to drop "context" - not least because we may need to handle
- # the invite locally, which will require us to have the whole
- # context (not just prev_member_event) to auth it.
- #
- context_event_ids = (
- e.event_id for e in context.current_state.values()
- )
-
- if (
- prev_member_event and
- prev_member_event.event_id not in context_event_ids
- ):
- # The prev_member_event is missing from context, so it must
- # have arrived over federation and is an outlier. We forcibly
- # set our context to the invite we received over federation
- builder.prev_events = (
- prev_member_event.event_id,
- prev_member_event.prev_events
- )
-
- context = yield state_handler.compute_event_context(
- builder,
- old_state=(prev_member_event,)
- )
-
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index adafd06b24..eb02f0e000 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -284,6 +284,9 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id`
"""
+ if dest == self.server_name:
+ raise SynapseError(400, "Can't backfill from self.")
+
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
@@ -450,7 +453,7 @@ class FederationHandler(BaseHandler):
likely_domains = [
domain for domain, depth in curr_domains
- if domain is not self.server_name
+ if domain != self.server_name
]
@defer.inlineCallbacks
@@ -784,13 +787,19 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
- origin, event = yield self._make_and_verify_event(
- target_hosts,
- room_id,
- user_id,
- "leave"
- )
- signed_event = self._sign_event(event)
+ try:
+ origin, event = yield self._make_and_verify_event(
+ target_hosts,
+ room_id,
+ user_id,
+ "leave"
+ )
+ signed_event = self._sign_event(event)
+ except SynapseError:
+ raise
+ except CodeMessageException as e:
+ logger.warn("Failed to reject invite: %s", e)
+ raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/
# request first.
@@ -800,10 +809,16 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
- yield self.replication_layer.send_leave(
- target_hosts,
- signed_event
- )
+ try:
+ yield self.replication_layer.send_leave(
+ target_hosts,
+ signed_event
+ )
+ except SynapseError:
+ raise
+ except CodeMessageException as e:
+ logger.warn("Failed to reject invite: %s", e)
+ raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0bb111d047..10608c0dd9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -176,7 +176,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_event(self, event_dict, token_id=None, txn_id=None):
+ def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
"""
Given a dict from a client, create a new event.
@@ -187,6 +187,9 @@ class MessageHandler(BaseHandler):
Args:
event_dict (dict): An entire event
+ token_id (str)
+ txn_id (str)
+ prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
@@ -225,6 +228,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
+ prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 01f833c371..b6ef3c91af 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.async import Linearizer
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
@@ -60,6 +61,8 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
+ self.member_linearizer = Linearizer()
+
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
@@ -96,6 +99,82 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain)
@defer.inlineCallbacks
+ def _local_membership_update(
+ self, requester, target, room_id, membership,
+ prev_event_ids,
+ txn_id=None,
+ ratelimit=True,
+ ):
+ msg_handler = self.hs.get_handlers().message_handler
+
+ content = {"membership": membership}
+ if requester.is_guest:
+ content["kind"] = "guest"
+
+ event, context = yield msg_handler.create_event(
+ {
+ "type": EventTypes.Member,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "state_key": target.to_string(),
+
+ # For backwards compatibility:
+ "membership": membership,
+ },
+ token_id=requester.access_token_id,
+ txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ )
+
+ yield self.handle_new_client_event(
+ requester,
+ event,
+ context,
+ extra_users=[target],
+ ratelimit=ratelimit,
+ )
+
+ prev_member_event = context.current_state.get(
+ (EventTypes.Member, target.to_string()),
+ None
+ )
+
+ if event.membership == Membership.JOIN:
+ if not prev_member_event or prev_member_event.membership != Membership.JOIN:
+ # Only fire user_joined_room if the user has acutally joined the
+ # room. Don't bother if the user is just changing their profile
+ # info.
+ yield user_joined_room(self.distributor, target, room_id)
+ elif event.membership == Membership.LEAVE:
+ if prev_member_event and prev_member_event.membership == Membership.JOIN:
+ user_left_room(self.distributor, target, room_id)
+
+ @defer.inlineCallbacks
+ def remote_join(self, remote_room_hosts, room_id, user, content):
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ # We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+ yield self.hs.get_handlers().federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user.to_string(),
+ content,
+ )
+ yield user_joined_room(self.distributor, user, room_id)
+
+ def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
+ return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ user_id
+ )
+
+ @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -107,6 +186,34 @@ class RoomMemberHandler(BaseHandler):
third_party_signed=None,
ratelimit=True,
):
+ key = (target, room_id,)
+
+ with (yield self.member_linearizer.queue(key)):
+ result = yield self._update_membership(
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=txn_id,
+ remote_room_hosts=remote_room_hosts,
+ third_party_signed=third_party_signed,
+ ratelimit=ratelimit,
+ )
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _update_membership(
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ ):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -120,28 +227,15 @@ class RoomMemberHandler(BaseHandler):
third_party_signed,
)
- msg_handler = self.hs.get_handlers().message_handler
+ if not remote_room_hosts:
+ remote_room_hosts = []
- content = {"membership": effective_membership_state}
- if requester.is_guest:
- content["kind"] = "guest"
-
- event, context = yield msg_handler.create_event(
- {
- "type": EventTypes.Member,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "state_key": target.to_string(),
-
- # For backwards compatibility:
- "membership": effective_membership_state,
- },
- token_id=requester.access_token_id,
- txn_id=txn_id,
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ current_state = yield self.state_handler.get_current_state(
+ room_id, latest_event_ids=latest_event_ids,
)
- old_state = context.current_state.get((EventTypes.Member, event.state_key))
+ old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@@ -156,13 +250,73 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE
)
- member_handler = self.hs.get_handlers().room_member_handler
- yield member_handler.send_membership_event(
- requester,
- event,
- context,
+ is_host_in_room = self.is_host_in_room(current_state)
+
+ if effective_membership_state == Membership.JOIN:
+ if requester.is_guest and not self._can_guest_join(current_state):
+ # This should be an auth check, but guests are a local concept,
+ # so don't really fit into the general auth process.
+ raise AuthError(403, "Guest access not allowed")
+
+ if not is_host_in_room:
+ inviter = yield self.get_inviter(target.to_string(), room_id)
+ if inviter and not self.hs.is_mine(inviter):
+ remote_room_hosts.append(inviter.domain)
+
+ content = {"membership": Membership.JOIN}
+
+ profile = self.hs.get_handlers().profile_handler
+ content["displayname"] = yield profile.get_displayname(target)
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+
+ if requester.is_guest:
+ content["kind"] = "guest"
+
+ ret = yield self.remote_join(
+ remote_room_hosts, room_id, target, content
+ )
+ defer.returnValue(ret)
+
+ elif effective_membership_state == Membership.LEAVE:
+ if not is_host_in_room:
+ # perhaps we've been invited
+ inviter = yield self.get_inviter(target.to_string(), room_id)
+ if not inviter:
+ raise SynapseError(404, "Not a known room")
+
+ if self.hs.is_mine(inviter):
+ # the inviter was on our server, but has now left. Carry on
+ # with the normal rejection codepath.
+ #
+ # This is a bit of a hack, because the room might still be
+ # active on other servers.
+ pass
+ else:
+ # send the rejection to the inviter's HS.
+ remote_room_hosts = remote_room_hosts + [inviter.domain]
+
+ try:
+ ret = yield self.reject_remote_invite(
+ target.to_string(), room_id, remote_room_hosts
+ )
+ defer.returnValue(ret)
+ except SynapseError as e:
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ target.to_string(), room_id
+ )
+
+ defer.returnValue({})
+
+ yield self._local_membership_update(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ membership=effective_membership_state,
+ txn_id=txn_id,
ratelimit=ratelimit,
- remote_room_hosts=remote_room_hosts,
+ prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
@@ -211,73 +365,19 @@ class RoomMemberHandler(BaseHandler):
if prev_event is not None:
return
- action = "send"
-
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- do_remote_join_dance, remote_room_hosts = self._should_do_dance(
- context,
- (self.get_inviter(event.state_key, context.current_state)),
- remote_room_hosts,
- )
- if do_remote_join_dance:
- action = "remote_join"
- elif event.membership == Membership.LEAVE:
- is_host_in_room = self.is_host_in_room(context.current_state)
- if not is_host_in_room:
- # perhaps we've been invited
- inviter = self.get_inviter(
- target_user.to_string(), context.current_state
- )
- if not inviter:
- raise SynapseError(404, "Not a known room")
-
- if self.hs.is_mine(inviter):
- # the inviter was on our server, but has now left. Carry on
- # with the normal rejection codepath.
- #
- # This is a bit of a hack, because the room might still be
- # active on other servers.
- pass
- else:
- # send the rejection to the inviter's HS.
- remote_room_hosts = remote_room_hosts + [inviter.domain]
- action = "remote_reject"
-
- federation_handler = self.hs.get_handlers().federation_handler
-
- if action == "remote_join":
- if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
-
- # We don't do an auth check if we are doing an invite
- # join dance for now, since we're kinda implicitly checking
- # that we are allowed to join when we decide whether or not we
- # need to do the invite/join dance.
- yield federation_handler.do_invite_join(
- remote_room_hosts,
- event.room_id,
- event.user_id,
- event.content,
- )
- elif action == "remote_reject":
- yield federation_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- event.user_id
- )
- else:
- yield self.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target_user],
- ratelimit=ratelimit,
- )
+ yield self.handle_new_client_event(
+ requester,
+ event,
+ context,
+ extra_users=[target_user],
+ ratelimit=ratelimit,
+ )
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
@@ -306,11 +406,11 @@ class RoomMemberHandler(BaseHandler):
and guest_access.content["guest_access"] == "can_join"
)
- def _should_do_dance(self, context, inviter, room_hosts=None):
+ def _should_do_dance(self, current_state, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
- is_host_in_room = self.is_host_in_room(context.current_state)
+ is_host_in_room = self.is_host_in_room(current_state)
if is_host_in_room:
return False, room_hosts
@@ -344,11 +444,14 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers))
- def get_inviter(self, user_id, current_state):
- prev_state = current_state.get((EventTypes.Member, user_id))
- if prev_state and prev_state.membership == Membership.INVITE:
- return UserID.from_string(prev_state.user_id)
- return None
+ @defer.inlineCallbacks
+ def get_inviter(self, user_id, room_id):
+ invite = yield self.store.get_invite_for_user_in_room(
+ user_id=user_id,
+ room_id=room_id,
+ )
+ if invite:
+ defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index c51a6fa103..a543af68f8 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -145,32 +145,43 @@ class ReplicationResource(Resource):
timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json")
- writer = _Writer(request)
- @defer.inlineCallbacks
- def replicate():
- current_token = yield self.current_replication_token()
- logger.info("Replicating up to %r", current_token)
-
- yield self.account_data(writer, current_token, limit)
- yield self.events(writer, current_token, limit)
- yield self.presence(writer, current_token) # TODO: implement limit
- yield self.typing(writer, current_token) # TODO: implement limit
- yield self.receipts(writer, current_token, limit)
- yield self.push_rules(writer, current_token, limit)
- yield self.pushers(writer, current_token, limit)
- yield self.state(writer, current_token, limit)
- self.streams(writer, current_token)
+ request_streams = {
+ name: parse_integer(request, name)
+ for names in STREAM_NAMES for name in names
+ }
+ request_streams["streams"] = parse_string(request, "streams")
- logger.info("Replicated %d rows", writer.total)
- defer.returnValue(writer.total)
+ def replicate():
+ return self.replicate(request_streams, limit)
- yield self.notifier.wait_for_replication(replicate, timeout)
+ result = yield self.notifier.wait_for_replication(replicate, timeout)
- writer.finish()
+ request.write(json.dumps(result, ensure_ascii=False))
+ finish_request(request)
- def streams(self, writer, current_token):
- request_token = parse_string(writer.request, "streams")
+ @defer.inlineCallbacks
+ def replicate(self, request_streams, limit):
+ writer = _Writer()
+ current_token = yield self.current_replication_token()
+ logger.info("Replicating up to %r", current_token)
+
+ yield self.account_data(writer, current_token, limit, request_streams)
+ yield self.events(writer, current_token, limit, request_streams)
+ # TODO: implement limit
+ yield self.presence(writer, current_token, request_streams)
+ yield self.typing(writer, current_token, request_streams)
+ yield self.receipts(writer, current_token, limit, request_streams)
+ yield self.push_rules(writer, current_token, limit, request_streams)
+ yield self.pushers(writer, current_token, limit, request_streams)
+ yield self.state(writer, current_token, limit, request_streams)
+ self.streams(writer, current_token, request_streams)
+
+ logger.info("Replicated %d rows", writer.total)
+ defer.returnValue(writer.finish())
+
+ def streams(self, writer, current_token, request_streams):
+ request_token = request_streams.get("streams")
streams = []
@@ -195,9 +206,9 @@ class ReplicationResource(Resource):
)
@defer.inlineCallbacks
- def events(self, writer, current_token, limit):
- request_events = parse_integer(writer.request, "events")
- request_backfill = parse_integer(writer.request, "backfill")
+ def events(self, writer, current_token, limit, request_streams):
+ request_events = request_streams.get("events")
+ request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None:
if request_events is None:
@@ -228,10 +239,10 @@ class ReplicationResource(Resource):
)
@defer.inlineCallbacks
- def presence(self, writer, current_token):
+ def presence(self, writer, current_token, request_streams):
current_position = current_token.presence
- request_presence = parse_integer(writer.request, "presence")
+ request_presence = request_streams.get("presence")
if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates(
@@ -244,10 +255,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def typing(self, writer, current_token):
+ def typing(self, writer, current_token, request_streams):
current_position = current_token.presence
- request_typing = parse_integer(writer.request, "typing")
+ request_typing = request_streams.get("typing")
if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates(
@@ -258,10 +269,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def receipts(self, writer, current_token, limit):
+ def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts
- request_receipts = parse_integer(writer.request, "receipts")
+ request_receipts = request_streams.get("receipts")
if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts(
@@ -272,12 +283,12 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def account_data(self, writer, current_token, limit):
+ def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data
- user_account_data = parse_integer(writer.request, "user_account_data")
- room_account_data = parse_integer(writer.request, "room_account_data")
- tag_account_data = parse_integer(writer.request, "tag_account_data")
+ user_account_data = request_streams.get("user_account_data")
+ room_account_data = request_streams.get("room_account_data")
+ tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None:
if user_account_data is None:
@@ -303,10 +314,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def push_rules(self, writer, current_token, limit):
+ def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules
- push_rules = parse_integer(writer.request, "push_rules")
+ push_rules = request_streams.get("push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
@@ -318,10 +329,11 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def pushers(self, writer, current_token, limit):
+ def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers
- pushers = parse_integer(writer.request, "pushers")
+ pushers = request_streams.get("pushers")
+
if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit
@@ -336,10 +348,11 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
- def state(self, writer, current_token, limit):
+ def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
- state = parse_integer(writer.request, "state")
+ state = request_streams.get("state")
+
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
@@ -356,9 +369,8 @@ class ReplicationResource(Resource):
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
- def __init__(self, request):
+ def __init__(self):
self.streams = {}
- self.request = request
self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None):
@@ -377,8 +389,7 @@ class _Writer(object):
self.total += len(rows)
def finish(self):
- self.request.write(json.dumps(self.streams, ensure_ascii=False))
- finish_request(self.request)
+ return self.streams
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/synapse/replication/slave/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/synapse/replication/slave/storage/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
new file mode 100644
index 0000000000..46e43ce1c7
--- /dev/null
+++ b/synapse/replication/slave/storage/_base.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage._base import SQLBaseStore
+from twisted.internet import defer
+
+
+class BaseSlavedStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(BaseSlavedStore, self).__init__(hs)
+
+ def stream_positions(self):
+ return {}
+
+ def process_replication(self, result):
+ return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
new file mode 100644
index 0000000000..24b5c79d4a
--- /dev/null
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.util.id_generators import _load_current_id
+
+
+class SlavedIdTracker(object):
+ def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ self.step = step
+ self._current = _load_current_id(db_conn, table, column, step)
+ for table, column in extra_tables:
+ self.advance(_load_current_id(db_conn, table, column))
+
+ def advance(self, new_id):
+ self._current = (max if self.step > 0 else min)(self._current, new_id)
+
+ def get_current_token(self):
+ return self._current
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
new file mode 100644
index 0000000000..707ddd248a
--- /dev/null
+++ b/synapse/replication/slave/storage/events.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
+from synapse.api.constants import EventTypes
+from synapse.events import FrozenEvent
+from synapse.storage import DataStore
+from synapse.storage.room import RoomStore
+from synapse.storage.roommember import RoomMemberStore
+from synapse.storage.event_federation import EventFederationStore
+from synapse.storage.state import StateStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+import ujson as json
+
+# So, um, we want to borrow a load of functions intended for reading from
+# a DataStore, but we don't want to take functions that either write to the
+# DataStore or are cached and don't have cache invalidation logic.
+#
+# Rather than write duplicate versions of those functions, or lift them to
+# a common base class, we going to grab the underlying __func__ object from
+# the method descriptor on the DataStore and chuck them into our class.
+
+
+class SlavedEventStore(BaseSlavedStore):
+
+ def __init__(self, db_conn, hs):
+ super(SlavedEventStore, self).__init__(db_conn, hs)
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
+ events_max = self._stream_id_gen.get_current_token()
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+
+ # Cached functions can't be accessed through a class instance so we need
+ # to reach inside the __dict__ to extract them.
+ get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
+ get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
+ get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
+ get_latest_event_ids_in_room = EventFederationStore.__dict__[
+ "get_latest_event_ids_in_room"
+ ]
+ _get_current_state_for_key = StateStore.__dict__[
+ "_get_current_state_for_key"
+ ]
+
+ get_current_state = DataStore.get_current_state.__func__
+ get_current_state_for_key = DataStore.get_current_state_for_key.__func__
+ get_rooms_for_user_where_membership_is = (
+ DataStore.get_rooms_for_user_where_membership_is.__func__
+ )
+ get_membership_changes_for_user = (
+ DataStore.get_membership_changes_for_user.__func__
+ )
+ get_room_events_max_id = DataStore.get_room_events_max_id.__func__
+ get_room_events_stream_for_room = (
+ DataStore.get_room_events_stream_for_room.__func__
+ )
+ _set_before_and_after = DataStore._set_before_and_after
+
+ _get_events = DataStore._get_events.__func__
+ _get_events_from_cache = DataStore._get_events_from_cache.__func__
+
+ _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
+ _parse_events_txn = DataStore._parse_events_txn.__func__
+ _get_events_txn = DataStore._get_events_txn.__func__
+ _enqueue_events = DataStore._enqueue_events.__func__
+ _do_fetch = DataStore._do_fetch.__func__
+ _fetch_events_txn = DataStore._fetch_events_txn.__func__
+ _fetch_event_rows = DataStore._fetch_event_rows.__func__
+ _get_event_from_row = DataStore._get_event_from_row.__func__
+ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
+ _get_rooms_for_user_where_membership_is_txn = (
+ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
+ )
+ _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
+
+ def stream_positions(self):
+ result = super(SlavedEventStore, self).stream_positions()
+ result["events"] = self._stream_id_gen.get_current_token()
+ result["backfilled"] = self._backfill_id_gen.get_current_token()
+ return result
+
+ def process_replication(self, result):
+ state_resets = set(
+ r[0] for r in result.get("state_resets", {"rows": []})["rows"]
+ )
+
+ stream = result.get("events")
+ if stream:
+ self._stream_id_gen.advance(stream["position"])
+ for row in stream["rows"]:
+ self._process_replication_row(
+ row, backfilled=False, state_resets=state_resets
+ )
+
+ stream = result.get("backfill")
+ if stream:
+ self._backfill_id_gen.advance(stream["position"])
+ for row in stream["rows"]:
+ self._process_replication_row(
+ row, backfilled=True, state_resets=state_resets
+ )
+
+ stream = result.get("forward_ex_outliers")
+ if stream:
+ for row in stream["rows"]:
+ event_id = row[1]
+ self._invalidate_get_event_cache(event_id)
+
+ stream = result.get("backward_ex_outliers")
+ if stream:
+ for row in stream["rows"]:
+ event_id = row[1]
+ self._invalidate_get_event_cache(event_id)
+
+ return super(SlavedEventStore, self).process_replication(result)
+
+ def _process_replication_row(self, row, backfilled, state_resets):
+ position = row[0]
+ internal = json.loads(row[1])
+ event_json = json.loads(row[2])
+
+ event = FrozenEvent(event_json, internal_metadata_dict=internal)
+ self._invalidate_caches_for_event(
+ event, backfilled, reset_state=position in state_resets
+ )
+
+ def _invalidate_caches_for_event(self, event, backfilled, reset_state):
+ if reset_state:
+ self._get_current_state_for_key.invalidate_all()
+ self.get_rooms_for_user.invalidate_all()
+ self.get_users_in_room.invalidate((event.room_id,))
+ # self.get_joined_hosts_for_room.invalidate((event.room_id,))
+ self.get_room_name_and_aliases.invalidate((event.room_id,))
+
+ self._invalidate_get_event_cache(event.event_id)
+
+ self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+
+ if not backfilled:
+ self._events_stream_cache.entity_has_changed(
+ event.room_id, event.internal_metadata.stream_ordering
+ )
+
+ # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
+ # (event.room_id,)
+ # )
+
+ if event.type == EventTypes.Redaction:
+ self._invalidate_get_event_cache(event.redacts)
+
+ if event.type == EventTypes.Member:
+ self.get_rooms_for_user.invalidate((event.state_key,))
+ # self.get_joined_hosts_for_room.invalidate((event.room_id,))
+ self.get_users_in_room.invalidate((event.room_id,))
+ # self._membership_stream_cache.entity_has_changed(
+ # event.state_key, event.internal_metadata.stream_ordering
+ # )
+
+ if not event.is_state():
+ return
+
+ if backfilled:
+ return
+
+ if (not event.internal_metadata.is_invite_from_remote()
+ and event.internal_metadata.is_outlier()):
+ return
+
+ self._get_current_state_for_key.invalidate((
+ event.room_id, event.type, event.state_key
+ ))
+
+ if event.type in [EventTypes.Name, EventTypes.Aliases]:
+ self.get_room_name_and_aliases.invalidate(
+ (event.room_id,)
+ )
+ pass
diff --git a/synapse/state.py b/synapse/state.py
index 1bca0f8f78..58211f5feb 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -75,7 +75,8 @@ class StateHandler(object):
self._state_cache.start()
@defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
+ def get_current_state(self, room_id, event_type=None, state_key="",
+ latest_event_ids=None):
""" Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@@ -89,9 +90,10 @@ class StateHandler(object):
Returns:
map from (type, state_key) to event
"""
- event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- res = yield self.resolve_state_groups(room_id, event_ids)
+ res = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1]
if event_type:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 57863bba4d..045ae6c03f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -94,7 +94,8 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering"
+ db_conn, "events", "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
@@ -176,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.__presence_on_startup = None
return active_on_startup
- def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - 100000"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- }
-
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (int(max_value),))
- rows = txn.fetchall()
- txn.close()
-
- cache = {
- row[0]: int(row[1])
- for row in rows
- }
-
- if cache:
- min_val = min(cache.values())
- else:
- min_val = max_value
-
- return cache, min_val
-
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b75b79df36..04d7fcf6d6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -816,6 +816,40 @@ class SQLBaseStore(object):
self._next_stream_id += 1
return i
+ def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
+ max_value):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - 100000"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ }
+
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+ rows = txn.fetchall()
+ txn.close()
+
+ cache = {
+ row[0]: int(row[1])
+ for row in rows
+ }
+
+ if cache:
+ min_val = min(cache.values())
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a48230b93f..7bb5de1fe7 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
}
-def create_engine(config):
- name = config.database_config["name"]
+def create_engine(database_config):
+ name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
- return engine_class(module, config=config)
+ return engine_class(module)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a09685b4df..c2290943b4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import prepare_database
-
from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
- def __init__(self, database_module, config):
+ def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
- self.config = config
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
- def prepare_database(self, db_conn):
- prepare_database(db_conn, self, config=self.config)
-
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 522b905949..14203aa500 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import (
- prepare_database, prepare_sqlite3_database
-)
+from synapse.storage.prepare_database import prepare_database
import struct
@@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object):
single_threaded = True
- def __init__(self, database_module, config):
+ def __init__(self, database_module):
self.module = database_module
- self.config = config
def check_database(self, txn):
pass
@@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
- self.prepare_database(db_conn)
+ prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
- def prepare_database(self, db_conn):
- prepare_sqlite3_database(db_conn)
- prepare_database(db_conn, self, config=self.config)
-
def is_deadlock(self, error):
return False
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3489315e0d..0827946207 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
+ @defer.inlineCallbacks
+ def get_max_depth_of_events(self, event_ids):
+ sql = (
+ "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
+ ) % (",".join(["?"] * len(event_ids)),)
+
+ rows = yield self._execute(
+ "get_max_depth_of_events", None,
+ sql, *event_ids
+ )
+
+ if rows:
+ defer.returnValue(rows[0][0])
+ else:
+ defer.returnValue(1)
+
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c4dc3b3d51..ee87a71719 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -367,7 +367,8 @@ class EventsStore(SQLBaseStore):
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
- ]
+ ],
+ backfilled=backfilled,
)
def event_dict(event):
@@ -485,14 +486,8 @@ class EventsStore(SQLBaseStore):
return
for event, _ in state_events_and_contexts:
- if (not event.internal_metadata.is_invite_from_remote()
- and event.internal_metadata.is_outlier()):
- # Outlier events generally shouldn't clobber the current state.
- # However invites from remote severs for rooms we aren't in
- # are a bit special: they don't come with any associated
- # state so are technically an outlier, however all the
- # client-facing code assumes that they are in the current
- # state table so we insert the event anyway.
+ if event.internal_metadata.is_outlier():
+ # Outlier events shouldn't clobber the current state.
continue
if context.rejected:
@@ -1139,7 +1134,7 @@ class EventsStore(SQLBaseStore):
upper_bound = current_forward_id
sql = (
- "SELECT -event_stream_ordering FROM current_state_resets"
+ "SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
@@ -1148,7 +1143,7 @@ class EventsStore(SQLBaseStore):
state_resets = txn.fetchall()
sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
+ "SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3f29aad1e8..00833422af 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 30
+SCHEMA_VERSION = 31
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
+
+ If `config` is None then prepare_database will assert that no upgrade is
+ necessary, *or* will create a fresh database if the database is empty.
"""
try:
cur = db_conn.cursor()
@@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info:
user_version, delta_files, upgraded = version_info
- _upgrade_existing_database(
- cur, user_version, delta_files, upgraded, database_engine, config
- )
- else:
- _setup_new_database(cur, database_engine, config)
- # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+ if config is None:
+ if user_version != SCHEMA_VERSION:
+ # If we don't pass in a config file then we are expecting to
+ # have already upgraded the DB.
+ raise UpgradeDatabaseException("Database needs to be upgraded")
+ else:
+ _upgrade_existing_database(
+ cur, user_version, delta_files, upgraded, database_engine, config
+ )
+ else:
+ _setup_new_database(cur, database_engine)
cur.close()
db_conn.commit()
@@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config):
raise
-def _setup_new_database(cur, database_engine, config):
+def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
- config=config,
+ config=None,
+ is_empty=True,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
- upgraded, database_engine, config):
+ upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_create(cur, database_engine)
+ if not is_empty:
+ module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
-
-
-def prepare_sqlite3_database(db_conn):
- """This function should be called before `prepare_database` on sqlite3
- databases.
-
- Since we changed the way we store the current schema version and handle
- updates to schemas, we need a way to upgrade from the old method to the
- new. This only affects sqlite databases since they were the only ones
- supported at the time.
- """
- with db_conn:
- schema_path = os.path.join(
- dir_path, "schema", "schema_version.sql",
- )
- create_schema = read_schema(schema_path)
- db_conn.executescript(create_schema)
-
- c = db_conn.execute("SELECT * FROM schema_version")
- rows = c.fetchall()
- c.close()
-
- if not rows:
- c = db_conn.execute("PRAGMA user_version")
- row = c.fetchone()
- c.close()
-
- if row and row[0]:
- db_conn.execute(
- "REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
- (row[0], False)
- )
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 4befebc8e2..7fdd84bbdc 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content,
}])
- @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
- num_args=3, inlineCallbacks=True)
+ @cachedList(cached_method_name="get_linearized_receipts_for_room",
+ list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index d46a963bb8..1f71773aaa 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
- @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+ @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 430b49c12e..66e7a40e3c 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def _store_room_members_txn(self, txn, events):
+ def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
@@ -62,6 +62,64 @@ class RoomMemberStore(SQLBaseStore):
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -127,6 +185,24 @@ class RoomMemberStore(SQLBaseStore):
user_id, [Membership.INVITE]
)
+ @defer.inlineCallbacks
+ def get_invite_for_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given user and room
+
+ Args:
+ user_id (str)
+ room_id (str)
+
+ Returns:
+ Deferred: Resolves to either a RoomsForUser or None if no invite was
+ found.
+ """
+ invites = yield self.get_invited_rooms_for_user(user_id)
+ for invite in invites:
+ if invite.room_id == room_id:
+ defer.returnValue(invite)
+ defer.returnValue(None)
+
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
Args:
@@ -163,29 +239,55 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
- where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
- )
- args = [user_id]
- args.extend(membership_list)
+ do_invite = Membership.INVITE in membership_list
+ membership_list = [m for m in membership_list if m != Membership.INVITE]
- sql = (
- "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
- " FROM current_state_events as c"
- " INNER JOIN room_memberships as m"
- " ON m.event_id = c.event_id"
- " INNER JOIN events as e"
- " ON e.event_id = c.event_id"
- " AND m.room_id = c.room_id"
- " AND m.user_id = c.state_key"
- " WHERE %s"
- ) % (where_clause,)
+ results = []
+ if membership_list:
+ where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
+ " OR ".join(["membership = ?" for _ in membership_list]),
+ )
+
+ args = [user_id]
+ args.extend(membership_list)
+
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
+ " FROM current_state_events as c"
+ " INNER JOIN room_memberships as m"
+ " ON m.event_id = c.event_id"
+ " INNER JOIN events as e"
+ " ON e.event_id = c.event_id"
+ " AND m.room_id = c.room_id"
+ " AND m.user_id = c.state_key"
+ " WHERE %s"
+ ) % (where_clause,)
+
+ txn.execute(sql, args)
+ results = [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
+
+ if do_invite:
+ sql = (
+ "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
+ " FROM local_invites as i"
+ " INNER JOIN events as e USING (event_id)"
+ " WHERE invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (user_id,))
+ results.extend(RoomsForUser(
+ room_id=r["room_id"],
+ sender=r["inviter"],
+ event_id=r["event_id"],
+ stream_ordering=r["stream_ordering"],
+ membership=Membership.INVITE,
+ ) for r in self.cursor_to_dict(txn))
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ return results
@cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 5c40a77757..8755bb2e49 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
-def run_upgrade(cur, *args, **kwargs):
+def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
@@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0])
)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
index 29164732af..147496a38b 100644
--- a/synapse/storage/schema/delta/20/pushers.py
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__)
-def run_upgrade(cur, database_engine, *args, **kwargs):
+def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
@@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index d3ff2b1779..4269ac69ad 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -43,7 +43,7 @@ SQLITE_TABLE = (
)
-def run_upgrade(cur, database_engine, *args, **kwargs):
+def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
@@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index f8c16391a2..71b12a2731 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -27,7 +27,7 @@ ALTER_TABLE = (
)
-def run_upgrade(cur, database_engine, *args, **kwargs):
+def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
@@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 4f6e9dd540..b417e3ac08 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__)
-def run_upgrade(cur, database_engine, config, *args, **kwargs):
+def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice.
try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so...
pass
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users")
rows = cur.fetchall()
diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql
new file mode 100644
index 0000000000..2c57846d5a
--- /dev/null
+++ b/synapse/storage/schema/delta/31/invites.sql
@@ -0,0 +1,42 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE local_invites(
+ stream_id BIGINT NOT NULL,
+ inviter TEXT NOT NULL,
+ invitee TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ locally_rejected TEXT,
+ replaced_by TEXT
+);
+
+-- Insert all invites for local users into new `invites` table
+INSERT INTO local_invites SELECT
+ stream_ordering as stream_id,
+ sender as inviter,
+ state_key as invitee,
+ event_id,
+ room_id,
+ NULL as locally_rejected,
+ NULL as replaced_by
+ FROM events
+ NATURAL JOIN current_state_events
+ NATURAL JOIN room_memberships
+ WHERE membership = 'invite' AND state_key IN (SELECT name FROM users);
+
+CREATE INDEX local_invites_id ON local_invites(stream_id);
+CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e9f9406014..c5d2a3a6df 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -273,8 +273,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event",
)
- @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
- num_args=1, inlineCallbacks=True)
+ @cachedList(cached_method_name="_get_state_group_for_event",
+ list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
diff --git a/synapse/util/async.py b/synapse/util/async.py
index cd4d90f3cf..0d6f48e2d8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,9 +16,13 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext, preserve_fn
+from .logcontext import (
+ PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
from synapse.util import unwrapFirstError
+from contextlib import contextmanager
+
@defer.inlineCallbacks
def sleep(seconds):
@@ -137,3 +141,47 @@ def concurrently_execute(func, args, limit):
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError)
+
+
+class Linearizer(object):
+ """Linearizes access to resources based on a key. Useful to ensure only one
+ thing is happening at a time on a given resource.
+
+ Example:
+
+ with (yield linearizer.queue("test_key")):
+ # do some work.
+
+ """
+ def __init__(self):
+ self.key_to_defer = {}
+
+ @defer.inlineCallbacks
+ def queue(self, key):
+ # If there is already a deferred in the queue, we pull it out so that
+ # we can wait on it later.
+ # Then we replace it with a deferred that we resolve *after* the
+ # context manager has exited.
+ # We only return the context manager after the previous deferred has
+ # resolved.
+ # This all has the net effect of creating a chain of deferreds that
+ # wait for the previous deferred before starting their work.
+ current_defer = self.key_to_defer.get(key)
+
+ new_defer = defer.Deferred()
+ self.key_to_defer[key] = new_defer
+
+ if current_defer:
+ yield preserve_context_over_deferred(current_defer)
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ current_d = self.key_to_defer.get(key)
+ if current_d is new_defer:
+ self.key_to_defer.pop(key, None)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..758f5982b0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -167,7 +167,8 @@ class CacheDescriptor(object):
% (orig.__name__,)
)
- self.cache = Cache(
+ def __get__(self, obj, objtype=None):
+ cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
@@ -175,14 +176,12 @@ class CacheDescriptor(object):
tree=self.tree,
)
- def __get__(self, obj, objtype=None):
-
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result_d = self.cache.get(cache_key)
+ cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -204,7 +203,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = self.cache.sequence
+ sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
@@ -213,20 +212,21 @@ class CacheDescriptor(object):
)
def onErr(f):
- self.cache.invalidate(cache_key)
+ cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- self.cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe())
- wrapped.invalidate = self.cache.invalidate
- wrapped.invalidate_all = self.cache.invalidate_all
- wrapped.invalidate_many = self.cache.invalidate_many
- wrapped.prefill = self.cache.prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
+ wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@@ -240,11 +240,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion.
"""
- def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+ def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ inlineCallbacks=False):
"""
Args:
orig (function)
- cache (Cache)
+ method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
@@ -263,7 +264,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
- self.cache = cache
+ self.cached_method_name = cached_method_name
self.sentinel = object()
@@ -277,11 +278,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cache.name,)
+ % (self.list_name, cached_method_name,)
)
def __get__(self, obj, objtype=None):
+ cache = getattr(obj, self.cached_method_name).cache
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@@ -297,14 +300,14 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
- res = self.cache.get(tuple(key)).observe()
+ res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
- sequence = self.cache.sequence
+ sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -327,10 +330,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- self.cache.update(sequence, tuple(key), observer)
+ cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
- self.cache.invalidate(key)
+ cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
@@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
)
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""
return lambda orig: CacheListDescriptor(
orig,
- cache=cache,
+ cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index be310ba320..36686b479e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,7 +35,7 @@ class ResponseCache(object):
return None
def set(self, key, deferred):
- result = ObservableDeferred(deferred)
+ result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
def remove(r):
diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/replication/slave/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/replication/slave/storage/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
new file mode 100644
index 0000000000..0f525a8943
--- /dev/null
+++ b/tests/replication/slave/storage/_base.py
@@ -0,0 +1,57 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from tests import unittest
+
+from synapse.replication.slave.storage.events import SlavedEventStore
+
+from mock import Mock, NonCallableMock
+from tests.utils import setup_test_homeserver
+from synapse.replication.resource import ReplicationResource
+
+
+class BaseSlavedStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(
+ "blue",
+ http_client=None,
+ replication_layer=Mock(),
+ ratelimiter=NonCallableMock(spec_set=[
+ "send_message",
+ ]),
+ )
+ self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ self.replication = ReplicationResource(self.hs)
+
+ self.master_store = self.hs.get_datastore()
+ self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
+ self.event_id = 0
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ streams = self.slaved_store.stream_positions()
+ result = yield self.replication.replicate(streams, 100)
+ yield self.slaved_store.process_replication(result)
+
+ @defer.inlineCallbacks
+ def check(self, method, args, expected_result=None):
+ master_result = yield getattr(self.master_store, method)(*args)
+ slaved_result = yield getattr(self.slaved_store, method)(*args)
+ self.assertEqual(master_result, slaved_result)
+ if expected_result is not None:
+ self.assertEqual(master_result, expected_result)
+ self.assertEqual(slaved_result, expected_result)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
new file mode 100644
index 0000000000..9af62702b3
--- /dev/null
+++ b/tests/replication/slave/storage/test_events.py
@@ -0,0 +1,258 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.events import FrozenEvent, _EventInternalMetadata
+from synapse.events.snapshot import EventContext
+from synapse.storage.roommember import RoomsForUser
+
+from twisted.internet import defer
+
+
+USER_ID = "@feeling:blue"
+USER_ID_2 = "@bright:blue"
+OUTLIER = {"outlier": True}
+ROOM_ID = "!room:blue"
+
+
+def dict_equals(self, other):
+ return self.__dict__ == other.__dict__
+
+
+def patch__eq__(cls):
+ eq = getattr(cls, "__eq__", None)
+ cls.__eq__ = dict_equals
+
+ def unpatch():
+ if eq is not None:
+ cls.__eq__ = eq
+ return unpatch
+
+
+class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
+
+ def setUp(self):
+ # Patch up the equality operator for events so that we can check
+ # whether lists of events match using assertEquals
+ self.unpatches = [
+ patch__eq__(_EventInternalMetadata),
+ patch__eq__(FrozenEvent),
+ ]
+ return super(SlavedEventStoreTestCase, self).setUp()
+
+ def tearDown(self):
+ [unpatch() for unpatch in self.unpatches]
+
+ @defer.inlineCallbacks
+ def test_room_name_and_aliases(self):
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ yield self.persist(type="m.room.name", key="", name="name1")
+ yield self.persist(
+ type="m.room.aliases", key="blue", aliases=["#1:blue"]
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
+ )
+
+ # Set the room name.
+ yield self.persist(type="m.room.name", key="", name="name2")
+ yield self.replicate()
+ yield self.check(
+ "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
+ )
+
+ # Set the room aliases.
+ yield self.persist(
+ type="m.room.aliases", key="blue", aliases=["#2:blue"]
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
+ )
+
+ # Leave and join the room clobbering the state.
+ yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+ yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ reset_state=[create]
+ )
+ yield self.replicate()
+
+ yield self.check(
+ "get_room_name_and_aliases", (ROOM_ID,), (None, [])
+ )
+
+ @defer.inlineCallbacks
+ def test_room_members(self):
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [])
+ yield self.check("get_users_in_room", (ROOM_ID,), [])
+
+ # Join the room.
+ join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
+ room_id=ROOM_ID,
+ sender=USER_ID,
+ membership="join",
+ event_id=join.event_id,
+ stream_ordering=join.internal_metadata.stream_ordering,
+ )])
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
+
+ # Leave the room.
+ yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [])
+ yield self.check("get_users_in_room", (ROOM_ID,), [])
+
+ # Add some other user to the room.
+ join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
+ room_id=ROOM_ID,
+ sender=USER_ID,
+ membership="join",
+ event_id=join.event_id,
+ stream_ordering=join.internal_metadata.stream_ordering,
+ )])
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
+
+ # Join the room clobbering the state.
+ # This should remove any evidence of the other user being in the room.
+ yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ reset_state=[create]
+ )
+ yield self.replicate()
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
+ yield self.check("get_rooms_for_user", (USER_ID_2,), [])
+
+ @defer.inlineCallbacks
+ def test_get_latest_event_ids_in_room(self):
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check(
+ "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
+ )
+
+ join = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ prev_events=[(create.event_id, {})],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
+ )
+
+ @defer.inlineCallbacks
+ def test_get_current_state(self):
+ # Create the room.
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
+ )
+
+ # Join the room.
+ join1 = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+ [join1]
+ )
+
+ # Add some other user to the room.
+ join2 = yield self.persist(
+ type="m.room.member", key=USER_ID_2, membership="join",
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+ [join2]
+ )
+
+ # Leave the room, then rejoin the room clobbering state.
+ yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+ join3 = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ reset_state=[create]
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+ []
+ )
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+ [join3]
+ )
+
+ event_id = 0
+
+ @defer.inlineCallbacks
+ def persist(
+ self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
+ state=None, reset_state=False, backfill=False,
+ depth=None, prev_events=[], auth_events=[], prev_state=[],
+ **content
+ ):
+ """
+ Returns:
+ synapse.events.FrozenEvent: The event that was persisted.
+ """
+ if depth is None:
+ depth = self.event_id
+
+ event_dict = {
+ "sender": sender,
+ "type": type,
+ "content": content,
+ "event_id": "$%d:blue" % (self.event_id,),
+ "room_id": room_id,
+ "depth": depth,
+ "origin_server_ts": self.event_id,
+ "prev_events": prev_events,
+ "auth_events": auth_events,
+ }
+ if key is not None:
+ event_dict["state_key"] = key
+ event_dict["prev_state"] = prev_state
+
+ event = FrozenEvent(event_dict, internal_metadata_dict=internal)
+
+ self.event_id += 1
+
+ context = EventContext(current_state=state)
+
+ ordering = None
+ if backfill:
+ yield self.master_store.persist_events(
+ [(event, context)], backfilled=True
+ )
+ else:
+ ordering, _ = yield self.master_store.persist_event(
+ event, context, current_state=reset_state
+ )
+
+ if ordering:
+ event.internal_metadata.stream_ordering = ordering
+
+ defer.returnValue(event)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4ab8b35e6b..8853cbb5fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
- yield self.join(room=room, user=usr, expect_code=404)
- yield self.leave(room=room, user=usr, expect_code=404)
+ yield self.join(room=room, user=usr, expect_code=403)
+ yield self.leave(room=room, user=usr, expect_code=403)
@defer.inlineCallbacks
def test_membership_private_room_perms(self):
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 2e33beb07c..afbefb2e2d 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test",
db_pool=self.db_pool,
config=config,
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
)
self.datastore = SQLBaseStore(hs)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
new file mode 100644
index 0000000000..afcba482f9
--- /dev/null
+++ b/tests/util/test_linearizer.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from twisted.internet import defer
+
+from synapse.util.async import Linearizer
+
+
+class LinearizerTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_linearizer(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ self.assertTrue(d2.called)
+
+ with (yield d2):
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index 52405502e9..c179df31ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn,
**kargs
)
@@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
**kargs
)
@@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn
def create_engine(self):
- return create_engine(self.config)
+ return create_engine(self.config.database_config)
class MemoryDataStore(object):
|