Эх сурвалжийг харах

[presubmit] Extend depot tools auth to use luci context

Bug: 509672
Change-Id: Ie3cb2fa1a2276f1fe658cdf7b9ffb657d03556e8
Reviewed-on: https://chromium-review.googlesource.com/754340
Commit-Queue: Mun Yong Jang <myjang@google.com>
Reviewed-by: Nodir Turakulov <nodir@chromium.org>
Mun Yong Jang 7 жил өмнө
parent
commit
acc8e3ebaa

+ 148 - 6
auth.py

@@ -16,6 +16,7 @@ import os
 import socket
 import sys
 import threading
+import time
 import urllib
 import urlparse
 import webbrowser
@@ -102,6 +103,119 @@ class LoginRequiredError(AuthenticationError):
     super(LoginRequiredError, self).__init__(msg)
 
 
+class LuciContextAuthError(Exception):
+  """Raised on errors related to unsuccessful attempts to load LUCI_CONTEXT"""
+
+
+def get_luci_context_access_token():
+  """Returns a valid AccessToken from the local LUCI context auth server.
+
+  Adapted from
+  https://chromium.googlesource.com/infra/luci/luci-py/+/master/client/libs/luci_context/luci_context.py
+  See the link above for more details.
+
+  Returns:
+    AccessToken if LUCI_CONTEXT is present and attempt to load it is successful.
+    None if LUCI_CONTEXT is absent.
+
+  Raises:
+    LuciContextAuthError if the attempt to load LUCI_CONTEXT
+        and request its access token is unsuccessful.
+  """
+  return _get_luci_context_access_token(os.environ, datetime.datetime.utcnow())
+
+
+def _get_luci_context_access_token(env, now):
+  ctx_path = env.get('LUCI_CONTEXT')
+  if not ctx_path:
+    return None
+  ctx_path = ctx_path.decode(sys.getfilesystemencoding())
+  logging.debug('Loading LUCI_CONTEXT: %r', ctx_path)
+
+  def authErr(msg, *args):
+    error_msg = msg % args
+    ex = sys.exc_info()[1]
+    if not ex:
+      logging.error(error_msg)
+      raise LuciContextAuthError(error_msg)
+    logging.exception(error_msg)
+    raise LuciContextAuthError('%s: %s' % (error_msg, ex))
+
+  try:
+    loaded = _load_luci_context(ctx_path)
+  except (OSError, IOError, ValueError):
+    authErr('Failed to open, read or decode LUCI_CONTEXT')
+  try:
+    local_auth = loaded.get('local_auth')
+  except AttributeError:
+    authErr('LUCI_CONTEXT not in proper format')
+  # failed to grab local_auth from LUCI context
+  if not local_auth:
+    logging.debug('local_auth: no local auth found')
+    return None
+  try:
+    account_id = local_auth.get('default_account_id')
+    secret = local_auth.get('secret')
+    rpc_port = int(local_auth.get('rpc_port'))
+  except (AttributeError, ValueError):
+    authErr('local_auth: unexpected local auth format')
+
+  if not secret:
+    authErr('local_auth: no secret returned')
+  # if account_id not specified, LUCI_CONTEXT should not be picked up
+  if not account_id:
+    return None
+
+  logging.debug('local_auth: requesting an access token for account "%s"',
+      account_id)
+  http = httplib2.Http()
+  host = '127.0.0.1:%d' % rpc_port
+  resp, content = http.request(
+      uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host,
+      method='POST',
+      body=json.dumps({
+        'account_id': account_id,
+        'scopes': OAUTH_SCOPES.split(' '),
+        'secret': secret,
+      }),
+      headers={'Content-Type': 'application/json'})
+  if resp.status != 200:
+    err = ('local_auth: Failed to grab access token from '
+           'LUCI context server with status %d: %r')
+    authErr(err, resp.status, content)
+  try:
+    token = json.loads(content)
+    error_code = token.get('error_code')
+    error_message = token.get('error_message')
+    access_token = token.get('access_token')
+    expiry = token.get('expiry')
+  except (AttributeError, ValueError):
+    authErr('local_auth: Unexpected access token response format')
+  if error_code:
+    authErr('local_auth: Error %d in retrieving access token: %s',
+        error_code, error_message)
+  if not access_token:
+    authErr('local_auth: No access token returned from LUCI context server')
+  expiry_dt = None
+  if expiry:
+    try:
+      expiry_dt = datetime.datetime.utcfromtimestamp(expiry)
+    except (TypeError, ValueError):
+      authErr('Invalid expiry in returned token')
+  logging.debug(
+      'local_auth: got an access token for account "%s" that expires in %d sec',
+      account_id, expiry - time.mktime(now.timetuple()))
+  access_token = AccessToken(access_token, expiry_dt)
+  if _needs_refresh(access_token, now=now):
+    authErr('local_auth: the returned access token needs to be refreshed')
+  return access_token
+
+
+def _load_luci_context(ctx_path):
+  with open(ctx_path) as f:
+    return json.load(f)
+
+
 def make_auth_config(
     use_oauth2=None,
     save_cookies=None,
@@ -219,6 +333,9 @@ def get_authenticator_for_host(hostname, config):
 
   Returns:
     Authenticator object.
+
+  Raises:
+    AuthenticationError if hostname is invalid.
   """
   hostname = hostname.lower().rstrip('/')
   # Append some scheme, otherwise urlparse puts hostname into parsed.path.
@@ -303,23 +420,43 @@ class Authenticator(object):
     with self._lock:
       return bool(self._get_cached_credentials())
 
-  def get_access_token(self, force_refresh=False, allow_user_interaction=False):
+  def get_access_token(self, force_refresh=False, allow_user_interaction=False,
+                       use_local_auth=True):
     """Returns AccessToken, refreshing it if necessary.
 
     Args:
       force_refresh: forcefully refresh access token even if it is not expired.
       allow_user_interaction: True to enable blocking for user input if needed.
+      use_local_auth: default to local auth if needed.
 
     Raises:
       AuthenticationError on error or if authentication flow was interrupted.
       LoginRequiredError if user interaction is required, but
           allow_user_interaction is False.
     """
+    def get_loc_auth_tkn():
+      exi = sys.exc_info()
+      if not use_local_auth:
+        logging.error('Failed to create access token')
+        raise
+      try:
+        self._access_token = get_luci_context_access_token()
+        if not self._access_token:
+          logging.error('Failed to create access token')
+          raise
+        return self._access_token
+      except LuciContextAuthError:
+        logging.exception('Failed to use local auth')
+        raise exi[0], exi[1], exi[2]
+
     with self._lock:
       if force_refresh:
         logging.debug('Forcing access token refresh')
-        self._access_token = self._create_access_token(allow_user_interaction)
-        return self._access_token
+        try:
+          self._access_token = self._create_access_token(allow_user_interaction)
+          return self._access_token
+        except LoginRequiredError:
+          return get_loc_auth_tkn()
 
       # Load from on-disk cache on a first access.
       if not self._access_token:
@@ -331,7 +468,11 @@ class Authenticator(object):
         self._access_token = self._load_access_token()
         # Nope, still expired, need to run the refresh flow.
         if not self._access_token or _needs_refresh(self._access_token):
-          self._access_token = self._create_access_token(allow_user_interaction)
+          try:
+            self._access_token = self._create_access_token(
+                allow_user_interaction)
+          except LoginRequiredError:
+            get_loc_auth_tkn()
 
       return self._access_token
 
@@ -548,11 +689,12 @@ def _read_refresh_token_json(path):
         'Failed to read refresh token from %s: missing key %s' % (path, e))
 
 
-def _needs_refresh(access_token):
+def _needs_refresh(access_token, now=None):
   """True if AccessToken should be refreshed."""
   if access_token.expires_at is not None:
+    now = now or datetime.datetime.utcnow()
     # Allow 5 min of clock skew between client and backend.
-    now = datetime.datetime.utcnow() + datetime.timedelta(seconds=300)
+    now += datetime.timedelta(seconds=300)
     return now >= access_token.expires_at
   # Token without expiration time never expires.
   return False

+ 2 - 2
presubmit_canned_checks.py

@@ -71,7 +71,7 @@ def CheckChangedConfigs(input_api, output_api):
   try:
     authenticator = auth.get_authenticator_for_host(
         LUCI_CONFIG_HOST_NAME, auth.make_auth_config())
-    acc_tkn = authenticator.get_access_token(allow_user_interaction=True).token
+    acc_tkn = authenticator.get_access_token()
   except auth.AuthenticationError as e:
     return [output_api.PresubmitError(
         'Error in authenticating user.', long_text=str(e))]
@@ -80,7 +80,7 @@ def CheckChangedConfigs(input_api, output_api):
     api_url = ('https://%s/_ah/api/config/v1/%s'
                % (LUCI_CONFIG_HOST_NAME, endpoint))
     req = urllib2.Request(api_url)
-    req.add_header('Authorization', 'Bearer %s' % acc_tkn)
+    req.add_header('Authorization', 'Bearer %s' % acc_tkn.token)
     if body is not None:
       req.add_header('Content-Type', 'application/json')
       req.add_data(json.dumps(body))

+ 106 - 0
tests/auth_test.py

@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+# Copyright (c) 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Unit Tests for auth.py"""
+
+import __builtin__
+import datetime
+import json
+import logging
+import os
+import unittest
+import sys
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+from testing_support import auto_stub
+from third_party import httplib2
+from third_party import mock
+
+import auth
+
+
+class TestGetLuciContextAccessToken(auto_stub.TestCase):
+  mock_env = {'LUCI_CONTEXT': 'default/test/path'}
+
+  def _mock_local_auth(self, account_id, secret, rpc_port):
+    self.mock(auth, '_load_luci_context', mock.Mock())
+    auth._load_luci_context.return_value = {
+      'local_auth': {
+        'default_account_id': account_id,
+        'secret': secret,
+        'rpc_port': rpc_port,
+      }
+    }
+
+  def _mock_loc_server_resp(self, status, content):
+    mock_resp = mock.Mock()
+    mock_resp.status = status
+    self.mock(httplib2.Http, 'request', mock.Mock())
+    httplib2.Http.request.return_value = (mock_resp, content)
+
+  def test_correct_local_auth_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    expiry_time = datetime.datetime.min + datetime.timedelta(minutes=60)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': time.mktime(expiry_time.timetuple()),
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    token = auth._get_luci_context_access_token(
+        self.mock_env, datetime.datetime.min)
+    self.assertEquals(token.token, 'token')
+
+  def test_incorrect_port_format(self):
+    self._mock_local_auth('foo', 'bar', 'bar')
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+  def test_no_account_id(self):
+    self._mock_local_auth(None, 'bar', 10)
+    token = auth._get_luci_context_access_token(
+        self.mock_env, datetime.datetime.min)
+    self.assertIsNone(token)
+
+  def test_expired_token(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': 1,
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(
+          self.mock_env, datetime.datetime.utcfromtimestamp(1))
+
+  def test_incorrect_expiry_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    resp_content = {
+      'error_code': None,
+      'error_message': None,
+      'access_token': 'token',
+      'expiry': 'dead',
+    }
+    self._mock_loc_server_resp(200, json.dumps(resp_content))
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+  def test_incorrect_response_content_format(self):
+    self._mock_local_auth('dead', 'beef', 10)
+    self._mock_loc_server_resp(200, '5')
+    with self.assertRaises(auth.LuciContextAuthError):
+      auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
+
+
+if __name__ == '__main__':
+  if '-v' in sys.argv:
+    logging.basicConfig(level=logging.DEBUG)
+  unittest.main()

+ 1 - 2
tests/presubmit_unittest.py

@@ -1974,8 +1974,7 @@ class CannedChecksUnittest(PresubmitTestsBase):
     token_mock = self.mox.CreateMock(auth.AccessToken)
     token_mock.token = 123
     auth_mock = self.mox.CreateMock(auth.Authenticator)
-    auth_mock.get_access_token(
-        allow_user_interaction=True).AndReturn(token_mock)
+    auth_mock.get_access_token().AndReturn(token_mock)
     self.mox.StubOutWithMock(auth, 'get_authenticator_for_host')
     auth.get_authenticator_for_host(
         mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock)