gdch.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package gdch
  15. import (
  16. "context"
  17. "crypto/rsa"
  18. "crypto/tls"
  19. "crypto/x509"
  20. "encoding/json"
  21. "errors"
  22. "fmt"
  23. "net/http"
  24. "net/url"
  25. "os"
  26. "strings"
  27. "time"
  28. "cloud.google.com/go/auth"
  29. "cloud.google.com/go/auth/internal"
  30. "cloud.google.com/go/auth/internal/credsfile"
  31. "cloud.google.com/go/auth/internal/jwt"
  32. )
  33. const (
  34. // GrantType is the grant type for the token request.
  35. GrantType = "urn:ietf:params:oauth:token-type:token-exchange"
  36. requestTokenType = "urn:ietf:params:oauth:token-type:access_token"
  37. subjectTokenType = "urn:k8s:params:oauth:token-type:serviceaccount"
  38. )
  39. var (
  40. gdchSupportFormatVersions map[string]bool = map[string]bool{
  41. "1": true,
  42. }
  43. )
  44. // Options for [NewTokenProvider].
  45. type Options struct {
  46. STSAudience string
  47. Client *http.Client
  48. }
  49. // NewTokenProvider returns a [cloud.google.com/go/auth.TokenProvider] from a
  50. // GDCH cred file.
  51. func NewTokenProvider(f *credsfile.GDCHServiceAccountFile, o *Options) (auth.TokenProvider, error) {
  52. if !gdchSupportFormatVersions[f.FormatVersion] {
  53. return nil, fmt.Errorf("credentials: unsupported gdch_service_account format %q", f.FormatVersion)
  54. }
  55. if o.STSAudience == "" {
  56. return nil, errors.New("credentials: STSAudience must be set for the GDCH auth flows")
  57. }
  58. pk, err := internal.ParseKey([]byte(f.PrivateKey))
  59. if err != nil {
  60. return nil, err
  61. }
  62. certPool, err := loadCertPool(f.CertPath)
  63. if err != nil {
  64. return nil, err
  65. }
  66. tp := gdchProvider{
  67. serviceIdentity: fmt.Sprintf("system:serviceaccount:%s:%s", f.Project, f.Name),
  68. tokenURL: f.TokenURL,
  69. aud: o.STSAudience,
  70. pk: pk,
  71. pkID: f.PrivateKeyID,
  72. certPool: certPool,
  73. client: o.Client,
  74. }
  75. return tp, nil
  76. }
  77. func loadCertPool(path string) (*x509.CertPool, error) {
  78. pool := x509.NewCertPool()
  79. pem, err := os.ReadFile(path)
  80. if err != nil {
  81. return nil, fmt.Errorf("credentials: failed to read certificate: %w", err)
  82. }
  83. pool.AppendCertsFromPEM(pem)
  84. return pool, nil
  85. }
  86. type gdchProvider struct {
  87. serviceIdentity string
  88. tokenURL string
  89. aud string
  90. pk *rsa.PrivateKey
  91. pkID string
  92. certPool *x509.CertPool
  93. client *http.Client
  94. }
  95. func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
  96. addCertToTransport(g.client, g.certPool)
  97. iat := time.Now()
  98. exp := iat.Add(time.Hour)
  99. claims := jwt.Claims{
  100. Iss: g.serviceIdentity,
  101. Sub: g.serviceIdentity,
  102. Aud: g.tokenURL,
  103. Iat: iat.Unix(),
  104. Exp: exp.Unix(),
  105. }
  106. h := jwt.Header{
  107. Algorithm: jwt.HeaderAlgRSA256,
  108. Type: jwt.HeaderType,
  109. KeyID: string(g.pkID),
  110. }
  111. payload, err := jwt.EncodeJWS(&h, &claims, g.pk)
  112. if err != nil {
  113. return nil, err
  114. }
  115. v := url.Values{}
  116. v.Set("grant_type", GrantType)
  117. v.Set("audience", g.aud)
  118. v.Set("requested_token_type", requestTokenType)
  119. v.Set("subject_token", payload)
  120. v.Set("subject_token_type", subjectTokenType)
  121. req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, strings.NewReader(v.Encode()))
  122. if err != nil {
  123. return nil, err
  124. }
  125. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  126. resp, body, err := internal.DoRequest(g.client, req)
  127. if err != nil {
  128. return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
  129. }
  130. if c := resp.StatusCode; c < http.StatusOK || c > http.StatusMultipleChoices {
  131. return nil, &auth.Error{
  132. Response: resp,
  133. Body: body,
  134. }
  135. }
  136. var tokenRes struct {
  137. AccessToken string `json:"access_token"`
  138. TokenType string `json:"token_type"`
  139. ExpiresIn int64 `json:"expires_in"` // relative seconds from now
  140. }
  141. if err := json.Unmarshal(body, &tokenRes); err != nil {
  142. return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
  143. }
  144. token := &auth.Token{
  145. Value: tokenRes.AccessToken,
  146. Type: tokenRes.TokenType,
  147. }
  148. raw := make(map[string]interface{})
  149. json.Unmarshal(body, &raw) // no error checks for optional fields
  150. token.Metadata = raw
  151. if secs := tokenRes.ExpiresIn; secs > 0 {
  152. token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
  153. }
  154. return token, nil
  155. }
  156. // addCertToTransport makes a best effort attempt at adding in the cert info to
  157. // the client. It tries to keep all configured transport settings if the
  158. // underlying transport is an http.Transport. Or else it overwrites the
  159. // transport with defaults adding in the certs.
  160. func addCertToTransport(hc *http.Client, certPool *x509.CertPool) {
  161. trans, ok := hc.Transport.(*http.Transport)
  162. if !ok {
  163. trans = http.DefaultTransport.(*http.Transport).Clone()
  164. }
  165. trans.TLSClientConfig = &tls.Config{
  166. RootCAs: certPool,
  167. }
  168. }