diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/__init__.py | 2 | ||||
-rw-r--r-- | synapse/federation/units.py | 1 | ||||
-rw-r--r-- | synapse/handlers/room.py | 6 | ||||
-rw-r--r-- | synapse/state.py | 149 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 40 | ||||
-rw-r--r-- | synapse/storage/_base.py | 8 | ||||
-rw-r--r-- | synapse/storage/pdu.py | 92 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 57 |
8 files changed, 192 insertions, 163 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 1ed9cdcdf3..d60267ebe4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a synapse home server. """ -__version__ = "0.2.2" +__version__ = "0.2.3" diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 9740431279..622fe66a8f 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject): "prev_state_id", "prev_state_origin", "required_power_level", + "user_id", ] internal_keys = [ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a0d0f2af16..310cb46fe7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -593,6 +593,12 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def get_public_room_list(self): chunk = yield self.store.get_rooms(is_public=True) + for room in chunk: + joined_members = yield self.store.get_room_members( + room_id=room["room_id"], + membership=Membership.JOIN + ) + room["num_joined_members"] = len(joined_members) # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) diff --git a/synapse/state.py b/synapse/state.py index 5dcff27367..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -115,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -134,7 +136,9 @@ class StateHandler(object): @defer.inlineCallbacks @log_function def _handle_new_state(self, new_pdu): - tree = yield self.store.get_unresolved_state_tree(new_pdu) + tree, missing_branch = yield self.store.get_unresolved_state_tree( + new_pdu + ) new_branch, current_branch = tree logger.debug( @@ -142,6 +146,28 @@ class StateHandler(object): new_branch, current_branch ) + if missing_branch is not None: + # We're missing some PDUs. Fetch them. + # TODO (erikj): Limit this. + missing_prev = tree[missing_branch][-1] + + pdu_id = missing_prev.prev_state_id + origin = missing_prev.prev_state_origin + + is_missing = yield self.store.get_pdu(pdu_id, origin) is None + if not is_missing: + raise Exception("Conflict resolution failed") + + yield self._replication.get_pdu( + destination=missing_prev.origin, + pdu_origin=origin, + pdu_id=pdu_id, + outlier=True + ) + + updated_current = yield self._handle_new_state(new_pdu) + defer.returnValue(updated_current) + if not current_branch: # There is no current state defer.returnValue(True) @@ -150,84 +176,85 @@ class StateHandler(object): n = new_branch[-1] c = current_branch[-1] - if n.pdu_id == c.pdu_id and n.origin == c.origin: - # We have all the PDUs we need, so we can just do the conflict - # resolution. + common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + + if common_ancestor: + # We found a common ancestor! if len(current_branch) == 1: # This is a direct clobber so we can just... defer.returnValue(True) - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - else: - # We need to ask for PDUs. - missing_prev = max( - new_branch[-1], current_branch[-1], - key=lambda x: x.depth - ) + # We didn't find a common ancestor. This is probably fine. + pass - if not hasattr(missing_prev, "prev_state_id"): - # FIXME Hmm - # temporary fallback - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) + result = yield self._do_conflict_res( + new_branch, current_branch, common_ancestor + ) + defer.returnValue(result) - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - return + @defer.inlineCallbacks + def _do_conflict_res(self, new_branch, current_branch, common_ancestor): + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin + for algo in conflict_res: + new_res, curr_res = yield defer.maybeDeferred( + algo, + new_branch, current_branch, common_ancestor + ) - is_missing = yield self.store.get_pdu(pdu_id, origin) is None + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) - if not is_missing: - raise Exception("Conflict resolution failed.") + raise Exception("Conflict resolution failed.") - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) - - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + @defer.inlineCallbacks + def _do_power_level_conflict_res(self, new_branch, current_branch, + common_ancestor): + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - def _do_power_level_conflict_res(self, new_branch, current_branch): - max_power_new = max( - new_branch[:-1], - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1], - key=lambda t: t.power_level - ).power_level + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) - def _do_chain_length_conflict_res(self, new_branch, current_branch): + def _do_chain_length_conflict_res(self, new_branch, current_branch, + common_ancestor): return (len(new_branch), len(current_branch)) - def _do_hash_conflict_res(self, new_branch, current_branch): + def _do_hash_conflict_res(self, new_branch, current_branch, + common_ancestor): new_str = "".join([p.pdu_id + p.origin for p in new_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch]) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 81c3c94b2e..9201a377b6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,7 +36,7 @@ from .registration import RegistrationStore from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore +from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore @@ -47,6 +47,11 @@ import os logger = logging.getLogger(__name__) +class _RollbackButIsFineException(Exception): + """ This exception is used to rollback a transaction without implying + something went wrong. + """ + pass class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, @@ -71,14 +76,16 @@ class DataStore(RoomMemberStore, RoomStore, self.min_token -= 1 stream_ordering = self.min_token - latest = yield self._db_pool.runInteraction( - self._persist_pdu_event_txn, - pdu=pdu, - event=event, - backfilled=backfilled, - stream_ordering=stream_ordering, - ) - defer.returnValue(latest) + try: + yield self._db_pool.runInteraction( + self._persist_pdu_event_txn, + pdu=pdu, + event=event, + backfilled=backfilled, + stream_ordering=stream_ordering, + ) + except _RollbackButIsFineException as e: + pass @defer.inlineCallbacks def get_event(self, event_id, allow_none=False): @@ -116,6 +123,12 @@ class DataStore(RoomMemberStore, RoomStore, del cols["content"] del cols["prev_pdus"] cols["content_json"] = json.dumps(pdu.content) + + unrec_keys.update({ + k: v for k, v in cols.items() + if k not in PdusTable.fields + }) + cols["unrecognized_keys"] = json.dumps(unrec_keys) logger.debug("Persisting: %s", repr(cols)) @@ -175,11 +188,12 @@ class DataStore(RoomMemberStore, RoomStore, try: self._simple_insert_txn(txn, "events", vals) except: - logger.exception( + logger.warn( "Failed to persist, probably duplicate: %s", - event.event_id + event.event_id, + exc_info=True, ) - return + raise _RollbackButIsFineException("_persist_event") if not backfilled and hasattr(event, "state_key"): vals = { @@ -205,8 +219,6 @@ class DataStore(RoomMemberStore, RoomStore, } ) - return self._get_room_events_max_id_txn(txn) - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): sql = ( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8037225079..8deaaf93bd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.logutils import log_function import collections import copy @@ -91,6 +92,7 @@ class SQLBaseStore(object): self._simple_insert_txn, table, values, or_replace=or_replace ) + @log_function def _simple_insert_txn(self, txn, table, values, or_replace=False): sql = "%s INTO %s (%s) VALUES(%s)" % ( ("INSERT OR REPLACE" if or_replace else "INSERT"), @@ -98,6 +100,12 @@ class SQLBaseStore(object): ", ".join(k for k in values), ", ".join("?" for k in values) ) + + logger.debug( + "[SQL] %s Args=%s Func=%s", + sql, values.values(), + ) + txn.execute(sql, values.values()) return txn.lastrowid diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 0bf97e37ee..3c859fdeac 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -17,6 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore, Table, JoinHelper +from synapse.federation.units import Pdu from synapse.util.logutils import log_function from collections import namedtuple @@ -308,8 +309,8 @@ class PduStore(SQLBaseStore): @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and haven't - seen). This list is used when we want to backfill backwards and is the + """Get a list of Pdus that we haven't backfilled beyond yet (and havent + seen). This list is used when we want to backfill backwards and is the list we send to the remote server. Args: @@ -516,7 +517,7 @@ class StatePduStore(SQLBaseStore): if not current: logger.debug("get_unresolved_state_tree No current state.") - return return_value + return (return_value, None) return_value.current_branch.append(current) @@ -524,13 +525,16 @@ class StatePduStore(SQLBaseStore): txn, new_pdu, current ) + missing_branch = None for branch, prev_state, state in enum_branches: if state: return_value[branch].append(state) else: + # We don't have prev_state :( + missing_branch = branch break - return return_value + return (return_value, missing_branch) def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): @@ -622,53 +626,6 @@ class StatePduStore(SQLBaseStore): return result - def get_next_missing_pdu(self, new_pdu): - """When we get a new state pdu we need to check whether we need to do - any conflict resolution, if we do then we need to check if we need - to go back and request some more state pdus that we haven't seen yet. - - Args: - txn - new_pdu - - Returns: - PduIdTuple: A pdu that we are missing, or None if we have all the - pdus required to do the conflict resolution. - """ - return self._db_pool.runInteraction( - self._get_next_missing_pdu, new_pdu - ) - - def _get_next_missing_pdu(self, txn, new_pdu): - logger.debug( - "get_next_missing_pdu %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - return None - - # Oh look, it's a straight clobber, so wooooo almost no-op. - if (new_pdu.prev_state_id == current.pdu_id - and new_pdu.prev_state_origin == current.origin): - return None - - enum_branches = self._enumerate_state_branches(txn, new_pdu, current) - for branch, prev_state, state in enum_branches: - if not state: - return PduIdTuple( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - - return None - def handle_new_state(self, new_pdu): """Actually perform conflict resolution on the new_pdu on the assumption we have all the pdus required to perform it. @@ -752,24 +709,11 @@ class StatePduStore(SQLBaseStore): return is_current - @classmethod @log_function - def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + def _enumerate_state_branches(self, txn, pdu_a, pdu_b): branch_a = pdu_a branch_b = pdu_b - get_query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - while True: if (branch_a.pdu_id == branch_b.pdu_id and branch_a.origin == branch_b.origin): @@ -801,13 +745,12 @@ class StatePduStore(SQLBaseStore): branch_a.prev_state_origin ) - logger.debug("getting branch_a prev %s", pdu_tuple) - txn.execute(get_query, pdu_tuple) - prev_branch = branch_a - res = txn.fetchone() - branch_a = PduEntry(*res) if res else None + logger.debug("getting branch_a prev %s", pdu_tuple) + branch_a = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_a: + branch_a = Pdu.from_pdu_tuple(branch_a) logger.debug("branch_a=%s", branch_a) @@ -820,14 +763,13 @@ class StatePduStore(SQLBaseStore): branch_b.prev_state_id, branch_b.prev_state_origin ) - txn.execute(get_query, pdu_tuple) - - logger.debug("getting branch_b prev %s", pdu_tuple) prev_branch = branch_b - res = txn.fetchone() - branch_b = PduEntry(*res) if res else None + logger.debug("getting branch_b prev %s", pdu_tuple) + branch_b = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_b: + branch_b = Pdu.from_pdu_tuple(branch_b) logger.debug("branch_b=%s", branch_b) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9a393e2568..676b2f2653 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import Membership +from synapse.util.logutils import log_function import logging @@ -29,8 +30,18 @@ class RoomMemberStore(SQLBaseStore): def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ - target_user_id = event.state_key - domain = self.hs.parse_userid(target_user_id).domain + try: + target_user_id = event.state_key + domain = self.hs.parse_userid(target_user_id).domain + except: + logger.exception("Failed to parse target_user_id=%s", target_user_id) + raise + + logger.debug( + "_store_room_member_txn: target_user_id=%s, membership=%s", + target_user_id, + event.membership, + ) self._simple_insert_txn( txn, @@ -51,12 +62,30 @@ class RoomMemberStore(SQLBaseStore): "VALUES (?, ?)" ) txn.execute(sql, (event.room_id, domain)) - else: - sql = ( - "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + elif event.membership != Membership.INVITE: + # Check if this was the last person to have left. + member_events = self._get_members_query_txn( + txn, + where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?", + where_values=(event.room_id, Membership.JOIN, target_user_id,) ) - txn.execute(sql, (event.room_id, domain)) + joined_domains = set() + for e in member_events: + try: + joined_domains.add( + self.hs.parse_userid(e.state_key).domain + ) + except: + # FIXME: How do we deal with invalid user ids in the db? + logger.exception("Invalid user_id: %s", event.state_key) + + if domain not in joined_domains: + sql = ( + "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + ) + + txn.execute(sql, (event.room_id, domain)) @defer.inlineCallbacks def get_room_member(self, user_id, room_id): @@ -146,8 +175,13 @@ class RoomMemberStore(SQLBaseStore): vals = where_dict.values() return self._get_members_query(clause, vals) - @defer.inlineCallbacks def _get_members_query(self, where_clause, where_values): + return self._db_pool.runInteraction( + self._get_members_query_txn, + where_clause, where_values + ) + + def _get_members_query_txn(self, txn, where_clause, where_values): sql = ( "SELECT e.* FROM events as e " "INNER JOIN room_memberships as m " @@ -157,12 +191,11 @@ class RoomMemberStore(SQLBaseStore): "WHERE %s " ) % (where_clause,) - rows = yield self._execute_and_decode(sql, *where_values) - - # logger.debug("_get_members_query Got rows %s", rows) + txn.execute(sql, where_values) + rows = self.cursor_to_dict(txn) - results = yield self._parse_events(rows) - defer.returnValue(results) + results = self._parse_events_txn(txn, rows) + return results @defer.inlineCallbacks def user_rooms_intersect(self, user_list): |