Эх сурвалжийг харах

Modify git squash-branch to perform reparenting

Bug: 40264739
Change-Id: I4ad7f4f8a670334b32c239458048e56c6af44098
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/6227541
Reviewed-by: Josip Sokcevic <sokcevic@chromium.org>
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
Commit-Queue: Alexander Cooper <alcooper@chromium.org>
Alexander Cooper 6 сар өмнө
parent
commit
380df04b62

+ 54 - 0
git_common.py

@@ -625,6 +625,60 @@ def get_branch_tree(use_limit=False):
     return skipped, branch_tree
 
 
+def get_diverged_branches(branch_tree=None):
+    """Gets the branches from the tree that have diverged from their upstream
+
+    Returns the list of branches that have diverged from their respective
+    upstream branch.
+    Expects to receive a tree as generated from `get_branch_tree`, which it will
+    call if not supplied, ignoring branches without upstreams.
+    """
+    if not branch_tree:
+        _, branch_tree = get_branch_tree()
+    diverged_branches = []
+    for branch, upstream_branch in branch_tree.items():
+        # If the merge base of a branch and its upstream is not equal to the
+        # upstream, then it means that both branch diverged.
+        upstream_branch_hash = hash_one(upstream_branch)
+        merge_base_hash = hash_one(get_or_create_merge_base(branch))
+        if upstream_branch_hash != merge_base_hash:
+            diverged_branches.append(branch)
+    return diverged_branches
+
+
+def get_hashes(branch_tree=None):
+    """Get the dictionary of {branch: hash}
+
+    Returns a dictionary that contains the hash of every branch. Suitable for
+    saving hashes before performing destructive operations to perform
+    appropriate rebases.
+    Expects to receive a tree as generated from `get_branch_tree`, which it will
+    call if not supplied, ignoring branches without upstreams.
+    """
+    if not branch_tree:
+        _, branch_tree = get_branch_tree()
+    hashes = {}
+    for branch, upstream_branch in branch_tree.items():
+        hashes[branch] = hash_one(branch)
+        hashes[upstream_branch] = hash_one(upstream_branch)
+    return hashes
+
+
+def get_downstream_branches(branch_tree=None):
+    """Get the dictionary of {branch: children}
+
+    Returns a dictionary that contains the list of downstream branches for every
+    branch.
+    Expects to receive a tree as generated from `get_branch_tree`, which it will
+    call if not supplied, ignoring branches without upstreams.
+    """
+    if not branch_tree:
+        _, branch_tree = get_branch_tree()
+    downstream_branches = collections.defaultdict(list)
+    for branch, upstream_branch in branch_tree.items():
+        downstream_branches[upstream_branch].append(branch)
+    return downstream_branches
+
 def get_or_create_merge_base(branch, parent=None) -> Optional[str]:
     """Finds the configured merge base for branch.
 

+ 69 - 0
git_squash_branch.py

@@ -10,6 +10,48 @@ import gclient_utils
 import git_common
 
 
+# Squash a branch, taking care to rebase the branch on top of the new commit
+# position of its upstream branch.
+def rebase_branch(branch, initial_hashes):
+    print('Re-parenting branch %s.' % branch)
+    assert initial_hashes[branch] == git_common.hash_one(branch)
+
+    upstream_branch = git_common.upstream(branch)
+    old_upstream_branch = initial_hashes[upstream_branch]
+
+    # Because the branch's upstream has potentially changed from squashing it,
+    # the current branch is rebased on top of the new upstream.
+    git_common.run('rebase', '--onto', upstream_branch, old_upstream_branch,
+                   branch, '--update-refs')
+
+
+# Squashes all branches that are part of the subtree starting at `branch`.
+def rebase_subtree(branch, initial_hashes, downstream_branches):
+    # Rebase us onto our parent
+    rebase_branch(branch, initial_hashes)
+
+    # Recurse on downstream branches, if any.
+    for downstream_branch in downstream_branches[branch]:
+        rebase_subtree(downstream_branch, initial_hashes, downstream_branches)
+
+
+def children_have_diverged(branch, downstream_branches, diverged_branches):
+    # If we have no diverged branches, then no children have diverged.
+    if not diverged_branches:
+        return False
+
+    # If we have diverged, then our children have diverged.
+    if branch in diverged_branches:
+        return True
+
+    # If any of our children have diverged, then we need to return true.
+    for downstream_branch in downstream_branches[branch]:
+        if children_have_diverged(downstream_branch, downstream_branches,
+                                  diverged_branches):
+            return True
+
+    return False
+
 def main(args):
     if gclient_utils.IsEnvCog():
         print('squash-branch command is not supported in non-git environment.',
@@ -25,7 +67,34 @@ def main(args):
     opts = parser.parse_args(args)
     if git_common.is_dirty_git_tree('squash-branch'):
         return 1
+
+    # Save off the current branch so we can return to it at the end.
+    return_branch = git_common.current_branch()
+
+    # Save the hashes before we mutate the tree so that we have all of the
+    # necessary rebasing information.
+    _, tree = git_common.get_branch_tree()
+    initial_hashes = git_common.get_hashes(tree)
+    downstream_branches = git_common.get_downstream_branches(tree)
+    diverged_branches = git_common.get_diverged_branches(tree)
+
+    # We won't be rebasing our squashed branch, so only check any potential
+    # children
+    for branch in downstream_branches[return_branch]:
+        if children_have_diverged(branch, downstream_branches,
+                                  diverged_branches):
+            print('Cannot use `git squash-branch` since some children have '
+                  'diverged from their upstream and could cause conflicts.')
+            return 1
+
     git_common.squash_current_branch(opts.message)
+
+    # Fixup our children with our new state.
+    for branch in downstream_branches[return_branch]:
+        rebase_subtree(branch, initial_hashes, downstream_branches)
+
+    git_common.run('checkout', return_branch)
+
     return 0
 
 

+ 3 - 35
git_squash_branch_tree.py

@@ -13,38 +13,6 @@ import git_common as git
 import sys
 
 
-# Returns the list of branches that have diverged from their respective upstream
-# branch.
-def get_diverged_branches(tree):
-    diverged_branches = []
-    for branch, upstream_branch in tree.items():
-        # If the merge base of a branch and its upstream is not equal to the
-        # upstream, then it means that both branch diverged.
-        upstream_branch_hash = git.hash_one(upstream_branch)
-        merge_base_hash = git.hash_one(git.get_or_create_merge_base(branch))
-        if upstream_branch_hash != merge_base_hash:
-            diverged_branches.append(branch)
-    return diverged_branches
-
-
-# Returns a dictionary that contains the hash of every branch before the
-# squashing started.
-def get_initial_hashes(tree):
-    initial_hashes = {}
-    for branch, upstream_branch in tree.items():
-        initial_hashes[branch] = git.hash_one(branch)
-        initial_hashes[upstream_branch] = git.hash_one(upstream_branch)
-    return initial_hashes
-
-
-# Returns a dictionary that contains the downstream branches of every branch.
-def get_downstream_branches(tree):
-    downstream_branches = collections.defaultdict(list)
-    for branch, upstream_branch in tree.items():
-        downstream_branches[upstream_branch].append(branch)
-    return downstream_branches
-
-
 # Squash a branch, taking care to rebase the branch on top of the new commit
 # position of its upstream branch.
 def squash_branch(branch, initial_hashes):
@@ -102,7 +70,7 @@ def main(args=None):
         print('Use --ignore-no-upstream to ignore this check and proceed.')
         return 1
 
-    diverged_branches = get_diverged_branches(tree)
+    diverged_branches = git.get_diverged_branches(tree)
     if diverged_branches:
         print('Cannot use `git squash-branch-tree` since the following\n'
               'branches have diverged from their upstream and could cause\n'
@@ -115,8 +83,8 @@ def main(args=None):
     # we can go back to it at the end.
     return_branch = git.current_branch()
 
-    initial_hashes = get_initial_hashes(tree)
-    downstream_branches = get_downstream_branches(tree)
+    initial_hashes = git.get_hashes(tree)
+    downstream_branches = git.get_downstream_branches(tree)
     squash_subtree(opts.branch, initial_hashes, downstream_branches)
 
     git.run('checkout', return_branch)

+ 28 - 0
tests/git_common_test.py

@@ -751,6 +751,34 @@ class GitMutableStructuredTest(git_test_utils.GitRepoReadWriteTestBase,
             ('root_A', 'root_X'),
         ])
 
+    def testGetHashes(self):
+        hashes = self.repo.run(self.gc.get_hashes)
+        for branch, branch_hash in hashes.items():
+            self.assertEqual(self.repo.run(self.gc.hash_one, branch),
+                             branch_hash)
+
+    def testGetDownstreamBranches(self):
+        downstream_branches = self.repo.run(self.gc.get_downstream_branches)
+        self.assertEqual(
+            downstream_branches, {
+                'root_A': ['branch_G'],
+                'branch_G': ['branch_K'],
+                'branch_K': ['branch_L'],
+                'root_X': ['branch_Z', 'root_A'],
+            })
+
+    def testGetDivergedBranches(self):
+        # root_X and root_A don't actually have a common base commit due to the
+        # test repo's structure, which causes get_diverged_branches to throw
+        # an error.
+        self.repo.git('branch', '--unset-upstream', 'root_A')
+
+        # K is setup with G as it's root, but it's branched at B.
+        # L is setup with K as it's root, but it's branched at J.
+        diverged_branches = self.repo.run(self.gc.get_diverged_branches)
+        self.assertEqual(diverged_branches, ['branch_K', 'branch_L'])
+
+
     def testIsGitTreeDirty(self):
         retval = []
         self.repo.capture_stdio(lambda: retval.append(

+ 146 - 0
tests/git_squash_branch_test.py

@@ -0,0 +1,146 @@
+#!/usr/bin/env vpython3
+# coding=utf-8
+# Copyright 2024 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Tests for git_squash_branch."""
+
+import os
+import sys
+import unittest
+
+DEPOT_TOOLS_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, DEPOT_TOOLS_ROOT)
+
+from testing_support import git_test_utils
+
+import git_squash_branch
+import git_common
+
+git_common.TEST_MODE = True
+
+
+class GitSquashBranchTest(git_test_utils.GitRepoReadWriteTestBase):
+    # Empty repo.
+    REPO_SCHEMA = """
+  """
+
+    def setUp(self):
+        super(GitSquashBranchTest, self).setUp()
+
+        # Note: Using the REPO_SCHEMA wouldn't simplify this test so it is not
+        #       used.
+        #
+        # Create a repo with the follow schema
+        #
+        # main <- branchA <- branchB <- branchC
+        #            ^
+        #            \ branchD
+        #
+        # where each branch has 2 commits.
+
+        # The repo is empty. Add the first commit or else most commands don't
+        # work, including `git branch`, which doesn't even show the main branch.
+        self.repo.git('commit', '-m', 'First commit', '--allow-empty')
+
+        # Create the first branch downstream from `main` with 2 commits.
+        self.repo.git('checkout', '-B', 'branchA', '--track', 'main')
+        self._createFileAndCommit('fileA1')
+        self._createFileAndCommit('fileA2')
+
+        # Create a branch downstream from `branchA` with 2 commits.
+        self.repo.git('checkout', '-B', 'branchB', '--track', 'branchA')
+        self._createFileAndCommit('fileB1')
+        self._createFileAndCommit('fileB2')
+
+        # Create another branch downstream from `branchB` with 2 commits.
+        self.repo.git('checkout', '-B', 'branchC', '--track', 'branchB')
+        self._createFileAndCommit('fileC1')
+        self._createFileAndCommit('fileC2')
+
+        # Create another branch downstream from `branchA` with 2 commits.
+        self.repo.git('checkout', '-B', 'branchD', '--track', 'branchA')
+        self._createFileAndCommit('fileD1')
+        self._createFileAndCommit('fileD2')
+
+    def testGitSquashBranchFailsWithDivergedBranch(self):
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+        self.repo.git('checkout', 'branchB')
+        self._createFileAndCommit('fileB3')
+        self.repo.git('checkout', 'branchA')
+
+        # We have now made a state where branchC has diverged from branchB.
+        output, _ = self.repo.capture_stdio(git_squash_branch.main, [])
+        self.assertIn('some children have diverged', output)
+
+    def testGitSquashBranchRootOnly(self):
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+        self.repo.git('checkout', 'branchA')
+        self.repo.run(git_squash_branch.main, [])
+
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+    def testGitSquashBranchLeaf(self):
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+        self.repo.git('checkout', 'branchD')
+        self.repo.run(git_squash_branch.main, [])
+
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 1)
+
+    def testGitSquashBranchSequential(self):
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+        self.repo.git('checkout', 'branchA')
+        self.repo.run(git_squash_branch.main, [])
+
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+        self.repo.git('checkout', 'branchB')
+        self.repo.run(git_squash_branch.main, [])
+
+        self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
+        self.assertEqual(self._getCountAheadOfUpstream('branchB'), 1)
+        self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
+        self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
+
+    # Creates a file with arbitrary contents and commit it to the current
+    # branch.
+    def _createFileAndCommit(self, filename):
+        with self.repo.open(filename, 'w') as f:
+            f.write('content')
+        self.repo.git('add', filename)
+        self.repo.git_commit('Added file ' + filename)
+
+    # Returns the count of how many commits `branch` is ahead of its upstream.
+    def _getCountAheadOfUpstream(self, branch):
+        upstream = branch + '@{u}'
+        output = self.repo.git('rev-list', '--count',
+                               upstream + '..' + branch).stdout
+        return int(output)
+
+
+if __name__ == '__main__':
+    unittest.main()