summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12556.misc1
-rwxr-xr-xscripts-dev/release.py63
2 files changed, 38 insertions, 26 deletions
diff --git a/changelog.d/12556.misc b/changelog.d/12556.misc
new file mode 100644
index 0000000000..dc245397fb
--- /dev/null
+++ b/changelog.d/12556.misc
@@ -0,0 +1 @@
+Release script: confirm the commit to be tagged before tagging.
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 14f3f3a45d..f4269e09bb 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -89,13 +89,7 @@ def prepare() -> None:
     """
 
     # Make sure we're in a git repo.
-    try:
-        repo = git.Repo()
-    except git.InvalidGitRepositoryError:
-        raise click.ClickException("Not in Synapse repo.")
-
-    if repo.is_dirty():
-        raise click.ClickException("Uncommitted changes exist.")
+    repo = get_repo_and_check_clean_checkout()
 
     click.secho("Updating git repo...")
     repo.remote().fetch()
@@ -171,9 +165,7 @@ def prepare() -> None:
     assert not parsed_new_version.is_devrelease
     assert not parsed_new_version.is_postrelease
 
-    release_branch_name = (
-        f"release-v{parsed_new_version.major}.{parsed_new_version.minor}"
-    )
+    release_branch_name = get_release_branch_name(parsed_new_version)
     release_branch = find_ref(repo, release_branch_name)
     if release_branch:
         if release_branch.is_remote():
@@ -274,13 +266,7 @@ def tag(gh_token: Optional[str]) -> None:
     """Tags the release and generates a draft GitHub release"""
 
     # Make sure we're in a git repo.
-    try:
-        repo = git.Repo()
-    except git.InvalidGitRepositoryError:
-        raise click.ClickException("Not in Synapse repo.")
-
-    if repo.is_dirty():
-        raise click.ClickException("Uncommitted changes exist.")
+    repo = get_repo_and_check_clean_checkout()
 
     click.secho("Updating git repo...")
     repo.remote().fetch()
@@ -293,6 +279,15 @@ def tag(gh_token: Optional[str]) -> None:
     if tag_name in repo.tags:
         raise click.ClickException(f"Tag {tag_name} already exists!\n")
 
+    # Check we're on the right release branch
+    release_branch = get_release_branch_name(current_version)
+    if repo.active_branch.name != release_branch:
+        click.echo(
+            f"Need to be on the release branch ({release_branch}) before tagging. "
+            f"Currently on ({repo.active_branch.name})."
+        )
+        click.get_current_context().abort()
+
     # Get the appropriate changelogs and tag.
     changes = get_changes_for_version(current_version)
 
@@ -358,21 +353,15 @@ def tag(gh_token: Optional[str]) -> None:
 @cli.command()
 @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True)
 def publish(gh_token: str) -> None:
-    """Publish release."""
+    """Publish release on GitHub."""
 
     # Make sure we're in a git repo.
-    try:
-        repo = git.Repo()
-    except git.InvalidGitRepositoryError:
-        raise click.ClickException("Not in Synapse repo.")
-
-    if repo.is_dirty():
-        raise click.ClickException("Uncommitted changes exist.")
+    get_repo_and_check_clean_checkout()
 
     current_version = get_package_version()
     tag_name = f"v{current_version}"
 
-    if not click.confirm(f"Publish {tag_name}?", default=True):
+    if not click.confirm(f"Publish release {tag_name} on GitHub?", default=True):
         return
 
     # Publish the draft release
@@ -406,6 +395,13 @@ def upload() -> None:
     current_version = get_package_version()
     tag_name = f"v{current_version}"
 
+    # Check we have the right tag checked out.
+    repo = get_repo_and_check_clean_checkout()
+    tag = repo.tag(f"refs/tags/{tag_name}")
+    if repo.head.commit != tag.commit:
+        click.echo("Tag {tag_name} (tag.commit) is not currently checked out!")
+        click.get_current_context().abort()
+
     pypi_asset_names = [
         f"matrix_synapse-{current_version}-py3-none-any.whl",
         f"matrix-synapse-{current_version}.tar.gz",
@@ -469,6 +465,21 @@ def get_package_version() -> version.Version:
     return version.Version(version_string)
 
 
+def get_release_branch_name(version_number: version.Version) -> str:
+    return f"release-v{version_number.major}.{version_number.minor}"
+
+
+def get_repo_and_check_clean_checkout() -> git.Repo:
+    """Get the project repo and check it's not got any uncommitted changes."""
+    try:
+        repo = git.Repo()
+    except git.InvalidGitRepositoryError:
+        raise click.ClickException("Not in Synapse repo.")
+    if repo.is_dirty():
+        raise click.ClickException("Uncommitted changes exist.")
+    return repo
+
+
 def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
     """Find the branch/ref, looking first locally then in the remote."""
     if ref_name in repo.references: