Browse Source

[presubmit] Extend depot tools auth to use luci context

Bug: 509672
Change-Id: Ie3cb2fa1a2276f1fe658cdf7b9ffb657d03556e8
Reviewed-on: https://chromium-review.googlesource.com/754340
Commit-Queue: Mun Yong Jang <myjang@google.com>
Reviewed-by: Nodir Turakulov <nodir@chromium.org>
Mun Yong Jang 7 năm trước cách đây
mục cha
commit
acc8e3ebaa
4 tập tin đã thay đổi với 257 bổ sung10 xóa
  1. 148 6
      auth.py
  2. 2 2
      presubmit_canned_checks.py
  3. 106 0
      tests/auth_test.py
  4. 1 2
      tests/presubmit_unittest.py

+ 148 - 6
auth.py

@@ -16,6 +16,7 @@ import os
 import socket
 import socket
 import sys
 import sys
 import threading
 import threading
+import time
 import urllib
 import urllib
 import urlparse
 import urlparse
 import webbrowser
 import webbrowser
@@ -102,6 +103,119 @@ class LoginRequiredError(AuthenticationError):
     super(LoginRequiredError, self).__init__(msg)
     super(LoginRequiredError, self).__init__(msg)
 
 
 
 
+class LuciContextAuthError(Exception):
+  """Raised on errors related to unsuccessful attempts to load LUCI_CONTEXT"""
+
+
+def get_luci_context_access_token():
+  """Returns a valid AccessToken from the local LUCI context auth server.
+
+  Adapted from
+  https://chromium.googlesource.com/infra/luci/luci-py/+/master/client/libs/luci_context/luci_context.py
+  See the link above for more details.
+
+  Returns:
+    AccessToken if LUCI_CONTEXT is present and attempt to load it is successful.
+    None if LUCI_CONTEXT is absent.
+
+  Raises:
+    LuciContextAuthError if the attempt to load LUCI_CONTEXT
+        and request its access token is unsuccessful.
+  """
+  return _get_luci_context_access_token(os.environ, datetime.datetime.utcnow())
+
+
+def _get_luci_context_access_token(env, now):
+  ctx_path = env.get('LUCI_CONTEXT')
+  if not ctx_path:
+    return None
+  ctx_path = ctx_path.decode(sys.getfilesystemencoding())
+  logging.debug('Loading LUCI_CONTEXT: %r', ctx_path)
+
+  def authErr(msg, *args):
+    error_msg = msg % args
+    ex = sys.exc_info()[1]
+    if not ex:
+      logging.error(error_msg)
+      raise LuciContextAuthError(error_msg)
+    logging.exception(error_msg)
+    raise LuciContextAuthError('%s: %s' % (error_msg, ex))
+
+  try:
+    loaded = _load_luci_context(ctx_path)
+  except (OSError, IOError, ValueError):
+    authErr('Failed to open, read or decode LUCI_CONTEXT')
+  try:
+    local_auth = loaded.get('local_auth')
+  except AttributeError:
+    authErr('LUCI_CONTEXT not in proper format')
+  # failed to grab local_auth from LUCI context
+  if not local_auth:
+    logging.debug('local_auth: no local auth found')
+    return None
+  try:
+    account_id = local_auth.get('default_account_id')
+    secret = local_auth.get('secret')
+    rpc_port = int(local_auth.get('rpc_port'))
+  except (AttributeError, ValueError):
+    authErr('local_auth: unexpected local auth format')
+
+  if not secret:
+    authErr('local_auth: no secret returned')
+  # if account_id not specified, LUCI_CONTEXT should not be picked up
+  if not account_id:
+    return None
+
+  logging.debug('local_auth: requesting an access token for account "%s"',
+      account_id)
+  http = httplib2.Http()
+  host = '127.0.0.1:%d' % rpc_port
+  resp, content = http.request(
+      uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host,
+      method='POST',
+      body=json.dumps({
+        'account_id': account_id,
+        'scopes': OAUTH_SCOPES.split(' '),
+        'secret': secret,
+      }),
+      headers={'Content-Type': 'application/json'})
+  if resp.status != 200:
+    err = ('local_auth: Failed to grab access token from '
+           'LUCI context server with status %d: %r')
+    authErr(err, resp.status, content)
+  try:
+    token = json.loads(content)
+    error_code = token.get('error_code')
+    error_message = token.get('error_message')
+    access_token = token.get('access_token')
+    expiry = token.get('expiry')
+  except (AttributeError, ValueError):
+    authErr('local_auth: Unexpected access token response format')
+  if error_code:
+    authErr('local_auth: Error %d in retrieving access token: %s',
+        error_code, error_message)
+  if not access_token:
+    authErr('local_auth: No access token returned from LUCI context server')
+  expiry_dt = None
+  if expiry:
+    try:
+      expiry_dt = datetime.datetime.utcfromtimestamp(expiry)
+    except (TypeError, ValueError):
+      authErr('Invalid expiry in returned token')
+  logging.debug(
+      'local_auth: got an access token for account "%s" that expires in %d sec',
+      account_id, expiry - time.mktime(now.timetuple()))
+  access_token = AccessToken(access_token, expiry_dt)
+  if _needs_refresh(access_token, now=now):
+    authErr('local_auth: the returned access token needs to be refreshed')
+  return access_token
+
+
+def _load_luci_context(ctx_path):
+  with open(ctx_path) as f:
+    return json.load(f)
+
+
 def make_auth_config(
 def make_auth_config(
     use_oauth2=None,
     use_oauth2=None,
     save_cookies=None,
     save_cookies=None,
@@ -219,6 +333,9 @@ def get_authenticator_for_host(hostname, config):
 
 
   Returns:
   Returns:
     Authenticator object.
     Authenticator object.
+
+  Raises:
+    AuthenticationError if hostname is invalid.
   """
   """
   hostname = hostname.lower().rstrip('/')
   hostname = hostname.lower().rstrip('/')
   # Append some scheme, otherwise urlparse puts hostname into parsed.path.
   # Append some scheme, otherwise urlparse puts hostname into parsed.path.
@@ -303,23 +420,43 @@ class Authenticator(object):
     with self._lock:
     with self._lock:
       return bool(self._get_cached_credentials())
       return bool(self._get_cached_credentials())
 
 
-  def get_access_token(self, force_refresh=False, allow_user_interaction=False):
+  def get_access_token(self, force_refresh=False, allow_user_interaction=False,
+                       use_local_auth=True):
     """Returns AccessToken, refreshing it if necessary.
     """Returns AccessToken, refreshing it if necessary.
 
 
     Args:
     Args:
       force_refresh: forcefully refresh access token even if it is not expired.
       force_refresh: forcefully refresh access token even if it is not expired.
       allow_user_interaction: True to enable blocking for user input if needed.
       allow_user_interaction: True to enable blocking for user input if needed.
+      use_local_auth: default to local auth if needed.
 
 
     Raises:
     Raises:
       AuthenticationError on error or if authentication flow was interrupted.
       AuthenticationError on error or if authentication flow was interrupted.
       LoginRequiredError if user interaction is required, but
       LoginRequiredError if user interaction is required, but
           allow_user_interaction is False.
           allow_user_interaction is False.
     """
     """
+    def get_loc_auth_tkn():
+      exi = sys.exc_info()
+      if not use_local_auth:
+        logging.error('Failed to create access token')
+        raise
+      try:
+        self._access_token = get_luci_context_access_token()
+        if not self._access_token:
+          logging.error('Failed to create access token')
+          raise
+        return self._access_token
+      except LuciContextAuthError:
+        logging.exception('Failed to use local auth')
+        raise exi[0], exi[1], exi[2]
+
     with self._lock:
     with self._lock:
       if force_refresh:
       if force_refresh:
         logging.debug('Forcing access token refresh')
         logging.debug('Forcing access token refresh')
-        self._access_token = self._create_access_token(allow_user_interaction)
-        return self._access_token
+        try:
+          self._access_token = self._create_access_token(allow_user_interaction)
+          return self._access_token
+        except LoginRequiredError:
+          return get_loc_auth_tkn()
 
 
       # Load from on-disk cache on a first access.
       # Load from on-disk cache on a first access.
       if not self._access_token:
       if not self._access_token:
@@ -331,7 +468,11 @@ class Authenticator(object):
         self._access_token = self._load_access_token()
         self._access_token = self._load_access_token()
         # Nope, still expired, need to run the refresh flow.
         # Nope, still expired, need to run the refresh flow.
         if not self._access_token or _needs_refresh(self._access_token):
         if not self._access_token or _needs_refresh(self._access_token):
-          self._access_token = self._create_access_token(allow_user_interaction)
+          try:
+            self._access_token = self._create_access_token(
+                allow_user_interaction)
+          except LoginRequiredError:
+            get_loc_auth_tkn()
 
 
       return self._access_token
       return self._access_token
 
 
@@ -548,11 +689,12 @@ def _read_refresh_token_json(path):
         'Failed to read refresh token from %s: missing key %s' % (path, e))
         'Failed to read refresh token from %s: missing key %s' % (path, e))
 
 
 
 
-def _needs_refresh(access_token):
+def _needs_refresh(access_token, now=None):
   """True if AccessToken should be refreshed."""
   """True if AccessToken should be refreshed."""
   if access_token.expires_at is not None:
   if access_token.expires_at is not None:
+    now = now or datetime.datetime.utcnow()
     # Allow 5 min of clock skew between client and backend.
     # Allow 5 min of clock skew between client and backend.
-    now = datetime.datetime.utcnow() + datetime.timedelta(seconds=300)
+    now += datetime.timedelta(seconds=300)
     return now >= access_token.expires_at
     return now >= access_token.expires_at
   # Token without expiration time never expires.
   # Token without expiration time never expires.
   return False
   return False

+ 2 - 2
presubmit_canned_checks.py

@@ -71,7 +71,7 @@ def CheckChangedConfigs(input_api, output_api):
   try:
   try:
     authenticator = auth.get_authenticator_for_host(
     authenticator = auth.get_authenticator_for_host(
         LUCI_CONFIG_HOST_NAME, auth.make_auth_config())
         LUCI_CONFIG_HOST_NAME, auth.make_auth_config())
-    acc_tkn = authenticator.get_access_token(allow_user_interaction=True).token
+    acc_tkn = authenticator.get_access_token()
   except auth.AuthenticationError as e:
   except auth.AuthenticationError as e:
     return [output_api.PresubmitError(
     return [output_api.PresubmitError(
         'Error in authenticating user.', long_text=str(e))]
         'Error in authenticating user.', long_text=str(e))]
@@ -80,7 +80,7 @@ def CheckChangedConfigs(input_api, output_api):
     api_url = ('https://%s/_ah/api/config/v1/%s'
     api_url = ('https://%s/_ah/api/config/v1/%s'
                % (LUCI_CONFIG_HOST_NAME, endpoint))
                % (LUCI_CONFIG_HOST_NAME, endpoint))
     req = urllib2.Request(api_url)
     req = urllib2.Request(api_url)
-    req.add_header('Authorization', 'Bearer %s' % acc_tkn)
+    req.add_header('Authorization', 'Bearer %s' % acc_tkn.token)
     if body is not None:
     if body is not None:
       req.add_header('Content-Type', 'application/json')
       req.add_header('Content-Type', 'application/json')
       req.add_data(json.dumps(body))
       req.add_data(json.dumps(body))

+ 106 - 0
tests/auth_test.py

@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+# Copyright (c) 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Unit Tests for auth.py"""
+
+import __builtin__
+import datetime
+import json
+import logging
+import os
+import unittest
+import sys
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+from testing_support import auto_stub
+from third_party import httplib2
+from third_party import mock
+
+import auth
+
+
+class TestGetLuciContextAccessToken(auto_stub.TestCase):
+  mock_env = {'LUCI_CONTEXT': 'default/test/path'}
+
+  def _mock_local_auth(self, account_id, secret, rpc_port):
+    self.mock(auth, '_load_luci_context', mock.Mock())
+    auth._load_luci_context.return_value = {
+      'local_auth': {
+        'default_account_id': account_id,
+        'secret': secret,
+        'rpc_port': rpc_port,
+      }
+    }
+
+  def _mock_loc_server_resp(self, status, content):
+    mock_resp = mock.Mock()
+    mock_resp.status = status
+    self.mock(httplib2.Http, 'request', mock.Mock())
+    httplib2.Http.request.return_value = (mock_resp, content)
+
+  def test_correct_local_auth_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    expiry_time = datetime.datetime.min + datetime.timedelta(minutes=60)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': time.mktime(expiry_time.timetuple()),
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    token = auth._get_luci_context_access_token(
+        self.mock_env, datetime.datetime.min)
+    self.assertEquals(token.token, 'token')
+
+  def test_incorrect_port_format(self):
+    self._mock_local_auth('foo', 'bar', 'bar')
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+  def test_no_account_id(self):
+    self._mock_local_auth(None, 'bar', 10)
+    token = auth._get_luci_context_access_token(
+        self.mock_env, datetime.datetime.min)
+    self.assertIsNone(token)
+
+  def test_expired_token(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': 1,
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(
+          self.mock_env, datetime.datetime.utcfromtimestamp(1))
+
+  def test_incorrect_expiry_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': 'dead',
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+  def test_incorrect_response_content_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    self._mock_loc_server_resp(200, '5')
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+
+if __name__ == '__main__':
+  if '-v' in sys.argv:
+    logging.basicConfig(level=logging.DEBUG)
+  unittest.main()

+ 1 - 2
tests/presubmit_unittest.py

@@ -1974,8 +1974,7 @@ class CannedChecksUnittest(PresubmitTestsBase):
     token_mock = self.mox.CreateMock(auth.AccessToken)
     token_mock = self.mox.CreateMock(auth.AccessToken)
     token_mock.token = 123
     token_mock.token = 123
     auth_mock = self.mox.CreateMock(auth.Authenticator)
     auth_mock = self.mox.CreateMock(auth.Authenticator)
-    auth_mock.get_access_token(
-        allow_user_interaction=True).AndReturn(token_mock)
+    auth_mock.get_access_token().AndReturn(token_mock)
     self.mox.StubOutWithMock(auth, 'get_authenticator_for_host')
     self.mox.StubOutWithMock(auth, 'get_authenticator_for_host')
     auth.get_authenticator_for_host(
     auth.get_authenticator_for_host(
         mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock)
         mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock)