diff options
-rwxr-xr-x | scripts-dev/definitions.py | 142 | ||||
-rwxr-xr-x | scripts/synapse_port_db | 2 | ||||
-rw-r--r-- | synapse/api/auth.py | 25 | ||||
-rw-r--r-- | synapse/api/errors.py | 5 | ||||
-rwxr-xr-x | synapse/app/homeserver.py | 7 | ||||
-rwxr-xr-x | synapse/app/synctl.py | 52 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 256 | ||||
-rw-r--r-- | synapse/handlers/room.py | 44 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/receipts.py | 4 | ||||
-rw-r--r-- | synapse/state.py | 4 | ||||
-rw-r--r-- | synapse/storage/_base.py | 190 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 92 | ||||
-rw-r--r-- | synapse/storage/events.py | 11 | ||||
-rw-r--r-- | synapse/storage/pusher.py | 4 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 6 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 112 | ||||
-rw-r--r-- | synapse/storage/state.py | 6 | ||||
-rw-r--r-- | synapse/streams/events.py | 28 | ||||
-rw-r--r-- | synapse/util/__init__.py | 28 | ||||
-rw-r--r-- | tests/rest/client/v1/test_presence.py | 18 | ||||
-rw-r--r-- | tests/storage/test_base.py | 20 | ||||
-rw-r--r-- | tests/test_state.py | 2 |
22 files changed, 365 insertions, 693 deletions
diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py new file mode 100755 index 0000000000..f0d0cd8a3f --- /dev/null +++ b/scripts-dev/definitions.py @@ -0,0 +1,142 @@ +#! /usr/bin/python + +import ast +import yaml + +class DefinitionVisitor(ast.NodeVisitor): + def __init__(self): + super(DefinitionVisitor, self).__init__() + self.functions = {} + self.classes = {} + self.names = {} + self.attrs = set() + self.definitions = { + 'def': self.functions, + 'class': self.classes, + 'names': self.names, + 'attrs': self.attrs, + } + + def visit_Name(self, node): + self.names.setdefault(type(node.ctx).__name__, set()).add(node.id) + + def visit_Attribute(self, node): + self.attrs.add(node.attr) + for child in ast.iter_child_nodes(node): + self.visit(child) + + def visit_ClassDef(self, node): + visitor = DefinitionVisitor() + self.classes[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + def visit_FunctionDef(self, node): + visitor = DefinitionVisitor() + self.functions[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + +def non_empty(defs): + functions = {name: non_empty(f) for name, f in defs['def'].items()} + classes = {name: non_empty(f) for name, f in defs['class'].items()} + result = {} + if functions: result['def'] = functions + if classes: result['class'] = classes + names = defs['names'] + uses = [] + for name in names.get('Load', ()): + if name not in names.get('Param', ()) and name not in names.get('Store', ()): + uses.append(name) + uses.extend(defs['attrs']) + if uses: result['uses'] = uses + result['names'] = names + result['attrs'] = defs['attrs'] + return result + + +def definitions_in_code(input_code): + input_ast = ast.parse(input_code) + visitor = DefinitionVisitor() + visitor.visit(input_ast) + definitions = non_empty(visitor.definitions) + return definitions + + +def definitions_in_file(filepath): + with open(filepath) as f: + return definitions_in_code(f.read()) + + +def defined_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + +def used_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for used in defs.get('uses', ()): + if used in names: + names[used].setdefault('used', []).append(prefix.rstrip('.')) + + +if __name__ == '__main__': + import sys, os, argparse, re + + parser = argparse.ArgumentParser(description='Find definitions.') + parser.add_argument( + "--unused", action="store_true", help="Only list unused definitions" + ) + parser.add_argument( + "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" + ) + parser.add_argument( + "--pattern", action="append", metavar="REGEXP", + help="Search for a pattern" + ) + parser.add_argument( + "directories", nargs='+', metavar="DIR", + help="Directories to search for definitions" + ) + args = parser.parse_args() + + definitions = {} + for directory in args.directories: + for root, dirs, files in os.walk(directory): + for filename in files: + if filename.endswith(".py"): + filepath = os.path.join(root, filename) + definitions[filepath] = definitions_in_file(filepath) + + names = {} + for filepath, defs in definitions.items(): + defined_names(filepath + ":", defs, names) + + for filepath, defs in definitions.items(): + used_names(filepath + ":", defs, names) + + patterns = [re.compile(pattern) for pattern in args.pattern or ()] + ignore = [re.compile(pattern) for pattern in args.ignore or ()] + + result = {} + for name, definition in names.items(): + if patterns and not any(pattern.match(name) for pattern in patterns): + continue + if ignore and any(pattern.match(name) for pattern in ignore): + continue + if args.unused and definition.get('used'): + continue + result[name] = definition + + yaml.dump(result, sys.stdout, default_flow_style=False) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 6aba72e459..62515997b1 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -95,8 +95,6 @@ class Store(object): _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] - _execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"] - def runInteraction(self, desc, func, *args, **kwargs): def r(conn): try: diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 847ff60671..e3b8c3099a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, EventID +from synapse.types import RoomID, UserID, EventID import logging import pymacaroons @@ -80,6 +80,15 @@ class Auth(object): "Room %r does not exist" % (event.room_id,) ) + creating_domain = RoomID.from_string(event.room_id).domain + originating_domain = UserID.from_string(event.sender).domain + if creating_domain != originating_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # FIXME: Temp hack if event.type == EventTypes.Aliases: return True @@ -219,6 +228,11 @@ class Auth(object): user_id, room_id, repr(member) )) + def can_federate(self, event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + @log_function def is_membership_change_allowed(self, event, auth_events): membership = event.content["membership"] @@ -234,6 +248,15 @@ class Auth(object): target_user_id = event.state_key + creating_domain = RoomID.from_string(event.room_id).domain + target_domain = UserID.from_string(target_user_id).domain + if creating_domain != target_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # get info about the caller key = (EventTypes.Member, event.user_id, ) caller = auth_events.get(key) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c3b4d971a8..ee3045268f 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -77,11 +77,6 @@ class SynapseError(CodeMessageException): ) -class RoomError(SynapseError): - """An error raised when a room event fails.""" - pass - - class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 21840e4a28..190b03e2f7 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -85,12 +85,6 @@ import time logger = logging.getLogger("synapse.app.homeserver") -class GzipFile(File): - def getChild(self, path, request): - child = File.getChild(self, path, request) - return EncodingResourceWrapper(child, [GzipEncoderFactory()]) - - def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) @@ -134,6 +128,7 @@ class SynapseHomeServer(HomeServer): # (It can stay enabled for the API resources: they call # write() with the whole body and then finish() straight # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 # return GzipFile(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable? diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 6bcc437591..5d82beed0e 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -16,36 +16,23 @@ import sys import os +import os.path import subprocess import signal import yaml SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] -CONFIGFILE = "homeserver.yaml" - GREEN = "\x1b[1;32m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" -if not os.path.exists(CONFIGFILE): - sys.stderr.write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), CONFIGFILE - ) - ) - sys.exit(1) - -CONFIG = yaml.load(open(CONFIGFILE)) -PIDFILE = CONFIG["pid_file"] - -def start(): +def start(configfile): print "Starting ...", args = SYNAPSE - args.extend(["--daemonize", "-c", CONFIGFILE]) + args.extend(["--daemonize", "-c", configfile]) + try: subprocess.check_call(args) print GREEN + "started" + NORMAL @@ -57,24 +44,39 @@ def start(): ) -def stop(): - if os.path.exists(PIDFILE): - pid = int(open(PIDFILE).read()) +def stop(pidfile): + if os.path.exists(pidfile): + pid = int(open(pidfile).read()) os.kill(pid, signal.SIGTERM) print GREEN + "stopped" + NORMAL def main(): + configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml" + + if not os.path.exists(configfile): + sys.stderr.write( + "No config file found\n" + "To generate a config file, run '%s -c %s --generate-config" + " --server-name=<server name>'\n" % ( + " ".join(SYNAPSE), configfile + ) + ) + sys.exit(1) + + config = yaml.load(open(configfile)) + pidfile = config["pid_file"] + action = sys.argv[1] if sys.argv[1:] else "usage" if action == "start": - start() + start(configfile) elif action == "stop": - stop() + stop(pidfile) elif action == "restart": - stop() - start() + stop(pidfile) + start(configfile) else: - sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) + sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) sys.exit(1) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4ff20599d6..3aff80bf59 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,60 +125,72 @@ class FederationHandler(BaseHandler): ) if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") - current_state = state - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} + try: + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + else: + event_ids = set() + if state: + event_ids |= {e.event_id for e in state} + if auth_chain: + event_ids |= {e.event_id for e in auth_chain} + + seen_ids = set( + (yield self.store.have_events(event_ids)).keys() + ) - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) + if state and auth_chain is not None: + # If we have any state or auth_chain given to us by the replication + # layer, then we should handle them (if we haven't before.) - event_infos = [] + event_infos = [] - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - event_infos.append({ - "event": e, - "auth_events": auth, - }) - seen_ids.add(e.event_id) + for e in itertools.chain(auth_chain, state): + if e.event_id in seen_ids: + continue + e.internal_metadata.outlier = True + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + event_infos.append({ + "event": e, + "auth_events": auth, + }) + seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events( + origin, + event_infos, + outliers=True + ) - try: - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - event, - state=state, - backfilled=backfilled, - current_state=current_state, - ) - except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + try: + _, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, + event, + state=state, + backfilled=backfilled, + current_state=current_state, + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) # if we're receiving valid events from an origin, # it's probably a good idea to mark it as not in retry-state @@ -649,35 +661,8 @@ class FederationHandler(BaseHandler): # FIXME pass - ev_infos = [] - for e in itertools.chain(state, auth_chain): - if e.event_id == event.event_id: - continue - - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - ev_infos.append({ - "event": e, - "auth_events": { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - }) - - yield self._handle_new_events(origin, ev_infos, outliers=True) - - auth_ids = [e_id for e_id, _ in event.auth_events] - auth_events = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - new_event, - state=state, - current_state=state, - auth_events=auth_events, + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event ) with PreserveLoggingContext(): @@ -1027,6 +1012,76 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks + def _persist_auth_tree(self, auth_events, state, event): + """Checks the auth chain is valid (and passes auth checks) for the + state and event. Then persists the auth chain and state atomically. + Persists the event seperately. + + Returns: + 2-tuple of (event_stream_id, max_stream_id) from the persist_event + call for `event` + """ + events_to_context = {} + for e in itertools.chain(auth_events, state): + ctx = yield self.state_handler.compute_event_context( + e, outlier=True, + ) + events_to_context[e.event_id] = ctx + e.internal_metadata.outlier = True + + event_map = { + e.event_id: e + for e in auth_events + } + + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + for e in itertools.chain(auth_events, state, [event]): + auth_for_e = { + (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] + for e_id, _ in e.auth_events + } + if create_event: + auth_for_e[(EventTypes.Create, "")] = create_event + + try: + self.auth.check(e, auth_events=auth_for_e) + except AuthError as err: + logger.warn( + "Rejecting %s because %s", + e.event_id, err.msg + ) + + if e == event: + raise + events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + + yield self.store.persist_events( + [ + (e, events_to_context[e.event_id]) + for e in itertools.chain(auth_events, state) + ], + is_new_state=False, + ) + + new_event_context = yield self.state_handler.compute_event_context( + event, old_state=state, outlier=False, + ) + + event_stream_id, max_stream_id = yield self.store.persist_event( + event, new_event_context, + backfilled=False, + is_new_state=True, + current_state=state, + ) + + defer.returnValue((event_stream_id, max_stream_id)) + + @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, backfilled=False, current_state=None, auth_events=None): outlier = event.internal_metadata.is_outlier() @@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler): }, "missing": [e.event_id for e in missing_locals], }) - - @defer.inlineCallbacks - def _handle_auth_events(self, origin, auth_events): - auth_ids_to_deferred = {} - - def process_auth_ev(ev): - auth_ids = [e_id for e_id, _ in ev.auth_events] - - prev_ds = [ - auth_ids_to_deferred[i] - for i in auth_ids - if i in auth_ids_to_deferred - ] - - d = defer.Deferred() - - auth_ids_to_deferred[ev.event_id] = d - - @defer.inlineCallbacks - def f(*_): - ev.internal_metadata.outlier = True - - try: - auth = { - (e.type, e.state_key): e for e in auth_events - if e.event_id in auth_ids - } - - yield self._handle_new_event( - origin, ev, auth_events=auth - ) - except: - logger.exception( - "Failed to handle auth event %s", - ev.event_id, - ) - - d.callback(None) - - if prev_ds: - dx = defer.DeferredList(prev_ds) - dx.addBoth(f) - else: - f() - - for e in auth_events: - process_auth_ev(e) - - yield defer.DeferredList(auth_ids_to_deferred.values()) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bb5eef6bbd..ac636255c2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -149,12 +149,16 @@ class RoomCreationHandler(BaseHandler): for val in raw_initial_state: initial_state[(val["type"], val.get("state_key", ""))] = val["content"] + creation_content = config.get("creation_content", {}) + user = UserID.from_string(user_id) creation_events = self._create_events_for_new_room( user, room_id, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, + creation_content=creation_content, + room_alias=room_alias, ) msg_handler = self.hs.get_handlers().message_handler @@ -202,7 +206,8 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) def _create_events_for_new_room(self, creator, room_id, preset_config, - invite_list, initial_state): + invite_list, initial_state, creation_content, + room_alias): config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.to_string() @@ -224,9 +229,10 @@ class RoomCreationHandler(BaseHandler): return e + creation_content.update({"creator": creator.to_string()}) creation_event = create( etype=EventTypes.Create, - content={"creator": creator.to_string()}, + content=creation_content, ) join_event = create( @@ -271,6 +277,14 @@ class RoomCreationHandler(BaseHandler): returned_events.append(power_levels_event) + if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + room_alias_event = create( + etype=EventTypes.CanonicalAlias, + content={"alias": room_alias.to_string()}, + ) + + returned_events.append(room_alias_event) + if (EventTypes.JoinRules, '') not in initial_state: join_rules_event = create( etype=EventTypes.JoinRules, @@ -493,32 +507,6 @@ class RoomMemberHandler(BaseHandler): ) @defer.inlineCallbacks - def _should_invite_join(self, room_id, prev_state, do_auth): - logger.debug("_should_invite_join: room_id: %s", room_id) - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - # Only do an invite join dance if a) we were invited, - # b) the person inviting was from a differnt HS and c) we are - # not currently in the room - room_host = None - if prev_state and prev_state.membership == Membership.INVITE: - room = yield self.store.get_room(room_id) - inviter = UserID.from_string( - prev_state.sender - ) - - is_remote_invite_join = not self.hs.is_mine(inviter) and not room - room_host = inviter.domain - else: - is_remote_invite_join = False - - defer.returnValue((is_remote_invite_join, room_host)) - - @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given membership states in.""" diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 52e99f54d5..b107b7ce17 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern @@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet): def on_POST(self, request, room_id, receipt_type, event_id): user, _ = yield self.auth.get_user_by_req(request) + if receipt_type != "m.read": + raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/state.py b/synapse/state.py index ed36f844cb..bb225c39cf 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -31,10 +31,6 @@ import hashlib logger = logging.getLogger(__name__) -def _get_state_key_from_event(event): - return event.state_key - - KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 495ef087c9..693784ad38 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple - import sys import time import threading @@ -376,9 +374,6 @@ class SQLBaseStore(object): return self.runInteraction(desc, interaction) - def _execute_and_decode(self, desc, query, *args): - return self._execute(desc, self.cursor_to_dict, query, *args) - # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -691,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -743,16 +707,6 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete(self, table, keyvalues, desc="_simple_delete"): - """Executes a DELETE query on the named table. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - - return self.runInteraction(desc, self._simple_delete_txn) - def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, @@ -761,24 +715,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def _simple_max_id(self, table): - """Executes a SELECT query on the named table, expecting to return the - max value for the column "id". - - Args: - table : string giving the table name - """ - sql = "SELECT MAX(id) AS id FROM %s" % table - - def func(txn): - txn.execute(sql) - max_id = self.cursor_to_dict(txn)[0]["id"] - if max_id is None: - return 0 - return max_id - - return self.runInteraction("_simple_max_id", func) - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id @@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass - - -class Table(object): - """ A base class used to store information about a particular table. - """ - - table_name = None - """ str: The name of the table """ - - fields = None - """ list: The field names """ - - EntryType = None - """ Type: A tuple type used to decode the results """ - - _select_where_clause = "SELECT %s FROM %s WHERE %s" - _select_clause = "SELECT %s FROM %s" - _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" - - @classmethod - def select_statement(cls, where_clause=None): - """ - Args: - where_clause (str): The WHERE clause to use. - - Returns: - str: An SQL statement to select rows from the table with the given - WHERE clause. - """ - if where_clause: - return cls._select_where_clause % ( - ", ".join(cls.fields), - cls.table_name, - where_clause - ) - else: - return cls._select_clause % ( - ", ".join(cls.fields), - cls.table_name, - ) - - @classmethod - def insert_statement(cls): - return cls._insert_clause % ( - cls.table_name, - ", ".join(cls.fields), - ", ".join(["?"] * len(cls.fields)), - ) - - @classmethod - def decode_single_result(cls, results): - """ Given an iterable of tuples, return a single instance of - `EntryType` or None if the iterable is empty - Args: - results (list): The results list to convert to `EntryType` - Returns: - EntryType: An instance of `EntryType` - """ - results = list(results) - if results: - return cls.EntryType(*results[0]) - else: - return None - - @classmethod - def decode_results(cls, results): - """ Given an iterable of tuples, return a list of `EntryType` - Args: - results (list): The results list to convert to `EntryType` - - Returns: - list: A list of `EntryType` - """ - return [cls.EntryType(*row) for row in results] - - @classmethod - def get_fields_string(cls, prefix=None): - if prefix: - to_join = ("%s.%s" % (prefix, f) for f in cls.fields) - else: - to_join = cls.fields - - return ", ".join(to_join) - - -class JoinHelper(object): - """ Used to help do joins on tables by looking at the tables' fields and - creating a list of unique fields to use with SELECTs and a namedtuple - to dump the results into. - - Attributes: - tables (list): List of `Table` classes - EntryType (type) - """ - - def __init__(self, *tables): - self.tables = tables - - res = [] - for table in self.tables: - res += [f for f in table.fields if f not in res] - - self.EntryType = namedtuple("JoinHelperEntry", res) - - def get_fields(self, **prefixes): - """Get a string representing a list of fields for use in SELECT - statements with the given prefixes applied to each. - - For example:: - - JoinHelper(PdusTable, StateTable).get_fields( - PdusTable="pdus", - StateTable="state" - ) - """ - res = [] - for field in self.EntryType._fields: - for table in self.tables: - if field in table.fields: - res.append("%s.%s" % (prefixes[table.__name__], field)) - break - - return ", ".join(res) - - def decode_results(self, rows): - return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index c1cabbaa60..6d4421dd8f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_latest_state_in_room(self, txn, room_id, type, state_key): - event_ids = self._simple_select_onecol_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "room_id": room_id, - "type": type, - "state_key": state_key, - }, - retcol="event_id", - ) - - results = [] - for event_id in event_ids: - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((event_id, prev_hashes)) - - return results - - def _get_prev_events(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=0, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_state(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=True, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_events_and_state(self, txn, event_id, is_state=None): - keyvalues = { - "event_id": event_id, - } - - if is_state is not None: - keyvalues["is_state"] = bool(is_state) - - res = self._simple_select_list_txn( - txn, - table="event_edges", - keyvalues=keyvalues, - retcols=["prev_event_id", "is_state"], - ) - - hashes = self._get_prev_event_hashes_txn(txn, event_id) - - results = [] - for d in res: - edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) - edge_hash.update(hashes.get(d["prev_event_id"], {})) - prev_hashes = { - k: encode_base64(v) - for k, v in edge_hash.items() - if k == "sha256" - } - results.append((d["prev_event_id"], prev_hashes, d["is_state"])) - - return results - - def _get_auth_events(self, txn, event_id): - auth_ids = self._simple_select_onecol_txn( - txn, - table="event_auth", - keyvalues={ - "event_id": event_id, - }, - retcol="auth_id", - ) - - results = [] - for auth_id in auth_ids: - hashes = self._get_event_reference_hashes_txn(txn, auth_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((auth_id, prev_hashes)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 46df6b4d6d..416ef6af93 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -890,22 +890,11 @@ class EventsStore(SQLBaseStore): return ev - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - def _parse_events_txn(self, txn, rows): event_ids = [r["event_id"] for r in rows] return self._get_events_txn(txn, event_ids) - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 00b748f131..345c4e1104 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore from twisted.internet import defer from synapse.api.errors import StoreError @@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore): ) -class PushersTable(Table): +class PushersTable(object): table_name = "pushers" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 41c939efb1..8c40d9a8a6 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -178,12 +178,6 @@ class RoomMemberStore(SQLBaseStore): return joined_domains - def _get_members_query(self, where_clause, where_values): - return self.runInteraction( - "get_members_query", self._get_members_events_txn, - where_clause, where_values - ).addCallbacks(self._get_events) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ab57b92174..b070be504d 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - def _get_event_content_hashes_txn(self, txn, event_id): - """Get all the hashes for a given Event. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_content_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - return dict(txn.fetchall()) - - def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): - """Store a hash for a Event - Args: - txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. - """ - self._simple_insert_txn( - txn, - "event_content_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) - def get_event_reference_hashes(self, event_ids): def f(txn): return [ @@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore): table="event_reference_hashes", values=vals, ) - - def _get_event_signatures_txn(self, txn, event_id): - """Get all the signatures for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of sig name -> dict(key_id -> signature_bytes) - """ - query = ( - "SELECT signature_name, key_id, signature" - " FROM event_signatures" - " WHERE event_id = ? " - ) - txn.execute(query, (event_id, )) - rows = txn.fetchall() - - res = {} - - for name, key, sig in rows: - res.setdefault(name, {})[key] = sig - - return res - - def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): - """Store a signature from the origin server for a PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - origin (str): origin of the Event. - key_id (str): Id for the signing key. - signature (bytes): The signature. - """ - self._simple_insert_txn( - txn, - "event_signatures", - { - "event_id": event_id, - "signature_name": signature_name, - "key_id": key_id, - "signature": buffer(signature_bytes), - }, - ) - - def _get_prev_event_hashes_txn(self, txn, event_id): - """Get all the hashes for previous PDUs of a PDU - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. - """ - query = ( - "SELECT prev_event_id, algorithm, hash" - " FROM event_edge_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - results = {} - for prev_event_id, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault(prev_event_id, {}) - hashes[algorithm] = hash_bytes - return results - - def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): - self._simple_insert_txn( - txn, - "event_edge_hashes", - { - "event_id": event_id, - "prev_event_id": prev_event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9630efcfcc..e935b9443b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import ( from twisted.internet import defer -from synapse.util.stringutils import random_string - import logging logger = logging.getLogger(__name__) @@ -428,7 +426,3 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) - - -def _make_group_id(clock): - return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index aaa3609aa5..699083ae12 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events_for_user(self, user, from_key, limit): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - class EventSources(object): SOURCE_TYPES = { "room": RoomEventSource, @@ -70,15 +54,3 @@ class EventSources(object): ), ) defer.returnValue(token) - - -class StreamSource(object): - def get_new_events_for_user(self, user, from_key, limit): - """from_key is the key within this event source.""" - raise NotImplementedError("get_new_events_for_user") - - def get_current_key(self): - raise NotImplementedError("get_current_key") - - def get_pagination_rows(self, user, pagination_config, key): - raise NotImplementedError("get_rows") diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 07ff25cef3..1d123ccefc 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -29,34 +29,6 @@ def unwrapFirstError(failure): return failure.value.subFailure -def unwrap_deferred(d): - """Given a deferred that we know has completed, return its value or raise - the failure as an exception - """ - if not d.called: - raise RuntimeError("deferred has not finished") - - res = [] - - def f(r): - res.append(r) - return r - d.addCallback(f) - - if res: - return res[0] - - def f(r): - res.append(r) - return r - d.addErrback(f) - - if res: - res[0].raiseException() - else: - raise RuntimeError("deferred did not call callbacks") - - class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 2ee3da0b34..29d9bbaad4 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -41,6 +41,22 @@ myid = "@apple:test" PATH_PREFIX = "/_matrix/client/api/v1" +class NullSource(object): + """This event source never yields any events and its token remains at + zero. It may be useful for unit-testing.""" + def __init__(self, hs): + pass + + def get_new_events_for_user(self, user, from_key, limit): + return defer.succeed(([], from_key)) + + def get_current_key(self, direction='f'): + return defer.succeed(0) + + def get_pagination_rows(self, user, pagination_config, key): + return defer.succeed(([], pagination_config.from_key)) + + class JustPresenceHandlers(object): def __init__(self, hs): self.presence_handler = PresenceHandler(hs) @@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # HIDEOUS HACKERY # TODO(paul): This should be injected in via the HomeServer DI system from synapse.streams.events import ( - PresenceEventSource, NullSource, EventSources + PresenceEventSource, EventSources ) old_SOURCE_TYPES = EventSources.SOURCE_TYPES diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 8573f18b55..1ddca1da4c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -186,26 +186,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_one_with_return(self): - self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Old Value",) - - ret = yield self.datastore._simple_selectupdate_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columname": "New Value"}, - retcols=["columname"] - ) - - self.assertEquals({"columname": "Old Value"}, ret) - self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', - ['TheKey']), - call("UPDATE tablename SET columname = ? WHERE keycol = ?", - ["New Value", "TheKey"]) - ]) - - @defer.inlineCallbacks def test_delete_one(self): self.mock_txn.rowcount = 1 diff --git a/tests/test_state.py b/tests/test_state.py index 55f37c521f..0274c4bc18 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -35,7 +35,7 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, if not event_id: _next_event_id += 1 - event_id = str(_next_event_id) + event_id = "$%s:test" % (_next_event_id,) if not name: if state_key is not None: |