auth.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package auth provides utilities for managing Google Cloud credentials,
  15. // including functionality for creating, caching, and refreshing OAuth2 tokens.
  16. // It offers customizable options for different OAuth2 flows, such as 2-legged
  17. // (2LO) and 3-legged (3LO) OAuth, along with support for PKCE and automatic
  18. // token management.
  19. package auth
  20. import (
  21. "context"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "net/http"
  26. "net/url"
  27. "strings"
  28. "sync"
  29. "time"
  30. "cloud.google.com/go/auth/internal"
  31. "cloud.google.com/go/auth/internal/jwt"
  32. )
  33. const (
  34. // Parameter keys for AuthCodeURL method to support PKCE.
  35. codeChallengeKey = "code_challenge"
  36. codeChallengeMethodKey = "code_challenge_method"
  37. // Parameter key for Exchange method to support PKCE.
  38. codeVerifierKey = "code_verifier"
  39. // 3 minutes and 45 seconds before expiration. The shortest MDS cache is 4 minutes,
  40. // so we give it 15 seconds to refresh it's cache before attempting to refresh a token.
  41. defaultExpiryDelta = 225 * time.Second
  42. universeDomainDefault = "googleapis.com"
  43. )
  44. // tokenState represents different states for a [Token].
  45. type tokenState int
  46. const (
  47. // fresh indicates that the [Token] is valid. It is not expired or close to
  48. // expired, or the token has no expiry.
  49. fresh tokenState = iota
  50. // stale indicates that the [Token] is close to expired, and should be
  51. // refreshed. The token can be used normally.
  52. stale
  53. // invalid indicates that the [Token] is expired or invalid. The token
  54. // cannot be used for a normal operation.
  55. invalid
  56. )
  57. var (
  58. defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
  59. defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType}
  60. // for testing
  61. timeNow = time.Now
  62. )
  63. // TokenProvider specifies an interface for anything that can return a token.
  64. type TokenProvider interface {
  65. // Token returns a Token or an error.
  66. // The Token returned must be safe to use
  67. // concurrently.
  68. // The returned Token must not be modified.
  69. // The context provided must be sent along to any requests that are made in
  70. // the implementing code.
  71. Token(context.Context) (*Token, error)
  72. }
  73. // Token holds the credential token used to authorized requests. All fields are
  74. // considered read-only.
  75. type Token struct {
  76. // Value is the token used to authorize requests. It is usually an access
  77. // token but may be other types of tokens such as ID tokens in some flows.
  78. Value string
  79. // Type is the type of token Value is. If uninitialized, it should be
  80. // assumed to be a "Bearer" token.
  81. Type string
  82. // Expiry is the time the token is set to expire.
  83. Expiry time.Time
  84. // Metadata may include, but is not limited to, the body of the token
  85. // response returned by the server.
  86. Metadata map[string]interface{} // TODO(codyoss): maybe make a method to flatten metadata to avoid []string for url.Values
  87. }
  88. // IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not
  89. // expired. A token is considered expired if [Token.Expiry] has passed or will
  90. // pass in the next 225 seconds.
  91. func (t *Token) IsValid() bool {
  92. return t.isValidWithEarlyExpiry(defaultExpiryDelta)
  93. }
  94. // MetadataString is a convenience method for accessing string values in the
  95. // token's metadata. Returns an empty string if the metadata is nil or the value
  96. // for the given key cannot be cast to a string.
  97. func (t *Token) MetadataString(k string) string {
  98. if t.Metadata == nil {
  99. return ""
  100. }
  101. s, ok := t.Metadata[k].(string)
  102. if !ok {
  103. return ""
  104. }
  105. return s
  106. }
  107. func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool {
  108. if t.isEmpty() {
  109. return false
  110. }
  111. if t.Expiry.IsZero() {
  112. return true
  113. }
  114. return !t.Expiry.Round(0).Add(-earlyExpiry).Before(timeNow())
  115. }
  116. func (t *Token) isEmpty() bool {
  117. return t == nil || t.Value == ""
  118. }
  119. // Credentials holds Google credentials, including
  120. // [Application Default Credentials].
  121. //
  122. // [Application Default Credentials]: https://developers.google.com/accounts/docs/application-default-credentials
  123. type Credentials struct {
  124. json []byte
  125. projectID CredentialsPropertyProvider
  126. quotaProjectID CredentialsPropertyProvider
  127. // universeDomain is the default service domain for a given Cloud universe.
  128. universeDomain CredentialsPropertyProvider
  129. TokenProvider
  130. }
  131. // JSON returns the bytes associated with the the file used to source
  132. // credentials if one was used.
  133. func (c *Credentials) JSON() []byte {
  134. return c.json
  135. }
  136. // ProjectID returns the associated project ID from the underlying file or
  137. // environment.
  138. func (c *Credentials) ProjectID(ctx context.Context) (string, error) {
  139. if c.projectID == nil {
  140. return internal.GetProjectID(c.json, ""), nil
  141. }
  142. v, err := c.projectID.GetProperty(ctx)
  143. if err != nil {
  144. return "", err
  145. }
  146. return internal.GetProjectID(c.json, v), nil
  147. }
  148. // QuotaProjectID returns the associated quota project ID from the underlying
  149. // file or environment.
  150. func (c *Credentials) QuotaProjectID(ctx context.Context) (string, error) {
  151. if c.quotaProjectID == nil {
  152. return internal.GetQuotaProject(c.json, ""), nil
  153. }
  154. v, err := c.quotaProjectID.GetProperty(ctx)
  155. if err != nil {
  156. return "", err
  157. }
  158. return internal.GetQuotaProject(c.json, v), nil
  159. }
  160. // UniverseDomain returns the default service domain for a given Cloud universe.
  161. // The default value is "googleapis.com".
  162. func (c *Credentials) UniverseDomain(ctx context.Context) (string, error) {
  163. if c.universeDomain == nil {
  164. return universeDomainDefault, nil
  165. }
  166. v, err := c.universeDomain.GetProperty(ctx)
  167. if err != nil {
  168. return "", err
  169. }
  170. if v == "" {
  171. return universeDomainDefault, nil
  172. }
  173. return v, err
  174. }
  175. // CredentialsPropertyProvider provides an implementation to fetch a property
  176. // value for [Credentials].
  177. type CredentialsPropertyProvider interface {
  178. GetProperty(context.Context) (string, error)
  179. }
  180. // CredentialsPropertyFunc is a type adapter to allow the use of ordinary
  181. // functions as a [CredentialsPropertyProvider].
  182. type CredentialsPropertyFunc func(context.Context) (string, error)
  183. // GetProperty loads the properly value provided the given context.
  184. func (p CredentialsPropertyFunc) GetProperty(ctx context.Context) (string, error) {
  185. return p(ctx)
  186. }
  187. // CredentialsOptions are used to configure [Credentials].
  188. type CredentialsOptions struct {
  189. // TokenProvider is a means of sourcing a token for the credentials. Required.
  190. TokenProvider TokenProvider
  191. // JSON is the raw contents of the credentials file if sourced from a file.
  192. JSON []byte
  193. // ProjectIDProvider resolves the project ID associated with the
  194. // credentials.
  195. ProjectIDProvider CredentialsPropertyProvider
  196. // QuotaProjectIDProvider resolves the quota project ID associated with the
  197. // credentials.
  198. QuotaProjectIDProvider CredentialsPropertyProvider
  199. // UniverseDomainProvider resolves the universe domain with the credentials.
  200. UniverseDomainProvider CredentialsPropertyProvider
  201. }
  202. // NewCredentials returns new [Credentials] from the provided options.
  203. func NewCredentials(opts *CredentialsOptions) *Credentials {
  204. creds := &Credentials{
  205. TokenProvider: opts.TokenProvider,
  206. json: opts.JSON,
  207. projectID: opts.ProjectIDProvider,
  208. quotaProjectID: opts.QuotaProjectIDProvider,
  209. universeDomain: opts.UniverseDomainProvider,
  210. }
  211. return creds
  212. }
  213. // CachedTokenProviderOptions provides options for configuring a cached
  214. // [TokenProvider].
  215. type CachedTokenProviderOptions struct {
  216. // DisableAutoRefresh makes the TokenProvider always return the same token,
  217. // even if it is expired. The default is false. Optional.
  218. DisableAutoRefresh bool
  219. // ExpireEarly configures the amount of time before a token expires, that it
  220. // should be refreshed. If unset, the default value is 3 minutes and 45
  221. // seconds. Optional.
  222. ExpireEarly time.Duration
  223. // DisableAsyncRefresh configures a synchronous workflow that refreshes
  224. // tokens in a blocking manner. The default is false. Optional.
  225. DisableAsyncRefresh bool
  226. }
  227. func (ctpo *CachedTokenProviderOptions) autoRefresh() bool {
  228. if ctpo == nil {
  229. return true
  230. }
  231. return !ctpo.DisableAutoRefresh
  232. }
  233. func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration {
  234. if ctpo == nil || ctpo.ExpireEarly == 0 {
  235. return defaultExpiryDelta
  236. }
  237. return ctpo.ExpireEarly
  238. }
  239. func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool {
  240. if ctpo == nil {
  241. return false
  242. }
  243. return ctpo.DisableAsyncRefresh
  244. }
  245. // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned
  246. // by the underlying provider. By default it will refresh tokens asynchronously
  247. // a few minutes before they expire.
  248. func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider {
  249. if ctp, ok := tp.(*cachedTokenProvider); ok {
  250. return ctp
  251. }
  252. return &cachedTokenProvider{
  253. tp: tp,
  254. autoRefresh: opts.autoRefresh(),
  255. expireEarly: opts.expireEarly(),
  256. blockingRefresh: opts.blockingRefresh(),
  257. }
  258. }
  259. type cachedTokenProvider struct {
  260. tp TokenProvider
  261. autoRefresh bool
  262. expireEarly time.Duration
  263. blockingRefresh bool
  264. mu sync.Mutex
  265. cachedToken *Token
  266. // isRefreshRunning ensures that the non-blocking refresh will only be
  267. // attempted once, even if multiple callers enter the Token method.
  268. isRefreshRunning bool
  269. // isRefreshErr ensures that the non-blocking refresh will only be attempted
  270. // once per refresh window if an error is encountered.
  271. isRefreshErr bool
  272. }
  273. func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) {
  274. if c.blockingRefresh {
  275. return c.tokenBlocking(ctx)
  276. }
  277. return c.tokenNonBlocking(ctx)
  278. }
  279. func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) {
  280. switch c.tokenState() {
  281. case fresh:
  282. c.mu.Lock()
  283. defer c.mu.Unlock()
  284. return c.cachedToken, nil
  285. case stale:
  286. // Call tokenAsync with a new Context because the user-provided context
  287. // may have a short timeout incompatible with async token refresh.
  288. c.tokenAsync(context.Background())
  289. // Return the stale token immediately to not block customer requests to Cloud services.
  290. c.mu.Lock()
  291. defer c.mu.Unlock()
  292. return c.cachedToken, nil
  293. default: // invalid
  294. return c.tokenBlocking(ctx)
  295. }
  296. }
  297. // tokenState reports the token's validity.
  298. func (c *cachedTokenProvider) tokenState() tokenState {
  299. c.mu.Lock()
  300. defer c.mu.Unlock()
  301. t := c.cachedToken
  302. now := timeNow()
  303. if t == nil || t.Value == "" {
  304. return invalid
  305. } else if t.Expiry.IsZero() {
  306. return fresh
  307. } else if now.After(t.Expiry.Round(0)) {
  308. return invalid
  309. } else if now.After(t.Expiry.Round(0).Add(-c.expireEarly)) {
  310. return stale
  311. }
  312. return fresh
  313. }
  314. // tokenAsync uses a bool to ensure that only one non-blocking token refresh
  315. // happens at a time, even if multiple callers have entered this function
  316. // concurrently. This avoids creating an arbitrary number of concurrent
  317. // goroutines. Retries should be attempted and managed within the Token method.
  318. // If the refresh attempt fails, no further attempts are made until the refresh
  319. // window expires and the token enters the invalid state, at which point the
  320. // blocking call to Token should likely return the same error on the main goroutine.
  321. func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
  322. fn := func() {
  323. c.mu.Lock()
  324. c.isRefreshRunning = true
  325. c.mu.Unlock()
  326. t, err := c.tp.Token(ctx)
  327. c.mu.Lock()
  328. defer c.mu.Unlock()
  329. c.isRefreshRunning = false
  330. if err != nil {
  331. // Discard errors from the non-blocking refresh, but prevent further
  332. // attempts.
  333. c.isRefreshErr = true
  334. return
  335. }
  336. c.cachedToken = t
  337. }
  338. c.mu.Lock()
  339. defer c.mu.Unlock()
  340. if !c.isRefreshRunning && !c.isRefreshErr {
  341. go fn()
  342. }
  343. }
  344. func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) {
  345. c.mu.Lock()
  346. defer c.mu.Unlock()
  347. c.isRefreshErr = false
  348. if c.cachedToken.IsValid() || (!c.autoRefresh && !c.cachedToken.isEmpty()) {
  349. return c.cachedToken, nil
  350. }
  351. t, err := c.tp.Token(ctx)
  352. if err != nil {
  353. return nil, err
  354. }
  355. c.cachedToken = t
  356. return t, nil
  357. }
  358. // Error is a error associated with retrieving a [Token]. It can hold useful
  359. // additional details for debugging.
  360. type Error struct {
  361. // Response is the HTTP response associated with error. The body will always
  362. // be already closed and consumed.
  363. Response *http.Response
  364. // Body is the HTTP response body.
  365. Body []byte
  366. // Err is the underlying wrapped error.
  367. Err error
  368. // code returned in the token response
  369. code string
  370. // description returned in the token response
  371. description string
  372. // uri returned in the token response
  373. uri string
  374. }
  375. func (e *Error) Error() string {
  376. if e.code != "" {
  377. s := fmt.Sprintf("auth: %q", e.code)
  378. if e.description != "" {
  379. s += fmt.Sprintf(" %q", e.description)
  380. }
  381. if e.uri != "" {
  382. s += fmt.Sprintf(" %q", e.uri)
  383. }
  384. return s
  385. }
  386. return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
  387. }
  388. // Temporary returns true if the error is considered temporary and may be able
  389. // to be retried.
  390. func (e *Error) Temporary() bool {
  391. if e.Response == nil {
  392. return false
  393. }
  394. sc := e.Response.StatusCode
  395. return sc == http.StatusInternalServerError || sc == http.StatusServiceUnavailable || sc == http.StatusRequestTimeout || sc == http.StatusTooManyRequests
  396. }
  397. func (e *Error) Unwrap() error {
  398. return e.Err
  399. }
  400. // Style describes how the token endpoint wants to receive the ClientID and
  401. // ClientSecret.
  402. type Style int
  403. const (
  404. // StyleUnknown means the value has not been initiated. Sending this in
  405. // a request will cause the token exchange to fail.
  406. StyleUnknown Style = iota
  407. // StyleInParams sends client info in the body of a POST request.
  408. StyleInParams
  409. // StyleInHeader sends client info using Basic Authorization header.
  410. StyleInHeader
  411. )
  412. // Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow.
  413. type Options2LO struct {
  414. // Email is the OAuth2 client ID. This value is set as the "iss" in the
  415. // JWT.
  416. Email string
  417. // PrivateKey contains the contents of an RSA private key or the
  418. // contents of a PEM file that contains a private key. It is used to sign
  419. // the JWT created.
  420. PrivateKey []byte
  421. // TokenURL is th URL the JWT is sent to. Required.
  422. TokenURL string
  423. // PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
  424. // "kid" in the JWT header. Optional.
  425. PrivateKeyID string
  426. // Subject is the used for to impersonate a user. It is used as the "sub" in
  427. // the JWT.m Optional.
  428. Subject string
  429. // Scopes specifies requested permissions for the token. Optional.
  430. Scopes []string
  431. // Expires specifies the lifetime of the token. Optional.
  432. Expires time.Duration
  433. // Audience specifies the "aud" in the JWT. Optional.
  434. Audience string
  435. // PrivateClaims allows specifying any custom claims for the JWT. Optional.
  436. PrivateClaims map[string]interface{}
  437. // Client is the client to be used to make the underlying token requests.
  438. // Optional.
  439. Client *http.Client
  440. // UseIDToken requests that the token returned be an ID token if one is
  441. // returned from the server. Optional.
  442. UseIDToken bool
  443. }
  444. func (o *Options2LO) client() *http.Client {
  445. if o.Client != nil {
  446. return o.Client
  447. }
  448. return internal.DefaultClient()
  449. }
  450. func (o *Options2LO) validate() error {
  451. if o == nil {
  452. return errors.New("auth: options must be provided")
  453. }
  454. if o.Email == "" {
  455. return errors.New("auth: email must be provided")
  456. }
  457. if len(o.PrivateKey) == 0 {
  458. return errors.New("auth: private key must be provided")
  459. }
  460. if o.TokenURL == "" {
  461. return errors.New("auth: token URL must be provided")
  462. }
  463. return nil
  464. }
  465. // New2LOTokenProvider returns a [TokenProvider] from the provided options.
  466. func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
  467. if err := opts.validate(); err != nil {
  468. return nil, err
  469. }
  470. return tokenProvider2LO{opts: opts, Client: opts.client()}, nil
  471. }
  472. type tokenProvider2LO struct {
  473. opts *Options2LO
  474. Client *http.Client
  475. }
  476. func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) {
  477. pk, err := internal.ParseKey(tp.opts.PrivateKey)
  478. if err != nil {
  479. return nil, err
  480. }
  481. claimSet := &jwt.Claims{
  482. Iss: tp.opts.Email,
  483. Scope: strings.Join(tp.opts.Scopes, " "),
  484. Aud: tp.opts.TokenURL,
  485. AdditionalClaims: tp.opts.PrivateClaims,
  486. Sub: tp.opts.Subject,
  487. }
  488. if t := tp.opts.Expires; t > 0 {
  489. claimSet.Exp = time.Now().Add(t).Unix()
  490. }
  491. if aud := tp.opts.Audience; aud != "" {
  492. claimSet.Aud = aud
  493. }
  494. h := *defaultHeader
  495. h.KeyID = tp.opts.PrivateKeyID
  496. payload, err := jwt.EncodeJWS(&h, claimSet, pk)
  497. if err != nil {
  498. return nil, err
  499. }
  500. v := url.Values{}
  501. v.Set("grant_type", defaultGrantType)
  502. v.Set("assertion", payload)
  503. req, err := http.NewRequestWithContext(ctx, "POST", tp.opts.TokenURL, strings.NewReader(v.Encode()))
  504. if err != nil {
  505. return nil, err
  506. }
  507. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  508. resp, body, err := internal.DoRequest(tp.Client, req)
  509. if err != nil {
  510. return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
  511. }
  512. if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
  513. return nil, &Error{
  514. Response: resp,
  515. Body: body,
  516. }
  517. }
  518. // tokenRes is the JSON response body.
  519. var tokenRes struct {
  520. AccessToken string `json:"access_token"`
  521. TokenType string `json:"token_type"`
  522. IDToken string `json:"id_token"`
  523. ExpiresIn int64 `json:"expires_in"`
  524. }
  525. if err := json.Unmarshal(body, &tokenRes); err != nil {
  526. return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
  527. }
  528. token := &Token{
  529. Value: tokenRes.AccessToken,
  530. Type: tokenRes.TokenType,
  531. }
  532. token.Metadata = make(map[string]interface{})
  533. json.Unmarshal(body, &token.Metadata) // no error checks for optional fields
  534. if secs := tokenRes.ExpiresIn; secs > 0 {
  535. token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
  536. }
  537. if v := tokenRes.IDToken; v != "" {
  538. // decode returned id token to get expiry
  539. claimSet, err := jwt.DecodeJWS(v)
  540. if err != nil {
  541. return nil, fmt.Errorf("auth: error decoding JWT token: %w", err)
  542. }
  543. token.Expiry = time.Unix(claimSet.Exp, 0)
  544. }
  545. if tp.opts.UseIDToken {
  546. if tokenRes.IDToken == "" {
  547. return nil, fmt.Errorf("auth: response doesn't have JWT token")
  548. }
  549. token.Value = tokenRes.IDToken
  550. }
  551. return token, nil
  552. }