summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/hash_password18
-rwxr-xr-xscripts/register_new_matrix_user32
-rwxr-xr-xscripts/synapse_port_db175
3 files changed, 175 insertions, 50 deletions
diff --git a/scripts/hash_password b/scripts/hash_password
index e784600989..215ab25cfe 100755
--- a/scripts/hash_password
+++ b/scripts/hash_password
@@ -1,10 +1,16 @@
 #!/usr/bin/env python
 
 import argparse
+
+import sys
+
 import bcrypt
 import getpass
 
+import yaml
+
 bcrypt_rounds=12
+password_pepper = ""
 
 def prompt_for_pass():
     password = getpass.getpass("Password: ")
@@ -28,12 +34,22 @@ if __name__ == "__main__":
         default=None,
         help="New password for user. Will prompt if omitted.",
     )
+    parser.add_argument(
+        "-c", "--config",
+        type=argparse.FileType('r'),
+        help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
+    )
 
     args = parser.parse_args()
+    if "config" in args and args.config:
+        config = yaml.safe_load(args.config)
+        bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
+        password_config = config.get("password_config", {})
+        password_pepper = password_config.get("pepper", password_pepper)
     password = args.password
 
     if not password:
         password = prompt_for_pass()
 
-    print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
+    print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
 
diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user
index 27a6250b14..12ed20d623 100755
--- a/scripts/register_new_matrix_user
+++ b/scripts/register_new_matrix_user
@@ -25,18 +25,26 @@ import urllib2
 import yaml
 
 
-def request_registration(user, password, server_location, shared_secret):
+def request_registration(user, password, server_location, shared_secret, admin=False):
     mac = hmac.new(
         key=shared_secret,
-        msg=user,
         digestmod=hashlib.sha1,
-    ).hexdigest()
+    )
+
+    mac.update(user)
+    mac.update("\x00")
+    mac.update(password)
+    mac.update("\x00")
+    mac.update("admin" if admin else "notadmin")
+
+    mac = mac.hexdigest()
 
     data = {
         "user": user,
         "password": password,
         "mac": mac,
         "type": "org.matrix.login.shared_secret",
+        "admin": admin,
     }
 
     server_location = server_location.rstrip("/")
@@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
         sys.exit(1)
 
 
-def register_new_user(user, password, server_location, shared_secret):
+def register_new_user(user, password, server_location, shared_secret, admin):
     if not user:
         try:
             default_user = getpass.getuser()
@@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
             print "Passwords do not match"
             sys.exit(1)
 
-    request_registration(user, password, server_location, shared_secret)
+    if not admin:
+        admin = raw_input("Make admin [no]: ")
+        if admin in ("y", "yes", "true"):
+            admin = True
+        else:
+            admin = False
+
+    request_registration(user, password, server_location, shared_secret, bool(admin))
 
 
 if __name__ == "__main__":
@@ -119,6 +134,11 @@ if __name__ == "__main__":
         default=None,
         help="New password for user. Will prompt if omitted.",
     )
+    parser.add_argument(
+        "-a", "--admin",
+        action="store_true",
+        help="Register new user as an admin. Will prompt if omitted.",
+    )
 
     group = parser.add_mutually_exclusive_group(required=True)
     group.add_argument(
@@ -151,4 +171,4 @@ if __name__ == "__main__":
     else:
         secret = args.shared_secret
 
-    register_new_user(args.user, args.password, args.server_url, secret)
+    register_new_user(args.user, args.password, args.server_url, secret, args.admin)
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index efd04da2d6..66c61b0198 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
 
 
 BOOLEAN_COLUMNS = {
-    "events": ["processed", "outlier"],
+    "events": ["processed", "outlier", "contains_url"],
     "rooms": ["is_public"],
     "event_edges": ["is_state"],
     "presence_list": ["accepted"],
@@ -92,8 +92,12 @@ class Store(object):
 
     _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
     _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
+    _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
+    _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
     _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
-    _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
+    _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
+        "_simple_select_one_onecol_txn"
+    ]
 
     _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
     _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -158,31 +162,40 @@ class Porter(object):
     def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            next_chunk = yield self.postgres_store._simple_select_one_onecol(
+            row = yield self.postgres_store._simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
-                retcol="rowid",
+                retcols=("forward_rowid", "backward_rowid"),
                 allow_none=True,
             )
 
             total_to_port = None
-            if next_chunk is None:
+            if row is None:
                 if table == "sent_transactions":
-                    next_chunk, already_ported, total_to_port = (
+                    forward_chunk, already_ported, total_to_port = (
                         yield self._setup_sent_transactions()
                     )
+                    backward_chunk = 0
                 else:
                     yield self.postgres_store._simple_insert(
                         table="port_from_sqlite3",
-                        values={"table_name": table, "rowid": 1}
+                        values={
+                            "table_name": table,
+                            "forward_rowid": 1,
+                            "backward_rowid": 0,
+                        }
                     )
 
-                    next_chunk = 1
+                    forward_chunk = 1
+                    backward_chunk = 0
                     already_ported = 0
+            else:
+                forward_chunk = row["forward_rowid"]
+                backward_chunk = row["backward_rowid"]
 
             if total_to_port is None:
                 already_ported, total_to_port = yield self._get_total_count_to_port(
-                    table, next_chunk
+                    table, forward_chunk, backward_chunk
                 )
         else:
             def delete_all(txn):
@@ -196,46 +209,85 @@ class Porter(object):
 
             yield self.postgres_store._simple_insert(
                 table="port_from_sqlite3",
-                values={"table_name": table, "rowid": 0}
+                values={
+                    "table_name": table,
+                    "forward_rowid": 1,
+                    "backward_rowid": 0,
+                }
             )
 
-            next_chunk = 1
+            forward_chunk = 1
+            backward_chunk = 0
 
             already_ported, total_to_port = yield self._get_total_count_to_port(
-                table, next_chunk
+                table, forward_chunk, backward_chunk
             )
 
-        defer.returnValue((table, already_ported, total_to_port, next_chunk))
+        defer.returnValue(
+            (table, already_ported, total_to_port, forward_chunk, backward_chunk)
+        )
 
     @defer.inlineCallbacks
-    def handle_table(self, table, postgres_size, table_size, next_chunk):
+    def handle_table(self, table, postgres_size, table_size, forward_chunk,
+                     backward_chunk):
         if not table_size:
             return
 
         self.progress.add_table(table, postgres_size, table_size)
 
         if table == "event_search":
-            yield self.handle_search_table(postgres_size, table_size, next_chunk)
+            yield self.handle_search_table(
+                postgres_size, table_size, forward_chunk, backward_chunk
+            )
             return
 
-        select = (
+        forward_select = (
             "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
             % (table,)
         )
 
+        backward_select = (
+            "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
+            % (table,)
+        )
+
+        do_forward = [True]
+        do_backward = [True]
+
         while True:
             def r(txn):
-                txn.execute(select, (next_chunk, self.batch_size,))
-                rows = txn.fetchall()
-                headers = [column[0] for column in txn.description]
+                forward_rows = []
+                backward_rows = []
+                if do_forward[0]:
+                    txn.execute(forward_select, (forward_chunk, self.batch_size,))
+                    forward_rows = txn.fetchall()
+                    if not forward_rows:
+                        do_forward[0] = False
+
+                if do_backward[0]:
+                    txn.execute(backward_select, (backward_chunk, self.batch_size,))
+                    backward_rows = txn.fetchall()
+                    if not backward_rows:
+                        do_backward[0] = False
+
+                if forward_rows or backward_rows:
+                    headers = [column[0] for column in txn.description]
+                else:
+                    headers = None
 
-                return headers, rows
+                return headers, forward_rows, backward_rows
 
-            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+            headers, frows, brows = yield self.sqlite_store.runInteraction(
+                "select", r
+            )
 
-            if rows:
-                next_chunk = rows[-1][0] + 1
+            if frows or brows:
+                if frows:
+                    forward_chunk = max(row[0] for row in frows) + 1
+                if brows:
+                    backward_chunk = min(row[0] for row in brows) - 1
 
+                rows = frows + brows
                 self._convert_rows(table, headers, rows)
 
                 def insert(txn):
@@ -247,7 +299,10 @@ class Porter(object):
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": table},
-                        updatevalues={"rowid": next_chunk},
+                        updatevalues={
+                            "forward_rowid": forward_chunk,
+                            "backward_rowid": backward_chunk,
+                        },
                     )
 
                 yield self.postgres_store.execute(insert)
@@ -259,7 +314,8 @@ class Porter(object):
                 return
 
     @defer.inlineCallbacks
-    def handle_search_table(self, postgres_size, table_size, next_chunk):
+    def handle_search_table(self, postgres_size, table_size, forward_chunk,
+                            backward_chunk):
         select = (
             "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
             " FROM event_search as es"
@@ -270,7 +326,7 @@ class Porter(object):
 
         while True:
             def r(txn):
-                txn.execute(select, (next_chunk, self.batch_size,))
+                txn.execute(select, (forward_chunk, self.batch_size,))
                 rows = txn.fetchall()
                 headers = [column[0] for column in txn.description]
 
@@ -279,7 +335,7 @@ class Porter(object):
             headers, rows = yield self.sqlite_store.runInteraction("select", r)
 
             if rows:
-                next_chunk = rows[-1][0] + 1
+                forward_chunk = rows[-1][0] + 1
 
                 # We have to treat event_search differently since it has a
                 # different structure in the two different databases.
@@ -312,7 +368,10 @@ class Porter(object):
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": "event_search"},
-                        updatevalues={"rowid": next_chunk},
+                        updatevalues={
+                            "forward_rowid": forward_chunk,
+                            "backward_rowid": backward_chunk,
+                        },
                     )
 
                 yield self.postgres_store.execute(insert)
@@ -324,7 +383,6 @@ class Porter(object):
             else:
                 return
 
-
     def setup_db(self, db_config, database_engine):
         db_conn = database_engine.module.connect(
             **{
@@ -395,10 +453,32 @@ class Porter(object):
                 txn.execute(
                     "CREATE TABLE port_from_sqlite3 ("
                     " table_name varchar(100) NOT NULL UNIQUE,"
-                    " rowid bigint NOT NULL"
+                    " forward_rowid bigint NOT NULL,"
+                    " backward_rowid bigint NOT NULL"
                     ")"
                 )
 
+            # The old port script created a table with just a "rowid" column.
+            # We want people to be able to rerun this script from an old port
+            # so that they can pick up any missing events that were not
+            # ported across.
+            def alter_table(txn):
+                txn.execute(
+                    "ALTER TABLE IF EXISTS port_from_sqlite3"
+                    " RENAME rowid TO forward_rowid"
+                )
+                txn.execute(
+                    "ALTER TABLE IF EXISTS port_from_sqlite3"
+                    " ADD backward_rowid bigint NOT NULL DEFAULT 0"
+                )
+
+            try:
+                yield self.postgres_store.runInteraction(
+                    "alter_table", alter_table
+                )
+            except Exception as e:
+                logger.info("Failed to create port table: %s", e)
+
             try:
                 yield self.postgres_store.runInteraction(
                     "create_port_table", create_port_table
@@ -458,7 +538,7 @@ class Porter(object):
     @defer.inlineCallbacks
     def _setup_sent_transactions(self):
         # Only save things from the last day
-        yesterday = int(time.time()*1000) - 86400000
+        yesterday = int(time.time() * 1000) - 86400000
 
         # And save the max transaction id from each destination
         select = (
@@ -514,7 +594,11 @@ class Porter(object):
 
         yield self.postgres_store._simple_insert(
             table="port_from_sqlite3",
-            values={"table_name": "sent_transactions", "rowid": next_chunk}
+            values={
+                "table_name": "sent_transactions",
+                "forward_rowid": next_chunk,
+                "backward_rowid": 0,
+            }
         )
 
         def get_sent_table_size(txn):
@@ -535,13 +619,18 @@ class Porter(object):
         defer.returnValue((next_chunk, inserted_rows, total_count))
 
     @defer.inlineCallbacks
-    def _get_remaining_count_to_port(self, table, next_chunk):
-        rows = yield self.sqlite_store.execute_sql(
+    def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
+        frows = yield self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
-            next_chunk,
+            forward_chunk,
         )
 
-        defer.returnValue(rows[0][0])
+        brows = yield self.sqlite_store.execute_sql(
+            "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
+            backward_chunk,
+        )
+
+        defer.returnValue(frows[0][0] + brows[0][0])
 
     @defer.inlineCallbacks
     def _get_already_ported_count(self, table):
@@ -552,10 +641,10 @@ class Porter(object):
         defer.returnValue(rows[0][0])
 
     @defer.inlineCallbacks
-    def _get_total_count_to_port(self, table, next_chunk):
+    def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
         remaining, done = yield defer.gatherResults(
             [
-                self._get_remaining_count_to_port(table, next_chunk),
+                self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
                 self._get_already_ported_count(table),
             ],
             consumeErrors=True,
@@ -686,7 +775,7 @@ class CursesProgress(Progress):
             color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
 
             self.stdscr.addstr(
-                i+2, left_margin + max_len - len(table),
+                i + 2, left_margin + max_len - len(table),
                 table,
                 curses.A_BOLD | color,
             )
@@ -694,18 +783,18 @@ class CursesProgress(Progress):
             size = 20
 
             progress = "[%s%s]" % (
-                "#" * int(perc*size/100),
-                " " * (size - int(perc*size/100)),
+                "#" * int(perc * size / 100),
+                " " * (size - int(perc * size / 100)),
             )
 
             self.stdscr.addstr(
-                i+2, left_margin + max_len + middle_space,
+                i + 2, left_margin + max_len + middle_space,
                 "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
             )
 
         if self.finished:
             self.stdscr.addstr(
-                rows-1, 0,
+                rows - 1, 0,
                 "Press any key to exit...",
             )