Jelajahi Sumber

[auth] Be able to generate id_token

Make auth be able to generate id_token. Some services on Cloud Run will
need it (e.g. luci-config v2).

Bug: 1487020
Change-Id: Icfe95002f93ee552b99ab2694c7b777e2322484b
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/4899437
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
Commit-Queue: Yuanjun Huang <yuanjunh@google.com>
Yuanjun Huang 1 tahun lalu
induk
melakukan
4c1d6d90bc
2 mengubah file dengan 118 tambahan dan 38 penghapusan
  1. 51 22
      auth.py
  2. 67 16
      tests/auth_test.py

+ 51 - 22
auth.py

@@ -29,14 +29,14 @@ def datetime_now():
     return datetime.datetime.utcnow()
 
 
-# OAuth access token with its expiration time (UTC datetime or None if unknown).
-class AccessToken(
-        collections.namedtuple('AccessToken', [
-            'token',
-            'expires_at',
-        ])):
+# OAuth access token or ID token with its expiration time (UTC datetime or None
+# if unknown).
+class Token(collections.namedtuple('Token', [
+        'token',
+        'expires_at',
+])):
     def needs_refresh(self):
-        """True if this AccessToken should be refreshed."""
+        """True if this token should be refreshed."""
         if self.expires_at is not None:
             # Allow 30s of clock skew between client and backend.
             return datetime_now() + datetime.timedelta(
@@ -67,22 +67,27 @@ def has_luci_context_local_auth():
 
 
 class Authenticator(object):
-    """Object that knows how to refresh access tokens when needed.
+    """Object that knows how to refresh access tokens or id tokens when needed.
 
   Args:
-    scopes: space separated oauth scopes. Defaults to OAUTH_SCOPE_EMAIL.
+    scopes: space separated oauth scopes. It's used to generate access tokens.
+            Defaults to OAUTH_SCOPE_EMAIL.
+    audience: An audience in ID tokens to claim which clients should accept it.
   """
-    def __init__(self, scopes=OAUTH_SCOPE_EMAIL):
+    def __init__(self, scopes=OAUTH_SCOPE_EMAIL, audience=None):
         self._access_token = None
         self._scopes = scopes
+        self._id_token = None
+        self._audience = audience
 
     def has_cached_credentials(self):
         """Returns True if credentials can be obtained.
 
-    If returns False, get_access_token() later will probably ask for interactive
-    login by raising LoginRequiredError.
+    If returns False, get_access_token() or get_id_token() later will probably
+    ask for interactive login by raising LoginRequiredError.
 
-    If returns True, get_access_token() won't ask for interactive login.
+    If returns True, get_access_token() or get_id_token() won't ask for
+    interactive login.
     """
         return bool(self._get_luci_auth_token())
 
@@ -105,7 +110,27 @@ class Authenticator(object):
         logging.error('Failed to create access token')
         raise LoginRequiredError(self._scopes)
 
-    def authorize(self, http):
+    def get_id_token(self):
+        """Returns id token, refreshing it if necessary.
+
+    Returns:
+       A Token object.
+
+    Raises:
+      LoginRequiredError if user interaction is required.
+    """
+        if self._id_token and not self._id_token.needs_refresh():
+            return self._id_token
+
+        self._id_token = self._get_luci_auth_token(use_id_token=True)
+        if self._id_token and not self._id_token.needs_refresh():
+            return self._id_token
+
+        # Nope, still expired. Needs user interaction.
+        logging.error('Failed to create id token')
+        raise LoginRequiredError()
+
+    def authorize(self, http, use_id_token=False):
         """Monkey patches authentication logic of httplib2.Http instance.
 
     The modified http.request method will add authentication headers to each
@@ -128,8 +153,9 @@ class Authenticator(object):
                         redirections=httplib2.DEFAULT_MAX_REDIRECTS,
                         connection_type=None):
             headers = (headers or {}).copy()
-            headers['Authorization'] = 'Bearer %s' % self.get_access_token(
-            ).token
+            auth_token = self.get_access_token(
+            ) if not use_id_token else self.get_id_token()
+            headers['Authorization'] = 'Bearer %s' % auth_token.token
             return request_orig(uri, method, body, headers, redirections,
                                 connection_type)
 
@@ -148,18 +174,21 @@ class Authenticator(object):
         subprocess2.check_call(['luci-auth', 'login', '-scopes', self._scopes])
         return self._get_luci_auth_token()
 
-    def _get_luci_auth_token(self):
+    def _get_luci_auth_token(self, use_id_token=False):
         logging.debug('Running luci-auth token')
+        if use_id_token:
+            args = ['-use-id-token'] + ['-audience', self._audience
+                                        ] if self._audience else []
+        else:
+            args = ['-scopes', self._scopes]
         try:
-            out, err = subprocess2.check_call_out([
-                'luci-auth', 'token', '-scopes', self._scopes, '-json-output',
-                '-'
-            ],
+            out, err = subprocess2.check_call_out(['luci-auth', 'token'] +
+                                                  args + ['-json-output', '-'],
                                                   stdout=subprocess2.PIPE,
                                                   stderr=subprocess2.PIPE)
             logging.debug('luci-auth token stderr:\n%s', err)
             token_info = json.loads(out)
-            return AccessToken(
+            return Token(
                 token_info['token'],
                 datetime.datetime.utcfromtimestamp(token_info['expiry']))
         except subprocess2.CalledProcessError as e:

+ 67 - 16
tests/auth_test.py

@@ -52,8 +52,8 @@ class AuthenticatorTest(unittest.TestCase):
 
     def testGetAccessToken_CachedToken(self):
         authenticator = auth.Authenticator()
-        authenticator._access_token = auth.AccessToken('token', None)
-        self.assertEqual(auth.AccessToken('token', None),
+        authenticator._access_token = auth.Token('token', None)
+        self.assertEqual(auth.Token('token', None),
                          authenticator.get_access_token())
         subprocess2.check_call_out.assert_not_called()
 
@@ -63,7 +63,7 @@ class AuthenticatorTest(unittest.TestCase):
             'token': 'token',
             'expiry': expiry
         }), '')
-        self.assertEqual(auth.AccessToken('token', VALID_EXPIRY),
+        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                          auth.Authenticator().get_access_token())
         subprocess2.check_call_out.assert_called_with([
             'luci-auth', 'token', '-scopes', auth.OAUTH_SCOPE_EMAIL,
@@ -78,7 +78,7 @@ class AuthenticatorTest(unittest.TestCase):
             'token': 'token',
             'expiry': expiry
         }), '')
-        self.assertEqual(auth.AccessToken('token', VALID_EXPIRY),
+        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                          auth.Authenticator('custom scopes').get_access_token())
         subprocess2.check_call_out.assert_called_with([
             'luci-auth', 'token', '-scopes', 'custom scopes', '-json-output',
@@ -87,41 +87,92 @@ class AuthenticatorTest(unittest.TestCase):
                                                       stdout=subprocess2.PIPE,
                                                       stderr=subprocess2.PIPE)
 
-    def testAuthorize(self):
+    def testAuthorize_AccessToken(self):
         http = mock.Mock()
         http_request = http.request
         http_request.__name__ = '__name__'
 
         authenticator = auth.Authenticator()
-        authenticator._access_token = auth.AccessToken('token', None)
+        authenticator._access_token = auth.Token('access_token', None)
+        authenticator._id_token = auth.Token('id_token', None)
 
         authorized = authenticator.authorize(http)
         authorized.request('https://example.com',
                            method='POST',
                            body='body',
                            headers={'header': 'value'})
-        http_request.assert_called_once_with('https://example.com', 'POST',
-                                             'body', {
-                                                 'header': 'value',
-                                                 'Authorization': 'Bearer token'
-                                             }, mock.ANY, mock.ANY)
+        http_request.assert_called_once_with(
+            'https://example.com', 'POST', 'body', {
+                'header': 'value',
+                'Authorization': 'Bearer access_token'
+            }, mock.ANY, mock.ANY)
+
+    def testGetIdToken_NotLoggedIn(self):
+        subprocess2.check_call_out.side_effect = [
+            subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout',
+                                           'stderr')
+        ]
+        self.assertRaises(auth.LoginRequiredError,
+                          auth.Authenticator().get_id_token)
+
+    def testGetIdToken_CachedToken(self):
+        authenticator = auth.Authenticator()
+        authenticator._id_token = auth.Token('token', None)
+        self.assertEqual(auth.Token('token', None),
+                         authenticator.get_id_token())
+        subprocess2.check_call_out.assert_not_called()
+
+    def testGetIdToken_LoggedIn(self):
+        expiry = calendar.timegm(VALID_EXPIRY.timetuple())
+        subprocess2.check_call_out.return_value = (json.dumps({
+            'token': 'token',
+            'expiry': expiry
+        }), '')
+        self.assertEqual(
+            auth.Token('token', VALID_EXPIRY),
+            auth.Authenticator(audience='https://test.com').get_id_token())
+        subprocess2.check_call_out.assert_called_with([
+            'luci-auth', 'token', '-use-id-token', '-audience',
+            'https://test.com', '-json-output', '-'
+        ],
+                                                      stdout=subprocess2.PIPE,
+                                                      stderr=subprocess2.PIPE)
+
+    def testAuthorize_IdToken(self):
+        http = mock.Mock()
+        http_request = http.request
+        http_request.__name__ = '__name__'
+
+        authenticator = auth.Authenticator()
+        authenticator._access_token = auth.Token('access_token', None)
+        authenticator._id_token = auth.Token('id_token', None)
+
+        authorized = authenticator.authorize(http, use_id_token=True)
+        authorized.request('https://example.com',
+                           method='POST',
+                           body='body',
+                           headers={'header': 'value'})
+        http_request.assert_called_once_with(
+            'https://example.com', 'POST', 'body', {
+                'header': 'value',
+                'Authorization': 'Bearer id_token'
+            }, mock.ANY, mock.ANY)
 
 
-class AccessTokenTest(unittest.TestCase):
+class TokenTest(unittest.TestCase):
     def setUp(self):
         mock.patch('auth.datetime_now', return_value=NOW).start()
         self.addCleanup(mock.patch.stopall)
 
     def testNeedsRefresh_NoExpiry(self):
-        self.assertFalse(auth.AccessToken('token', None).needs_refresh())
+        self.assertFalse(auth.Token('token', None).needs_refresh())
 
     def testNeedsRefresh_Expired(self):
         expired = NOW + datetime.timedelta(seconds=30)
-        self.assertTrue(auth.AccessToken('token', expired).needs_refresh())
+        self.assertTrue(auth.Token('token', expired).needs_refresh())
 
     def testNeedsRefresh_Valid(self):
-        self.assertFalse(
-            auth.AccessToken('token', VALID_EXPIRY).needs_refresh())
+        self.assertFalse(auth.Token('token', VALID_EXPIRY).needs_refresh())
 
 
 class HasLuciContextLocalAuthTest(unittest.TestCase):