diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/auth.py | 25 | ||||
-rw-r--r-- | synapse/api/errors.py | 5 | ||||
-rwxr-xr-x | synapse/app/homeserver.py | 7 | ||||
-rwxr-xr-x | synapse/app/synctl.py | 52 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 256 | ||||
-rw-r--r-- | synapse/handlers/room.py | 44 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/receipts.py | 4 | ||||
-rw-r--r-- | synapse/state.py | 4 | ||||
-rw-r--r-- | synapse/storage/_base.py | 190 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 92 | ||||
-rw-r--r-- | synapse/storage/events.py | 11 | ||||
-rw-r--r-- | synapse/storage/pusher.py | 4 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 6 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 112 | ||||
-rw-r--r-- | synapse/storage/state.py | 6 | ||||
-rw-r--r-- | synapse/streams/events.py | 28 | ||||
-rw-r--r-- | synapse/util/__init__.py | 28 |
17 files changed, 205 insertions, 669 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 847ff60671..e3b8c3099a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, EventID +from synapse.types import RoomID, UserID, EventID import logging import pymacaroons @@ -80,6 +80,15 @@ class Auth(object): "Room %r does not exist" % (event.room_id,) ) + creating_domain = RoomID.from_string(event.room_id).domain + originating_domain = UserID.from_string(event.sender).domain + if creating_domain != originating_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # FIXME: Temp hack if event.type == EventTypes.Aliases: return True @@ -219,6 +228,11 @@ class Auth(object): user_id, room_id, repr(member) )) + def can_federate(self, event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + @log_function def is_membership_change_allowed(self, event, auth_events): membership = event.content["membership"] @@ -234,6 +248,15 @@ class Auth(object): target_user_id = event.state_key + creating_domain = RoomID.from_string(event.room_id).domain + target_domain = UserID.from_string(target_user_id).domain + if creating_domain != target_domain: + if not self.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + # get info about the caller key = (EventTypes.Member, event.user_id, ) caller = auth_events.get(key) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c3b4d971a8..ee3045268f 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -77,11 +77,6 @@ class SynapseError(CodeMessageException): ) -class RoomError(SynapseError): - """An error raised when a room event fails.""" - pass - - class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 21840e4a28..190b03e2f7 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -85,12 +85,6 @@ import time logger = logging.getLogger("synapse.app.homeserver") -class GzipFile(File): - def getChild(self, path, request): - child = File.getChild(self, path, request) - return EncodingResourceWrapper(child, [GzipEncoderFactory()]) - - def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) @@ -134,6 +128,7 @@ class SynapseHomeServer(HomeServer): # (It can stay enabled for the API resources: they call # write() with the whole body and then finish() straight # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 # return GzipFile(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable? diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 6bcc437591..5d82beed0e 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -16,36 +16,23 @@ import sys import os +import os.path import subprocess import signal import yaml SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] -CONFIGFILE = "homeserver.yaml" - GREEN = "\x1b[1;32m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" -if not os.path.exists(CONFIGFILE): - sys.stderr.write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), CONFIGFILE - ) - ) - sys.exit(1) - -CONFIG = yaml.load(open(CONFIGFILE)) -PIDFILE = CONFIG["pid_file"] - -def start(): +def start(configfile): print "Starting ...", args = SYNAPSE - args.extend(["--daemonize", "-c", CONFIGFILE]) + args.extend(["--daemonize", "-c", configfile]) + try: subprocess.check_call(args) print GREEN + "started" + NORMAL @@ -57,24 +44,39 @@ def start(): ) -def stop(): - if os.path.exists(PIDFILE): - pid = int(open(PIDFILE).read()) +def stop(pidfile): + if os.path.exists(pidfile): + pid = int(open(pidfile).read()) os.kill(pid, signal.SIGTERM) print GREEN + "stopped" + NORMAL def main(): + configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml" + + if not os.path.exists(configfile): + sys.stderr.write( + "No config file found\n" + "To generate a config file, run '%s -c %s --generate-config" + " --server-name=<server name>'\n" % ( + " ".join(SYNAPSE), configfile + ) + ) + sys.exit(1) + + config = yaml.load(open(configfile)) + pidfile = config["pid_file"] + action = sys.argv[1] if sys.argv[1:] else "usage" if action == "start": - start() + start(configfile) elif action == "stop": - stop() + stop(pidfile) elif action == "restart": - stop() - start() + stop(pidfile) + start(configfile) else: - sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) + sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) sys.exit(1) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4ff20599d6..3aff80bf59 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,60 +125,72 @@ class FederationHandler(BaseHandler): ) if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") - current_state = state - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} + try: + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + else: + event_ids = set() + if state: + event_ids |= {e.event_id for e in state} + if auth_chain: + event_ids |= {e.event_id for e in auth_chain} + + seen_ids = set( + (yield self.store.have_events(event_ids)).keys() + ) - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) + if state and auth_chain is not None: + # If we have any state or auth_chain given to us by the replication + # layer, then we should handle them (if we haven't before.) - event_infos = [] + event_infos = [] - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - event_infos.append({ - "event": e, - "auth_events": auth, - }) - seen_ids.add(e.event_id) + for e in itertools.chain(auth_chain, state): + if e.event_id in seen_ids: + continue + e.internal_metadata.outlier = True + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + event_infos.append({ + "event": e, + "auth_events": auth, + }) + seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events( + origin, + event_infos, + outliers=True + ) - try: - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - event, - state=state, - backfilled=backfilled, - current_state=current_state, - ) - except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + try: + _, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, + event, + state=state, + backfilled=backfilled, + current_state=current_state, + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) # if we're receiving valid events from an origin, # it's probably a good idea to mark it as not in retry-state @@ -649,35 +661,8 @@ class FederationHandler(BaseHandler): # FIXME pass - ev_infos = [] - for e in itertools.chain(state, auth_chain): - if e.event_id == event.event_id: - continue - - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - ev_infos.append({ - "event": e, - "auth_events": { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - }) - - yield self._handle_new_events(origin, ev_infos, outliers=True) - - auth_ids = [e_id for e_id, _ in event.auth_events] - auth_events = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - new_event, - state=state, - current_state=state, - auth_events=auth_events, + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event ) with PreserveLoggingContext(): @@ -1027,6 +1012,76 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks + def _persist_auth_tree(self, auth_events, state, event): + """Checks the auth chain is valid (and passes auth checks) for the + state and event. Then persists the auth chain and state atomically. + Persists the event seperately. + + Returns: + 2-tuple of (event_stream_id, max_stream_id) from the persist_event + call for `event` + """ + events_to_context = {} + for e in itertools.chain(auth_events, state): + ctx = yield self.state_handler.compute_event_context( + e, outlier=True, + ) + events_to_context[e.event_id] = ctx + e.internal_metadata.outlier = True + + event_map = { + e.event_id: e + for e in auth_events + } + + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + for e in itertools.chain(auth_events, state, [event]): + auth_for_e = { + (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] + for e_id, _ in e.auth_events + } + if create_event: + auth_for_e[(EventTypes.Create, "")] = create_event + + try: + self.auth.check(e, auth_events=auth_for_e) + except AuthError as err: + logger.warn( + "Rejecting %s because %s", + e.event_id, err.msg + ) + + if e == event: + raise + events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + + yield self.store.persist_events( + [ + (e, events_to_context[e.event_id]) + for e in itertools.chain(auth_events, state) + ], + is_new_state=False, + ) + + new_event_context = yield self.state_handler.compute_event_context( + event, old_state=state, outlier=False, + ) + + event_stream_id, max_stream_id = yield self.store.persist_event( + event, new_event_context, + backfilled=False, + is_new_state=True, + current_state=state, + ) + + defer.returnValue((event_stream_id, max_stream_id)) + + @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, backfilled=False, current_state=None, auth_events=None): outlier = event.internal_metadata.is_outlier() @@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler): }, "missing": [e.event_id for e in missing_locals], }) - - @defer.inlineCallbacks - def _handle_auth_events(self, origin, auth_events): - auth_ids_to_deferred = {} - - def process_auth_ev(ev): - auth_ids = [e_id for e_id, _ in ev.auth_events] - - prev_ds = [ - auth_ids_to_deferred[i] - for i in auth_ids - if i in auth_ids_to_deferred - ] - - d = defer.Deferred() - - auth_ids_to_deferred[ev.event_id] = d - - @defer.inlineCallbacks - def f(*_): - ev.internal_metadata.outlier = True - - try: - auth = { - (e.type, e.state_key): e for e in auth_events - if e.event_id in auth_ids - } - - yield self._handle_new_event( - origin, ev, auth_events=auth - ) - except: - logger.exception( - "Failed to handle auth event %s", - ev.event_id, - ) - - d.callback(None) - - if prev_ds: - dx = defer.DeferredList(prev_ds) - dx.addBoth(f) - else: - f() - - for e in auth_events: - process_auth_ev(e) - - yield defer.DeferredList(auth_ids_to_deferred.values()) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bb5eef6bbd..ac636255c2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -149,12 +149,16 @@ class RoomCreationHandler(BaseHandler): for val in raw_initial_state: initial_state[(val["type"], val.get("state_key", ""))] = val["content"] + creation_content = config.get("creation_content", {}) + user = UserID.from_string(user_id) creation_events = self._create_events_for_new_room( user, room_id, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, + creation_content=creation_content, + room_alias=room_alias, ) msg_handler = self.hs.get_handlers().message_handler @@ -202,7 +206,8 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) def _create_events_for_new_room(self, creator, room_id, preset_config, - invite_list, initial_state): + invite_list, initial_state, creation_content, + room_alias): config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.to_string() @@ -224,9 +229,10 @@ class RoomCreationHandler(BaseHandler): return e + creation_content.update({"creator": creator.to_string()}) creation_event = create( etype=EventTypes.Create, - content={"creator": creator.to_string()}, + content=creation_content, ) join_event = create( @@ -271,6 +277,14 @@ class RoomCreationHandler(BaseHandler): returned_events.append(power_levels_event) + if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + room_alias_event = create( + etype=EventTypes.CanonicalAlias, + content={"alias": room_alias.to_string()}, + ) + + returned_events.append(room_alias_event) + if (EventTypes.JoinRules, '') not in initial_state: join_rules_event = create( etype=EventTypes.JoinRules, @@ -493,32 +507,6 @@ class RoomMemberHandler(BaseHandler): ) @defer.inlineCallbacks - def _should_invite_join(self, room_id, prev_state, do_auth): - logger.debug("_should_invite_join: room_id: %s", room_id) - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - # Only do an invite join dance if a) we were invited, - # b) the person inviting was from a differnt HS and c) we are - # not currently in the room - room_host = None - if prev_state and prev_state.membership == Membership.INVITE: - room = yield self.store.get_room(room_id) - inviter = UserID.from_string( - prev_state.sender - ) - - is_remote_invite_join = not self.hs.is_mine(inviter) and not room - room_host = inviter.domain - else: - is_remote_invite_join = False - - defer.returnValue((is_remote_invite_join, room_host)) - - @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given membership states in.""" diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 52e99f54d5..b107b7ce17 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern @@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet): def on_POST(self, request, room_id, receipt_type, event_id): user, _ = yield self.auth.get_user_by_req(request) + if receipt_type != "m.read": + raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/state.py b/synapse/state.py index ed36f844cb..bb225c39cf 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -31,10 +31,6 @@ import hashlib logger = logging.getLogger(__name__) -def _get_state_key_from_event(event): - return event.state_key - - KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 495ef087c9..693784ad38 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple - import sys import time import threading @@ -376,9 +374,6 @@ class SQLBaseStore(object): return self.runInteraction(desc, interaction) - def _execute_and_decode(self, desc, query, *args): - return self._execute(desc, self.cursor_to_dict, query, *args) - # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -691,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -743,16 +707,6 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete(self, table, keyvalues, desc="_simple_delete"): - """Executes a DELETE query on the named table. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - - return self.runInteraction(desc, self._simple_delete_txn) - def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, @@ -761,24 +715,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def _simple_max_id(self, table): - """Executes a SELECT query on the named table, expecting to return the - max value for the column "id". - - Args: - table : string giving the table name - """ - sql = "SELECT MAX(id) AS id FROM %s" % table - - def func(txn): - txn.execute(sql) - max_id = self.cursor_to_dict(txn)[0]["id"] - if max_id is None: - return 0 - return max_id - - return self.runInteraction("_simple_max_id", func) - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id @@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass - - -class Table(object): - """ A base class used to store information about a particular table. - """ - - table_name = None - """ str: The name of the table """ - - fields = None - """ list: The field names """ - - EntryType = None - """ Type: A tuple type used to decode the results """ - - _select_where_clause = "SELECT %s FROM %s WHERE %s" - _select_clause = "SELECT %s FROM %s" - _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" - - @classmethod - def select_statement(cls, where_clause=None): - """ - Args: - where_clause (str): The WHERE clause to use. - - Returns: - str: An SQL statement to select rows from the table with the given - WHERE clause. - """ - if where_clause: - return cls._select_where_clause % ( - ", ".join(cls.fields), - cls.table_name, - where_clause - ) - else: - return cls._select_clause % ( - ", ".join(cls.fields), - cls.table_name, - ) - - @classmethod - def insert_statement(cls): - return cls._insert_clause % ( - cls.table_name, - ", ".join(cls.fields), - ", ".join(["?"] * len(cls.fields)), - ) - - @classmethod - def decode_single_result(cls, results): - """ Given an iterable of tuples, return a single instance of - `EntryType` or None if the iterable is empty - Args: - results (list): The results list to convert to `EntryType` - Returns: - EntryType: An instance of `EntryType` - """ - results = list(results) - if results: - return cls.EntryType(*results[0]) - else: - return None - - @classmethod - def decode_results(cls, results): - """ Given an iterable of tuples, return a list of `EntryType` - Args: - results (list): The results list to convert to `EntryType` - - Returns: - list: A list of `EntryType` - """ - return [cls.EntryType(*row) for row in results] - - @classmethod - def get_fields_string(cls, prefix=None): - if prefix: - to_join = ("%s.%s" % (prefix, f) for f in cls.fields) - else: - to_join = cls.fields - - return ", ".join(to_join) - - -class JoinHelper(object): - """ Used to help do joins on tables by looking at the tables' fields and - creating a list of unique fields to use with SELECTs and a namedtuple - to dump the results into. - - Attributes: - tables (list): List of `Table` classes - EntryType (type) - """ - - def __init__(self, *tables): - self.tables = tables - - res = [] - for table in self.tables: - res += [f for f in table.fields if f not in res] - - self.EntryType = namedtuple("JoinHelperEntry", res) - - def get_fields(self, **prefixes): - """Get a string representing a list of fields for use in SELECT - statements with the given prefixes applied to each. - - For example:: - - JoinHelper(PdusTable, StateTable).get_fields( - PdusTable="pdus", - StateTable="state" - ) - """ - res = [] - for field in self.EntryType._fields: - for table in self.tables: - if field in table.fields: - res.append("%s.%s" % (prefixes[table.__name__], field)) - break - - return ", ".join(res) - - def decode_results(self, rows): - return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index c1cabbaa60..6d4421dd8f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_latest_state_in_room(self, txn, room_id, type, state_key): - event_ids = self._simple_select_onecol_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "room_id": room_id, - "type": type, - "state_key": state_key, - }, - retcol="event_id", - ) - - results = [] - for event_id in event_ids: - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((event_id, prev_hashes)) - - return results - - def _get_prev_events(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=0, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_state(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=True, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_events_and_state(self, txn, event_id, is_state=None): - keyvalues = { - "event_id": event_id, - } - - if is_state is not None: - keyvalues["is_state"] = bool(is_state) - - res = self._simple_select_list_txn( - txn, - table="event_edges", - keyvalues=keyvalues, - retcols=["prev_event_id", "is_state"], - ) - - hashes = self._get_prev_event_hashes_txn(txn, event_id) - - results = [] - for d in res: - edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) - edge_hash.update(hashes.get(d["prev_event_id"], {})) - prev_hashes = { - k: encode_base64(v) - for k, v in edge_hash.items() - if k == "sha256" - } - results.append((d["prev_event_id"], prev_hashes, d["is_state"])) - - return results - - def _get_auth_events(self, txn, event_id): - auth_ids = self._simple_select_onecol_txn( - txn, - table="event_auth", - keyvalues={ - "event_id": event_id, - }, - retcol="auth_id", - ) - - results = [] - for auth_id in auth_ids: - hashes = self._get_event_reference_hashes_txn(txn, auth_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((auth_id, prev_hashes)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 46df6b4d6d..416ef6af93 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -890,22 +890,11 @@ class EventsStore(SQLBaseStore): return ev - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - def _parse_events_txn(self, txn, rows): event_ids = [r["event_id"] for r in rows] return self._get_events_txn(txn, event_ids) - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 00b748f131..345c4e1104 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore from twisted.internet import defer from synapse.api.errors import StoreError @@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore): ) -class PushersTable(Table): +class PushersTable(object): table_name = "pushers" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 41c939efb1..8c40d9a8a6 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -178,12 +178,6 @@ class RoomMemberStore(SQLBaseStore): return joined_domains - def _get_members_query(self, where_clause, where_values): - return self.runInteraction( - "get_members_query", self._get_members_events_txn, - where_clause, where_values - ).addCallbacks(self._get_events) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ab57b92174..b070be504d 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - def _get_event_content_hashes_txn(self, txn, event_id): - """Get all the hashes for a given Event. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_content_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - return dict(txn.fetchall()) - - def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): - """Store a hash for a Event - Args: - txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. - """ - self._simple_insert_txn( - txn, - "event_content_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) - def get_event_reference_hashes(self, event_ids): def f(txn): return [ @@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore): table="event_reference_hashes", values=vals, ) - - def _get_event_signatures_txn(self, txn, event_id): - """Get all the signatures for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of sig name -> dict(key_id -> signature_bytes) - """ - query = ( - "SELECT signature_name, key_id, signature" - " FROM event_signatures" - " WHERE event_id = ? " - ) - txn.execute(query, (event_id, )) - rows = txn.fetchall() - - res = {} - - for name, key, sig in rows: - res.setdefault(name, {})[key] = sig - - return res - - def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): - """Store a signature from the origin server for a PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - origin (str): origin of the Event. - key_id (str): Id for the signing key. - signature (bytes): The signature. - """ - self._simple_insert_txn( - txn, - "event_signatures", - { - "event_id": event_id, - "signature_name": signature_name, - "key_id": key_id, - "signature": buffer(signature_bytes), - }, - ) - - def _get_prev_event_hashes_txn(self, txn, event_id): - """Get all the hashes for previous PDUs of a PDU - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. - """ - query = ( - "SELECT prev_event_id, algorithm, hash" - " FROM event_edge_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - results = {} - for prev_event_id, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault(prev_event_id, {}) - hashes[algorithm] = hash_bytes - return results - - def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): - self._simple_insert_txn( - txn, - "event_edge_hashes", - { - "event_id": event_id, - "prev_event_id": prev_event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9630efcfcc..e935b9443b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import ( from twisted.internet import defer -from synapse.util.stringutils import random_string - import logging logger = logging.getLogger(__name__) @@ -428,7 +426,3 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) - - -def _make_group_id(clock): - return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index aaa3609aa5..699083ae12 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events_for_user(self, user, from_key, limit): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - class EventSources(object): SOURCE_TYPES = { "room": RoomEventSource, @@ -70,15 +54,3 @@ class EventSources(object): ), ) defer.returnValue(token) - - -class StreamSource(object): - def get_new_events_for_user(self, user, from_key, limit): - """from_key is the key within this event source.""" - raise NotImplementedError("get_new_events_for_user") - - def get_current_key(self): - raise NotImplementedError("get_current_key") - - def get_pagination_rows(self, user, pagination_config, key): - raise NotImplementedError("get_rows") diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 07ff25cef3..1d123ccefc 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -29,34 +29,6 @@ def unwrapFirstError(failure): return failure.value.subFailure -def unwrap_deferred(d): - """Given a deferred that we know has completed, return its value or raise - the failure as an exception - """ - if not d.called: - raise RuntimeError("deferred has not finished") - - res = [] - - def f(r): - res.append(r) - return r - d.addCallback(f) - - if res: - return res[0] - - def f(r): - res.append(r) - return r - d.addErrback(f) - - if res: - res[0].raiseException() - else: - raise RuntimeError("deferred did not call callbacks") - - class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. |