Browse Source

Improve ensure_gsutil reliability

The current gsutil download code silently fails when the connection
drops mid-download, as read() returns an empty buffer instead of raising
an exception. This may lead to errors such as "zipfile.BadZipFile: File
is not a zip file" on Chromium sync with freshly-bootstrapped
depot_tools when downloading gcs deps.

This change solves this by hardening the process:
- Use retry mechanism with exponential backoff for gsutil download
- Switch to urlretrieve, which looks at Content-Length
- Compare MD5 of the downloaded file with the value from API
- Move exponential_backoff_retry from git_cache.py to gclient_utils.py

Change-Id: I25242948399e01373eb2afd9352e5c78a889051d
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/6485485
Reviewed-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
Auto-Submit: Aleksei Khoroshilov <akhoroshilov@brave.com>
Reviewed-by: Scott Lee <ddoman@chromium.org>
Aleksei Khoroshilov 3 months ago
parent
commit
998f7bfaf2
4 changed files with 117 additions and 93 deletions
  1. 39 0
      gclient_utils.py
  2. 9 46
      git_cache.py
  3. 28 25
      gsutil.py
  4. 41 22
      tests/gsutil_test.py

+ 39 - 0
gclient_utils.py

@@ -1418,3 +1418,42 @@ def merge_conditions(*conditions):
             continue
         condition = f'({condition}) and ({current_condition})'
     return condition
+
+
+def exponential_backoff_retry(fn,
+                              excs=(Exception, ),
+                              name=None,
+                              count=10,
+                              sleep_time=0.25,
+                              printerr=None):
+    """Executes |fn| up to |count| times, backing off exponentially.
+
+    Args:
+        fn (callable): The function to execute. If this raises a handled
+            exception, the function will retry with exponential backoff.
+        excs (tuple): A tuple of Exception types to handle. If one of these is
+            raised by |fn|, a retry will be attempted. If |fn| raises an
+            Exception that is not in this list, it will immediately pass
+            through. If |excs| is empty, the Exception base class will be used.
+        name (str): Optional operation name to print in the retry string.
+        count (int): The number of times to try before allowing the exception
+            to pass through.
+        sleep_time (float): The initial number of seconds to sleep in between
+            retries. This will be doubled each retry.
+        printerr (callable): Function that will be called with the error string
+            upon failures. If None, |logging.warning| will be used.
+
+    Returns: The return value of the successful fn.
+    """
+    printerr = printerr or logging.warning
+    for i in range(count):
+        try:
+            return fn()
+        except excs as e:
+            if (i + 1) >= count:
+                raise
+
+            printerr('Retrying %s in %.2f second(s) (%d / %d attempts): %s' %
+                     ((name or 'operation'), sleep_time, (i + 1), count, e))
+            time.sleep(sleep_time)
+            sleep_time *= 2

+ 9 - 46
git_cache.py

@@ -47,45 +47,6 @@ class ClobberNeeded(Exception):
     pass
 
 
-def exponential_backoff_retry(fn,
-                              excs=(Exception, ),
-                              name=None,
-                              count=10,
-                              sleep_time=0.25,
-                              printerr=None):
-    """Executes |fn| up to |count| times, backing off exponentially.
-
-    Args:
-        fn (callable): The function to execute. If this raises a handled
-            exception, the function will retry with exponential backoff.
-        excs (tuple): A tuple of Exception types to handle. If one of these is
-            raised by |fn|, a retry will be attempted. If |fn| raises an
-            Exception that is not in this list, it will immediately pass
-            through. If |excs| is empty, the Exception base class will be used.
-        name (str): Optional operation name to print in the retry string.
-        count (int): The number of times to try before allowing the exception
-            to pass through.
-        sleep_time (float): The initial number of seconds to sleep in between
-            retries. This will be doubled each retry.
-        printerr (callable): Function that will be called with the error string
-            upon failures. If None, |logging.warning| will be used.
-
-    Returns: The return value of the successful fn.
-    """
-    printerr = printerr or logging.warning
-    for i in range(count):
-        try:
-            return fn()
-        except excs as e:
-            if (i + 1) >= count:
-                raise
-
-            printerr('Retrying %s in %.2f second(s) (%d / %d attempts): %s' %
-                     ((name or 'operation'), sleep_time, (i + 1), count, e))
-            time.sleep(sleep_time)
-            sleep_time *= 2
-
-
 class Mirror(object):
 
     git_exe = 'git.bat' if sys.platform.startswith('win') else 'git'
@@ -239,10 +200,11 @@ class Mirror(object):
         # This is somehow racy on Windows.
         # Catching OSError because WindowsError isn't portable and
         # pylint complains.
-        exponential_backoff_retry(lambda: os.rename(src, dst),
-                                  excs=(OSError, ),
-                                  name='rename [%s] => [%s]' % (src, dst),
-                                  printerr=self.print)
+        gclient_utils.exponential_backoff_retry(lambda: os.rename(src, dst),
+                                                excs=(OSError, ),
+                                                name='rename [%s] => [%s]' %
+                                                (src, dst),
+                                                printerr=self.print)
 
     def RunGit(self, cmd, print_stdout=True, **kwargs):
         """Run git in a subprocess."""
@@ -485,12 +447,13 @@ class Mirror(object):
                     'repository.' % len(pack_files))
 
     def _set_symbolic_ref(self):
-        remote_info = exponential_backoff_retry(lambda: subprocess.check_output(
-            [
+        remote_info = gclient_utils.exponential_backoff_retry(
+            lambda: subprocess.check_output([
                 self.git_exe, '--git-dir',
                 os.path.abspath(self.mirror_path), 'remote', 'show', self.url
             ],
-            cwd=self.mirror_path).decode('utf-8', 'ignore').strip())
+                                            cwd=self.mirror_path).decode(
+                                                'utf-8', 'ignore').strip())
         default_branch_regexp = re.compile(r'HEAD branch: (.*)')
         m = default_branch_regexp.search(remote_info, re.MULTILINE)
         if m:

+ 28 - 25
gsutil.py

@@ -14,9 +14,10 @@ import shutil
 import subprocess
 import sys
 import tempfile
-import time
 import urllib.request
 
+import gclient_utils
+
 GSUTIL_URL = 'https://storage.googleapis.com/pub/'
 API_URL = 'https://www.googleapis.com/storage/v1/b/pub/o/'
 
@@ -49,34 +50,34 @@ def download_gsutil(version, target_dir):
     filename = 'gsutil_%s.zip' % version
     target_filename = os.path.join(target_dir, filename)
 
-    # Check if the target exists already.
-    if os.path.exists(target_filename):
-        md5_calc = hashlib.md5()
+    # Get md5 hash of the remote file from the metadata.
+    metadata_url = '%s%s' % (API_URL, filename)
+    metadata = json.load(urllib.request.urlopen(metadata_url))
+    remote_md5 = base64.b64decode(metadata['md5Hash'])
+
+    # Calculate the md5 hash of the local file.
+    def calc_local_md5():
+        assert os.path.exists(target_filename)
+        md5 = hashlib.md5()
         with open(target_filename, 'rb') as f:
-            while True:
-                buf = f.read(4096)
-                if not buf:
-                    break
-                md5_calc.update(buf)
-        local_md5 = md5_calc.hexdigest()
-
-        metadata_url = '%s%s' % (API_URL, filename)
-        metadata = json.load(urllib.request.urlopen(metadata_url))
-        remote_md5 = base64.b64decode(metadata['md5Hash']).decode('utf-8')
-
-        if local_md5 == remote_md5:
+            while chunk := f.read(1024 * 1024):
+                md5.update(chunk)
+        return md5.digest()
+
+    # Use the existing file if it has the correct md5 hash.
+    if os.path.exists(target_filename):
+        if calc_local_md5() == remote_md5:
             return target_filename
         os.remove(target_filename)
 
-    # Do the download.
+    # Download the file.
     url = '%s%s' % (GSUTIL_URL, filename)
-    u = urllib.request.urlopen(url)
-    with open(target_filename, 'wb') as f:
-        while True:
-            buf = u.read(4096)
-            if not buf:
-                break
-            f.write(buf)
+    urllib.request.urlretrieve(url, target_filename)
+
+    # Check if the file was downloaded correctly.
+    if calc_local_md5() != remote_md5:
+        raise InvalidGsutilError(f'Downloaded gsutil from {url} has wrong md5')
+
     return target_filename
 
 
@@ -112,7 +113,9 @@ def ensure_gsutil(version, target, clean):
 
         with temporary_directory(target) as instance_dir:
             download_dir = os.path.join(instance_dir, 'd')
-            target_zip_filename = download_gsutil(version, instance_dir)
+            target_zip_filename = gclient_utils.exponential_backoff_retry(
+                lambda: download_gsutil(version, instance_dir),
+                name='download_gsutil')
             with zipfile.ZipFile(target_zip_filename, 'r') as target_zip:
                 target_zip.extractall(download_dir)
 

+ 41 - 22
tests/gsutil_test.py

@@ -50,6 +50,13 @@ class FakeCall(object):
         return exp_returns
 
 
+class MockResponse(io.BytesIO):
+
+    def info(self):
+        # urlretrieve expects info() to return a dictionary.
+        return {}
+
+
 class GsutilUnitTests(unittest.TestCase):
     def setUp(self):
         self.fake = FakeCall()
@@ -65,44 +72,53 @@ class GsutilUnitTests(unittest.TestCase):
         setattr(urllib.request, 'urlopen', self.old_urlopen)
         setattr(subprocess, 'call', self.old_call)
 
+    def add_md5_expectation(self, url, data):
+        md5_calc = hashlib.md5()
+        md5_calc.update(data)
+        b64_md5 = base64.b64encode(md5_calc.digest()).decode('utf-8')
+        response_data = json.dumps({'md5Hash': b64_md5}).encode('utf-8')
+        self.fake.add_expectation(url, _returns=MockResponse(response_data))
+
+    def add_file_expectation(self, url, data):
+        self.fake.add_expectation(url, None, _returns=MockResponse(data))
+
     def test_download_gsutil(self):
         version = gsutil.VERSION
         filename = 'gsutil_%s.zip' % version
         full_filename = os.path.join(self.tempdir, filename)
         fake_file = b'This is gsutil.zip'
         fake_file2 = b'This is other gsutil.zip'
-        url = '%s%s' % (gsutil.GSUTIL_URL, filename)
-        self.fake.add_expectation(url, _returns=io.BytesIO(fake_file))
+        metadata_url = gsutil.API_URL + filename
+        url = gsutil.GSUTIL_URL + filename
 
+        # The md5 is valid, so download_gsutil should download the file.
+        self.add_md5_expectation(metadata_url, fake_file)
+        self.add_file_expectation(url, fake_file)
         self.assertEqual(gsutil.download_gsutil(version, self.tempdir),
                          full_filename)
         with open(full_filename, 'rb') as f:
             self.assertEqual(fake_file, f.read())
+        self.assertEqual(self.fake.expectations, [])
 
-        metadata_url = gsutil.API_URL + filename
-        md5_calc = hashlib.md5()
-        md5_calc.update(fake_file)
-        b64_md5 = base64.b64encode(md5_calc.hexdigest().encode('utf-8'))
-        self.fake.add_expectation(metadata_url,
-                                  _returns=io.BytesIO(
-                                      json.dumps({
-                                          'md5Hash':
-                                          b64_md5.decode('utf-8')
-                                      }).encode('utf-8')))
+        # The md5 is valid, so download_gsutil should use the existing file.
+        self.add_md5_expectation(metadata_url, fake_file)
         self.assertEqual(gsutil.download_gsutil(version, self.tempdir),
                          full_filename)
         with open(full_filename, 'rb') as f:
             self.assertEqual(fake_file, f.read())
         self.assertEqual(self.fake.expectations, [])
 
-        self.fake.add_expectation(
-            metadata_url,
-            _returns=io.BytesIO(
-                json.dumps({
-                    'md5Hash':
-                    base64.b64encode(b'aaaaaaa').decode('utf-8')  # Bad MD5
-                }).encode('utf-8')))
-        self.fake.add_expectation(url, _returns=io.BytesIO(fake_file2))
+        # The md5 is invalid for a new file, so download_gsutil should raise an
+        # error.
+        self.add_md5_expectation(metadata_url, b'aaaaaaa')
+        self.add_file_expectation(url, fake_file2)
+        self.assertRaises(gsutil.InvalidGsutilError, gsutil.download_gsutil,
+                          version, self.tempdir)
+        self.assertEqual(self.fake.expectations, [])
+
+        # The md5 is valid and the new file is already downloaded, so it should
+        # be used without downloading again.
+        self.add_md5_expectation(metadata_url, fake_file2)
         self.assertEqual(gsutil.download_gsutil(version, self.tempdir),
                          full_filename)
         with open(full_filename, 'rb') as f:
@@ -117,13 +133,16 @@ class GsutilUnitTests(unittest.TestCase):
         os.makedirs(gsutil_dir)
 
         zip_filename = 'gsutil_%s.zip' % version
-        url = '%s%s' % (gsutil.GSUTIL_URL, zip_filename)
+        metadata_url = gsutil.API_URL + zip_filename
+        url = gsutil.GSUTIL_URL + zip_filename
         _, tempzip = tempfile.mkstemp()
         fake_gsutil = 'Fake gsutil'
         with zipfile.ZipFile(tempzip, 'w') as zf:
             zf.writestr('gsutil/gsutil', fake_gsutil)
         with open(tempzip, 'rb') as f:
-            self.fake.add_expectation(url, _returns=io.BytesIO(f.read()))
+            fake_file = f.read()
+            self.add_md5_expectation(metadata_url, fake_file)
+            self.add_file_expectation(url, fake_file)
 
         # This should write the gsutil_bin with 'Fake gsutil'
         gsutil.ensure_gsutil(version, self.tempdir, False)