Kaynağa Gözat

[auth] Be able to generate id_token

Make auth be able to generate id_token. Some services on Cloud Run will
need it (e.g. luci-config v2).

Bug: 1487020
Change-Id: Icfe95002f93ee552b99ab2694c7b777e2322484b
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/4899437
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
Commit-Queue: Yuanjun Huang <yuanjunh@google.com>
Yuanjun Huang 1 yıl önce
ebeveyn
işleme
4c1d6d90bc
2 değiştirilmiş dosya ile 118 ekleme ve 38 silme
  1. 51 22
      auth.py
  2. 67 16
      tests/auth_test.py

+ 51 - 22
auth.py

@@ -29,14 +29,14 @@ def datetime_now():
     return datetime.datetime.utcnow()
     return datetime.datetime.utcnow()
 
 
 
 
-# OAuth access token with its expiration time (UTC datetime or None if unknown).
-class AccessToken(
-        collections.namedtuple('AccessToken', [
-            'token',
-            'expires_at',
-        ])):
+# OAuth access token or ID token with its expiration time (UTC datetime or None
+# if unknown).
+class Token(collections.namedtuple('Token', [
+        'token',
+        'expires_at',
+])):
     def needs_refresh(self):
     def needs_refresh(self):
-        """True if this AccessToken should be refreshed."""
+        """True if this token should be refreshed."""
         if self.expires_at is not None:
         if self.expires_at is not None:
             # Allow 30s of clock skew between client and backend.
             # Allow 30s of clock skew between client and backend.
             return datetime_now() + datetime.timedelta(
             return datetime_now() + datetime.timedelta(
@@ -67,22 +67,27 @@ def has_luci_context_local_auth():
 
 
 
 
 class Authenticator(object):
 class Authenticator(object):
-    """Object that knows how to refresh access tokens when needed.
+    """Object that knows how to refresh access tokens or id tokens when needed.
 
 
   Args:
   Args:
-    scopes: space separated oauth scopes. Defaults to OAUTH_SCOPE_EMAIL.
+    scopes: space separated oauth scopes. It's used to generate access tokens.
+            Defaults to OAUTH_SCOPE_EMAIL.
+    audience: An audience in ID tokens to claim which clients should accept it.
   """
   """
-    def __init__(self, scopes=OAUTH_SCOPE_EMAIL):
+    def __init__(self, scopes=OAUTH_SCOPE_EMAIL, audience=None):
         self._access_token = None
         self._access_token = None
         self._scopes = scopes
         self._scopes = scopes
+        self._id_token = None
+        self._audience = audience
 
 
     def has_cached_credentials(self):
     def has_cached_credentials(self):
         """Returns True if credentials can be obtained.
         """Returns True if credentials can be obtained.
 
 
-    If returns False, get_access_token() later will probably ask for interactive
-    login by raising LoginRequiredError.
+    If returns False, get_access_token() or get_id_token() later will probably
+    ask for interactive login by raising LoginRequiredError.
 
 
-    If returns True, get_access_token() won't ask for interactive login.
+    If returns True, get_access_token() or get_id_token() won't ask for
+    interactive login.
     """
     """
         return bool(self._get_luci_auth_token())
         return bool(self._get_luci_auth_token())
 
 
@@ -105,7 +110,27 @@ class Authenticator(object):
         logging.error('Failed to create access token')
         logging.error('Failed to create access token')
         raise LoginRequiredError(self._scopes)
         raise LoginRequiredError(self._scopes)
 
 
-    def authorize(self, http):
+    def get_id_token(self):
+        """Returns id token, refreshing it if necessary.
+
+    Returns:
+       A Token object.
+
+    Raises:
+      LoginRequiredError if user interaction is required.
+    """
+        if self._id_token and not self._id_token.needs_refresh():
+            return self._id_token
+
+        self._id_token = self._get_luci_auth_token(use_id_token=True)
+        if self._id_token and not self._id_token.needs_refresh():
+            return self._id_token
+
+        # Nope, still expired. Needs user interaction.
+        logging.error('Failed to create id token')
+        raise LoginRequiredError()
+
+    def authorize(self, http, use_id_token=False):
         """Monkey patches authentication logic of httplib2.Http instance.
         """Monkey patches authentication logic of httplib2.Http instance.
 
 
     The modified http.request method will add authentication headers to each
     The modified http.request method will add authentication headers to each
@@ -128,8 +153,9 @@ class Authenticator(object):
                         redirections=httplib2.DEFAULT_MAX_REDIRECTS,
                         redirections=httplib2.DEFAULT_MAX_REDIRECTS,
                         connection_type=None):
                         connection_type=None):
             headers = (headers or {}).copy()
             headers = (headers or {}).copy()
-            headers['Authorization'] = 'Bearer %s' % self.get_access_token(
-            ).token
+            auth_token = self.get_access_token(
+            ) if not use_id_token else self.get_id_token()
+            headers['Authorization'] = 'Bearer %s' % auth_token.token
             return request_orig(uri, method, body, headers, redirections,
             return request_orig(uri, method, body, headers, redirections,
                                 connection_type)
                                 connection_type)
 
 
@@ -148,18 +174,21 @@ class Authenticator(object):
         subprocess2.check_call(['luci-auth', 'login', '-scopes', self._scopes])
         subprocess2.check_call(['luci-auth', 'login', '-scopes', self._scopes])
         return self._get_luci_auth_token()
         return self._get_luci_auth_token()
 
 
-    def _get_luci_auth_token(self):
+    def _get_luci_auth_token(self, use_id_token=False):
         logging.debug('Running luci-auth token')
         logging.debug('Running luci-auth token')
+        if use_id_token:
+            args = ['-use-id-token'] + ['-audience', self._audience
+                                        ] if self._audience else []
+        else:
+            args = ['-scopes', self._scopes]
         try:
         try:
-            out, err = subprocess2.check_call_out([
-                'luci-auth', 'token', '-scopes', self._scopes, '-json-output',
-                '-'
-            ],
+            out, err = subprocess2.check_call_out(['luci-auth', 'token'] +
+                                                  args + ['-json-output', '-'],
                                                   stdout=subprocess2.PIPE,
                                                   stdout=subprocess2.PIPE,
                                                   stderr=subprocess2.PIPE)
                                                   stderr=subprocess2.PIPE)
             logging.debug('luci-auth token stderr:\n%s', err)
             logging.debug('luci-auth token stderr:\n%s', err)
             token_info = json.loads(out)
             token_info = json.loads(out)
-            return AccessToken(
+            return Token(
                 token_info['token'],
                 token_info['token'],
                 datetime.datetime.utcfromtimestamp(token_info['expiry']))
                 datetime.datetime.utcfromtimestamp(token_info['expiry']))
         except subprocess2.CalledProcessError as e:
         except subprocess2.CalledProcessError as e:

+ 67 - 16
tests/auth_test.py

@@ -52,8 +52,8 @@ class AuthenticatorTest(unittest.TestCase):
 
 
     def testGetAccessToken_CachedToken(self):
     def testGetAccessToken_CachedToken(self):
         authenticator = auth.Authenticator()
         authenticator = auth.Authenticator()
-        authenticator._access_token = auth.AccessToken('token', None)
-        self.assertEqual(auth.AccessToken('token', None),
+        authenticator._access_token = auth.Token('token', None)
+        self.assertEqual(auth.Token('token', None),
                          authenticator.get_access_token())
                          authenticator.get_access_token())
         subprocess2.check_call_out.assert_not_called()
         subprocess2.check_call_out.assert_not_called()
 
 
@@ -63,7 +63,7 @@ class AuthenticatorTest(unittest.TestCase):
             'token': 'token',
             'token': 'token',
             'expiry': expiry
             'expiry': expiry
         }), '')
         }), '')
-        self.assertEqual(auth.AccessToken('token', VALID_EXPIRY),
+        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                          auth.Authenticator().get_access_token())
                          auth.Authenticator().get_access_token())
         subprocess2.check_call_out.assert_called_with([
         subprocess2.check_call_out.assert_called_with([
             'luci-auth', 'token', '-scopes', auth.OAUTH_SCOPE_EMAIL,
             'luci-auth', 'token', '-scopes', auth.OAUTH_SCOPE_EMAIL,
@@ -78,7 +78,7 @@ class AuthenticatorTest(unittest.TestCase):
             'token': 'token',
             'token': 'token',
             'expiry': expiry
             'expiry': expiry
         }), '')
         }), '')
-        self.assertEqual(auth.AccessToken('token', VALID_EXPIRY),
+        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                          auth.Authenticator('custom scopes').get_access_token())
                          auth.Authenticator('custom scopes').get_access_token())
         subprocess2.check_call_out.assert_called_with([
         subprocess2.check_call_out.assert_called_with([
             'luci-auth', 'token', '-scopes', 'custom scopes', '-json-output',
             'luci-auth', 'token', '-scopes', 'custom scopes', '-json-output',
@@ -87,41 +87,92 @@ class AuthenticatorTest(unittest.TestCase):
                                                       stdout=subprocess2.PIPE,
                                                       stdout=subprocess2.PIPE,
                                                       stderr=subprocess2.PIPE)
                                                       stderr=subprocess2.PIPE)
 
 
-    def testAuthorize(self):
+    def testAuthorize_AccessToken(self):
         http = mock.Mock()
         http = mock.Mock()
         http_request = http.request
         http_request = http.request
         http_request.__name__ = '__name__'
         http_request.__name__ = '__name__'
 
 
         authenticator = auth.Authenticator()
         authenticator = auth.Authenticator()
-        authenticator._access_token = auth.AccessToken('token', None)
+        authenticator._access_token = auth.Token('access_token', None)
+        authenticator._id_token = auth.Token('id_token', None)
 
 
         authorized = authenticator.authorize(http)
         authorized = authenticator.authorize(http)
         authorized.request('https://example.com',
         authorized.request('https://example.com',
                            method='POST',
                            method='POST',
                            body='body',
                            body='body',
                            headers={'header': 'value'})
                            headers={'header': 'value'})
-        http_request.assert_called_once_with('https://example.com', 'POST',
-                                             'body', {
-                                                 'header': 'value',
-                                                 'Authorization': 'Bearer token'
-                                             }, mock.ANY, mock.ANY)
+        http_request.assert_called_once_with(
+            'https://example.com', 'POST', 'body', {
+                'header': 'value',
+                'Authorization': 'Bearer access_token'
+            }, mock.ANY, mock.ANY)
+
+    def testGetIdToken_NotLoggedIn(self):
+        subprocess2.check_call_out.side_effect = [
+            subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout',
+                                           'stderr')
+        ]
+        self.assertRaises(auth.LoginRequiredError,
+                          auth.Authenticator().get_id_token)
+
+    def testGetIdToken_CachedToken(self):
+        authenticator = auth.Authenticator()
+        authenticator._id_token = auth.Token('token', None)
+        self.assertEqual(auth.Token('token', None),
+                         authenticator.get_id_token())
+        subprocess2.check_call_out.assert_not_called()
+
+    def testGetIdToken_LoggedIn(self):
+        expiry = calendar.timegm(VALID_EXPIRY.timetuple())
+        subprocess2.check_call_out.return_value = (json.dumps({
+            'token': 'token',
+            'expiry': expiry
+        }), '')
+        self.assertEqual(
+            auth.Token('token', VALID_EXPIRY),
+            auth.Authenticator(audience='https://test.com').get_id_token())
+        subprocess2.check_call_out.assert_called_with([
+            'luci-auth', 'token', '-use-id-token', '-audience',
+            'https://test.com', '-json-output', '-'
+        ],
+                                                      stdout=subprocess2.PIPE,
+                                                      stderr=subprocess2.PIPE)
+
+    def testAuthorize_IdToken(self):
+        http = mock.Mock()
+        http_request = http.request
+        http_request.__name__ = '__name__'
+
+        authenticator = auth.Authenticator()
+        authenticator._access_token = auth.Token('access_token', None)
+        authenticator._id_token = auth.Token('id_token', None)
+
+        authorized = authenticator.authorize(http, use_id_token=True)
+        authorized.request('https://example.com',
+                           method='POST',
+                           body='body',
+                           headers={'header': 'value'})
+        http_request.assert_called_once_with(
+            'https://example.com', 'POST', 'body', {
+                'header': 'value',
+                'Authorization': 'Bearer id_token'
+            }, mock.ANY, mock.ANY)
 
 
 
 
-class AccessTokenTest(unittest.TestCase):
+class TokenTest(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         mock.patch('auth.datetime_now', return_value=NOW).start()
         mock.patch('auth.datetime_now', return_value=NOW).start()
         self.addCleanup(mock.patch.stopall)
         self.addCleanup(mock.patch.stopall)
 
 
     def testNeedsRefresh_NoExpiry(self):
     def testNeedsRefresh_NoExpiry(self):
-        self.assertFalse(auth.AccessToken('token', None).needs_refresh())
+        self.assertFalse(auth.Token('token', None).needs_refresh())
 
 
     def testNeedsRefresh_Expired(self):
     def testNeedsRefresh_Expired(self):
         expired = NOW + datetime.timedelta(seconds=30)
         expired = NOW + datetime.timedelta(seconds=30)
-        self.assertTrue(auth.AccessToken('token', expired).needs_refresh())
+        self.assertTrue(auth.Token('token', expired).needs_refresh())
 
 
     def testNeedsRefresh_Valid(self):
     def testNeedsRefresh_Valid(self):
-        self.assertFalse(
-            auth.AccessToken('token', VALID_EXPIRY).needs_refresh())
+        self.assertFalse(auth.Token('token', VALID_EXPIRY).needs_refresh())
 
 
 
 
 class HasLuciContextLocalAuthTest(unittest.TestCase):
 class HasLuciContextLocalAuthTest(unittest.TestCase):