Explorar o código

[gerrit_util] Change Authenticator API to return proxy info.

This will be used with an upcoming SSOAuthenticator implementation
which will need to proxy all http requests for Googlers.

R=ayatane, gavinmak@google.com

Bug: 336351842
Change-Id: If8cbb8db51fce198e704f109232868421130b40c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/5582100
Commit-Queue: Gavin Mak <gavinmak@google.com>
Auto-Submit: Robbie Iannucci <iannucci@chromium.org>
Reviewed-by: Gavin Mak <gavinmak@google.com>
Robert Iannucci hai 1 ano
pai
achega
c57b7ed364
Modificáronse 4 ficheiros con 60 adicións e 44 borrados
  1. 28 18
      gerrit_util.py
  2. 3 3
      git_cl.py
  3. 28 22
      tests/gerrit_util_test.py
  4. 1 1
      tests/git_cl_test.py

+ 28 - 18
gerrit_util.py

@@ -9,7 +9,7 @@ https://gerrit-review.googlesource.com/Documentation/rest-api.html
 
 import base64
 import contextlib
-from typing import List, Type
+from typing import List, Optional, Tuple, Type
 import httplib2
 import json
 import logging
@@ -96,7 +96,17 @@ def _QueryString(params, first_param=None):
 
 class Authenticator(object):
     """Base authenticator class for authenticator implementations to subclass."""
-    def get_auth_header(self, host):
+    def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
+        """Returns the Authorization header value, plus an optional ProxyInfo.
+
+        TODO: Remove `host`. This is only needed for the deprecated
+        CookiesAuthenticator. If distinguishing between hosts is still needed
+        later, I would propose moving this parameter to
+        Authenticator.get/Authenticator.is_applicable/Authenticator.__init__
+        instead.
+
+        TODO: Make auth header non-optional.
+        """
         raise NotImplementedError()
 
     def debug_summary_state(self) -> str:
@@ -231,16 +241,16 @@ class CookiesAuthenticator(Authenticator):
                 return (creds[0], None, creds[1])
         return None
 
-    def get_auth_header(self, host):
+    def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
         a = self._get_auth_for_host(host)
         if a:
             if a[0]:
                 secret = base64.b64encode(
                     ('%s:%s' % (a[0], a[2])).encode('utf-8'))
-                return 'Basic %s' % secret.decode('utf-8')
+                return 'Basic %s' % secret.decode('utf-8'), None
 
-            return 'Bearer %s' % a[2]
-        return None
+            return 'Bearer %s' % a[2], None
+        return None, None
 
     # Used to redact the cookies from the gitcookies file.
     GITCOOKIES_REDACT_RE = re.compile(r'1/.*')
@@ -333,11 +343,11 @@ class GceAuthenticator(Authenticator):
         cls._token_expiration = cls._token_cache['expires_in'] + time_time()
         return cls._token_cache
 
-    def get_auth_header(self, _host):
+    def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
         token_dict = self._get_token_dict()
         if not token_dict:
-            return None
-        return '%(token_type)s %(access_token)s' % token_dict
+            return None, None
+        return '%(token_type)s %(access_token)s' % token_dict, None
 
     def debug_summary_state(self) -> str:
         # TODO(b/343230702) - report ambient account name.
@@ -355,8 +365,8 @@ class LuciContextAuthenticator(Authenticator):
         self._authenticator = auth.Authenticator(' '.join(
             [auth.OAUTH_SCOPE_EMAIL, auth.OAUTH_SCOPE_GERRIT]))
 
-    def get_auth_header(self, _host):
-        return 'Bearer %s' % self._authenticator.get_access_token().token
+    def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
+        return 'Bearer %s' % self._authenticator.get_access_token().token, None
 
     def debug_summary_state(self) -> str:
         # TODO(b/343230702) - report ambient account name.
@@ -373,22 +383,22 @@ def CreateHttpConn(host,
     headers = headers or {}
     bare_host = host.partition(':')[0]
 
-    a = Authenticator.get()
+    authenticator = Authenticator.get()
     # TODO(crbug.com/1059384): Automatically detect when running on cloudtop.
-    if isinstance(a, GceAuthenticator):
+    if isinstance(authenticator, GceAuthenticator):
         print('If you\'re on a cloudtop instance, export '
               'SKIP_GCE_AUTH_FOR_GIT=1 in your env.')
 
-    a = a.get_auth_header(bare_host)
-    if a:
-        headers.setdefault('Authorization', a)
+    auth_header, proxy = authenticator.get_auth_info(bare_host)
+    if auth_header:
+        headers.setdefault('Authorization', auth_header)
     else:
         LOGGER.debug('No authorization found for %s.' % bare_host)
 
     url = path
     if not url.startswith('/'):
         url = '/' + url
-    if 'Authorization' in headers and not url.startswith('/a/'):
+    if auth_header and not url.startswith('/a/'):
         url = '/a%s' % url
 
     if body:
@@ -402,7 +412,7 @@ def CreateHttpConn(host,
             LOGGER.debug('%s: %s' % (key, val))
         if body:
             LOGGER.debug(body)
-    conn = httplib2.Http(timeout=timeout)
+    conn = httplib2.Http(timeout=timeout, proxy_info=proxy)
     # HACK: httplib2.Http has no such attribute; we store req_host here for
     # later use in ReadHttpResponse.
     conn.req_host = host

+ 3 - 3
git_cl.py

@@ -2312,12 +2312,12 @@ class Changelist(object):
         git_host = self._GetGitHost()
         assert self._gerrit_server and self._gerrit_host and git_host
 
-        gerrit_auth = cookie_auth.get_auth_header(self._gerrit_host)
-        git_auth = cookie_auth.get_auth_header(git_host)
+        gerrit_auth, _ = cookie_auth.get_auth_info(self._gerrit_host)
+        git_auth, _ = cookie_auth.get_auth_info(git_host)
         if gerrit_auth and git_auth:
             if gerrit_auth == git_auth:
                 return
-            all_gsrc = cookie_auth.get_auth_header(
+            all_gsrc, _ = cookie_auth.get_auth_info(
                 'd0esN0tEx1st.googlesource.com')
             print(
                 'WARNING: You have different credentials for Gerrit and git hosts:\n'

+ 28 - 22
tests/gerrit_util_test.py

@@ -151,13 +151,13 @@ class CookiesAuthenticatorTest(unittest.TestCase):
             'Basic Z2l0LXVzZXIuY2hyb21pdW0ub3JnOjEvY2hyb21pdW0tc2VjcmV0')
 
         auth = gerrit_util.CookiesAuthenticator()
-        self.assertEqual(expected_chromium_header,
-                         auth.get_auth_header('chromium.googlesource.com'))
+        self.assertEqual((expected_chromium_header, None),
+                         auth.get_auth_info('chromium.googlesource.com'))
         self.assertEqual(
-            expected_chromium_header,
-            auth.get_auth_header('chromium-review.googlesource.com'))
-        self.assertEqual('Bearer example-bearer-token',
-                         auth.get_auth_header('some-review.example.com'))
+            (expected_chromium_header, None),
+            auth.get_auth_info('chromium-review.googlesource.com'))
+        self.assertEqual(('Bearer example-bearer-token', None),
+                         auth.get_auth_info('some-review.example.com'))
 
     def testGetAuthEmail(self):
         auth = gerrit_util.CookiesAuthenticator()
@@ -226,15 +226,21 @@ class GceAuthenticatorTest(unittest.TestCase):
 
     def testGetAuthHeader_Error(self):
         httplib2.Http().request.side_effect = httplib2.HttpLib2Error
-        self.assertIsNone(self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(
+            (None, None),
+            self.GceAuthenticator().get_auth_info(''))
 
     def testGetAuthHeader_500(self):
         httplib2.Http().request.return_value = (mock.Mock(status=500), None)
-        self.assertIsNone(self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(
+            (None, None),
+            self.GceAuthenticator().get_auth_info(''))
 
     def testGetAuthHeader_Non200(self):
         httplib2.Http().request.return_value = (mock.Mock(status=403), None)
-        self.assertIsNone(self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(
+            (None, None),
+            self.GceAuthenticator().get_auth_info(''))
 
     def testGetAuthHeader_OK(self):
         httplib2.Http().request.return_value = (
@@ -242,8 +248,8 @@ class GceAuthenticatorTest(unittest.TestCase):
             '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
         )
         gerrit_util.time_time.return_value = 0
-        self.assertEqual('TYPE TOKEN',
-                         self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(('TYPE TOKEN', None),
+                         self.GceAuthenticator().get_auth_info(''))
 
     def testGetAuthHeader_Cache(self):
         httplib2.Http().request.return_value = (
@@ -251,10 +257,10 @@ class GceAuthenticatorTest(unittest.TestCase):
             '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
         )
         gerrit_util.time_time.return_value = 0
-        self.assertEqual('TYPE TOKEN',
-                         self.GceAuthenticator().get_auth_header(''))
-        self.assertEqual('TYPE TOKEN',
-                         self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(('TYPE TOKEN', None),
+                         self.GceAuthenticator().get_auth_info(''))
+        self.assertEqual(('TYPE TOKEN', None),
+                         self.GceAuthenticator().get_auth_info(''))
         httplib2.Http().request.assert_called_once()
 
     def testGetAuthHeader_CacheOld(self):
@@ -263,10 +269,10 @@ class GceAuthenticatorTest(unittest.TestCase):
             '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
         )
         gerrit_util.time_time.side_effect = [0, 100, 200]
-        self.assertEqual('TYPE TOKEN',
-                         self.GceAuthenticator().get_auth_header(''))
-        self.assertEqual('TYPE TOKEN',
-                         self.GceAuthenticator().get_auth_header(''))
+        self.assertEqual(('TYPE TOKEN', None),
+                         self.GceAuthenticator().get_auth_info(''))
+        self.assertEqual(('TYPE TOKEN', None),
+                         self.GceAuthenticator().get_auth_info(''))
         self.assertEqual(2, len(httplib2.Http().request.mock_calls))
 
 
@@ -294,7 +300,7 @@ class GerritUtilTest(unittest.TestCase):
 
     @mock.patch('gerrit_util.Authenticator')
     def testCreateHttpConn_Basic(self, mockAuth):
-        mockAuth.get().get_auth_header.return_value = None
+        mockAuth.get().get_auth_info.return_value = None, None
         conn = gerrit_util.CreateHttpConn('host.example.com', 'foo/bar')
         self.assertEqual('host.example.com', conn.req_host)
         self.assertEqual(
@@ -307,7 +313,7 @@ class GerritUtilTest(unittest.TestCase):
 
     @mock.patch('gerrit_util.Authenticator')
     def testCreateHttpConn_Authenticated(self, mockAuth):
-        mockAuth.get().get_auth_header.return_value = 'Bearer token'
+        mockAuth.get().get_auth_info.return_value = 'Bearer token', None
         conn = gerrit_util.CreateHttpConn('host.example.com',
                                           'foo/bar',
                                           headers={'header': 'value'})
@@ -325,7 +331,7 @@ class GerritUtilTest(unittest.TestCase):
 
     @mock.patch('gerrit_util.Authenticator')
     def testCreateHttpConn_Body(self, mockAuth):
-        mockAuth.get().get_auth_header.return_value = None
+        mockAuth.get().get_auth_info.return_value = None, None
         conn = gerrit_util.CreateHttpConn('host.example.com',
                                           'foo/bar',
                                           body={

+ 1 - 1
tests/git_cl_test.py

@@ -2598,7 +2598,7 @@ class TestGitCl(unittest.TestCase):
                 'chromium-review.googlesource.com': ('', None, 'secret'),
             })
         self.assertIsNone(cl.EnsureAuthenticated(force=False))
-        header = gerrit_util.CookiesAuthenticator().get_auth_header(
+        header, _ = gerrit_util.CookiesAuthenticator().get_auth_info(
             'chromium.googlesource.com')
         self.assertTrue('Bearer' in header)