threelegged.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package auth
  15. import (
  16. "bytes"
  17. "context"
  18. "encoding/json"
  19. "errors"
  20. "fmt"
  21. "mime"
  22. "net/http"
  23. "net/url"
  24. "strconv"
  25. "strings"
  26. "time"
  27. "cloud.google.com/go/auth/internal"
  28. )
  29. // AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
  30. // OAuth consent at the specified auth code URL and returns an auth code and
  31. // state upon approval.
  32. type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
  33. // Options3LO are the options for doing a 3-legged OAuth2 flow.
  34. type Options3LO struct {
  35. // ClientID is the application's ID.
  36. ClientID string
  37. // ClientSecret is the application's secret. Not required if AuthHandlerOpts
  38. // is set.
  39. ClientSecret string
  40. // AuthURL is the URL for authenticating.
  41. AuthURL string
  42. // TokenURL is the URL for retrieving a token.
  43. TokenURL string
  44. // AuthStyle is used to describe how to client info in the token request.
  45. AuthStyle Style
  46. // RefreshToken is the token used to refresh the credential. Not required
  47. // if AuthHandlerOpts is set.
  48. RefreshToken string
  49. // RedirectURL is the URL to redirect users to. Optional.
  50. RedirectURL string
  51. // Scopes specifies requested permissions for the Token. Optional.
  52. Scopes []string
  53. // URLParams are the set of values to apply to the token exchange. Optional.
  54. URLParams url.Values
  55. // Client is the client to be used to make the underlying token requests.
  56. // Optional.
  57. Client *http.Client
  58. // EarlyTokenExpiry is the time before the token expires that it should be
  59. // refreshed. If not set the default value is 3 minutes and 45 seconds.
  60. // Optional.
  61. EarlyTokenExpiry time.Duration
  62. // AuthHandlerOpts provides a set of options for doing a
  63. // 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
  64. AuthHandlerOpts *AuthorizationHandlerOptions
  65. }
  66. func (o *Options3LO) validate() error {
  67. if o == nil {
  68. return errors.New("auth: options must be provided")
  69. }
  70. if o.ClientID == "" {
  71. return errors.New("auth: client ID must be provided")
  72. }
  73. if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
  74. return errors.New("auth: client secret must be provided")
  75. }
  76. if o.AuthURL == "" {
  77. return errors.New("auth: auth URL must be provided")
  78. }
  79. if o.TokenURL == "" {
  80. return errors.New("auth: token URL must be provided")
  81. }
  82. if o.AuthStyle == StyleUnknown {
  83. return errors.New("auth: auth style must be provided")
  84. }
  85. if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
  86. return errors.New("auth: refresh token must be provided")
  87. }
  88. return nil
  89. }
  90. // PKCEOptions holds parameters to support PKCE.
  91. type PKCEOptions struct {
  92. // Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
  93. Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
  94. // ChallengeMethod is the encryption method (ex. S256).
  95. ChallengeMethod string
  96. // Verifier is the original, non-encrypted secret.
  97. Verifier string // The original, non-encrypted secret.
  98. }
  99. type tokenJSON struct {
  100. AccessToken string `json:"access_token"`
  101. TokenType string `json:"token_type"`
  102. RefreshToken string `json:"refresh_token"`
  103. ExpiresIn int `json:"expires_in"`
  104. // error fields
  105. ErrorCode string `json:"error"`
  106. ErrorDescription string `json:"error_description"`
  107. ErrorURI string `json:"error_uri"`
  108. }
  109. func (e *tokenJSON) expiry() (t time.Time) {
  110. if v := e.ExpiresIn; v != 0 {
  111. return time.Now().Add(time.Duration(v) * time.Second)
  112. }
  113. return
  114. }
  115. func (o *Options3LO) client() *http.Client {
  116. if o.Client != nil {
  117. return o.Client
  118. }
  119. return internal.DefaultClient()
  120. }
  121. // authCodeURL returns a URL that points to a OAuth2 consent page.
  122. func (o *Options3LO) authCodeURL(state string, values url.Values) string {
  123. var buf bytes.Buffer
  124. buf.WriteString(o.AuthURL)
  125. v := url.Values{
  126. "response_type": {"code"},
  127. "client_id": {o.ClientID},
  128. }
  129. if o.RedirectURL != "" {
  130. v.Set("redirect_uri", o.RedirectURL)
  131. }
  132. if len(o.Scopes) > 0 {
  133. v.Set("scope", strings.Join(o.Scopes, " "))
  134. }
  135. if state != "" {
  136. v.Set("state", state)
  137. }
  138. if o.AuthHandlerOpts != nil {
  139. if o.AuthHandlerOpts.PKCEOpts != nil &&
  140. o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
  141. v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
  142. }
  143. if o.AuthHandlerOpts.PKCEOpts != nil &&
  144. o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
  145. v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
  146. }
  147. }
  148. for k := range values {
  149. v.Set(k, v.Get(k))
  150. }
  151. if strings.Contains(o.AuthURL, "?") {
  152. buf.WriteByte('&')
  153. } else {
  154. buf.WriteByte('?')
  155. }
  156. buf.WriteString(v.Encode())
  157. return buf.String()
  158. }
  159. // New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
  160. // configuration. The TokenProvider is caches and auto-refreshes tokens by
  161. // default.
  162. func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
  163. if err := opts.validate(); err != nil {
  164. return nil, err
  165. }
  166. if opts.AuthHandlerOpts != nil {
  167. return new3LOTokenProviderWithAuthHandler(opts), nil
  168. }
  169. return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
  170. ExpireEarly: opts.EarlyTokenExpiry,
  171. }), nil
  172. }
  173. // AuthorizationHandlerOptions provides a set of options to specify for doing a
  174. // 3-legged OAuth2 flow with a custom [AuthorizationHandler].
  175. type AuthorizationHandlerOptions struct {
  176. // AuthorizationHandler specifies the handler used to for the authorization
  177. // part of the flow.
  178. Handler AuthorizationHandler
  179. // State is used verify that the "state" is identical in the request and
  180. // response before exchanging the auth code for OAuth2 token.
  181. State string
  182. // PKCEOpts allows setting configurations for PKCE. Optional.
  183. PKCEOpts *PKCEOptions
  184. }
  185. func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
  186. return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
  187. ExpireEarly: opts.EarlyTokenExpiry,
  188. })
  189. }
  190. // exchange handles the final exchange portion of the 3lo flow. Returns a Token,
  191. // refreshToken, and error.
  192. func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
  193. // Build request
  194. v := url.Values{
  195. "grant_type": {"authorization_code"},
  196. "code": {code},
  197. }
  198. if o.RedirectURL != "" {
  199. v.Set("redirect_uri", o.RedirectURL)
  200. }
  201. if o.AuthHandlerOpts != nil &&
  202. o.AuthHandlerOpts.PKCEOpts != nil &&
  203. o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
  204. v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
  205. }
  206. for k := range o.URLParams {
  207. v.Set(k, o.URLParams.Get(k))
  208. }
  209. return fetchToken(ctx, o, v)
  210. }
  211. // This struct is not safe for concurrent access alone, but the way it is used
  212. // in this package by wrapping it with a cachedTokenProvider makes it so.
  213. type tokenProvider3LO struct {
  214. opts *Options3LO
  215. client *http.Client
  216. refreshToken string
  217. }
  218. func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
  219. if tp.refreshToken == "" {
  220. return nil, errors.New("auth: token expired and refresh token is not set")
  221. }
  222. v := url.Values{
  223. "grant_type": {"refresh_token"},
  224. "refresh_token": {tp.refreshToken},
  225. }
  226. for k := range tp.opts.URLParams {
  227. v.Set(k, tp.opts.URLParams.Get(k))
  228. }
  229. tk, rt, err := fetchToken(ctx, tp.opts, v)
  230. if err != nil {
  231. return nil, err
  232. }
  233. if tp.refreshToken != rt && rt != "" {
  234. tp.refreshToken = rt
  235. }
  236. return tk, err
  237. }
  238. type tokenProviderWithHandler struct {
  239. opts *Options3LO
  240. state string
  241. }
  242. func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
  243. url := tp.opts.authCodeURL(tp.state, nil)
  244. code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
  245. if err != nil {
  246. return nil, err
  247. }
  248. if state != tp.state {
  249. return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
  250. }
  251. tok, _, err := tp.opts.exchange(ctx, code)
  252. return tok, err
  253. }
  254. // fetchToken returns a Token, refresh token, and/or an error.
  255. func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
  256. var refreshToken string
  257. if o.AuthStyle == StyleInParams {
  258. if o.ClientID != "" {
  259. v.Set("client_id", o.ClientID)
  260. }
  261. if o.ClientSecret != "" {
  262. v.Set("client_secret", o.ClientSecret)
  263. }
  264. }
  265. req, err := http.NewRequestWithContext(ctx, "POST", o.TokenURL, strings.NewReader(v.Encode()))
  266. if err != nil {
  267. return nil, refreshToken, err
  268. }
  269. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  270. if o.AuthStyle == StyleInHeader {
  271. req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
  272. }
  273. // Make request
  274. resp, body, err := internal.DoRequest(o.client(), req)
  275. if err != nil {
  276. return nil, refreshToken, err
  277. }
  278. failureStatus := resp.StatusCode < 200 || resp.StatusCode > 299
  279. tokError := &Error{
  280. Response: resp,
  281. Body: body,
  282. }
  283. var token *Token
  284. // errors ignored because of default switch on content
  285. content, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
  286. switch content {
  287. case "application/x-www-form-urlencoded", "text/plain":
  288. // some endpoints return a query string
  289. vals, err := url.ParseQuery(string(body))
  290. if err != nil {
  291. if failureStatus {
  292. return nil, refreshToken, tokError
  293. }
  294. return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
  295. }
  296. tokError.code = vals.Get("error")
  297. tokError.description = vals.Get("error_description")
  298. tokError.uri = vals.Get("error_uri")
  299. token = &Token{
  300. Value: vals.Get("access_token"),
  301. Type: vals.Get("token_type"),
  302. Metadata: make(map[string]interface{}, len(vals)),
  303. }
  304. for k, v := range vals {
  305. token.Metadata[k] = v
  306. }
  307. refreshToken = vals.Get("refresh_token")
  308. e := vals.Get("expires_in")
  309. expires, _ := strconv.Atoi(e)
  310. if expires != 0 {
  311. token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
  312. }
  313. default:
  314. var tj tokenJSON
  315. if err = json.Unmarshal(body, &tj); err != nil {
  316. if failureStatus {
  317. return nil, refreshToken, tokError
  318. }
  319. return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
  320. }
  321. tokError.code = tj.ErrorCode
  322. tokError.description = tj.ErrorDescription
  323. tokError.uri = tj.ErrorURI
  324. token = &Token{
  325. Value: tj.AccessToken,
  326. Type: tj.TokenType,
  327. Expiry: tj.expiry(),
  328. Metadata: make(map[string]interface{}),
  329. }
  330. json.Unmarshal(body, &token.Metadata) // optional field, skip err check
  331. refreshToken = tj.RefreshToken
  332. }
  333. // according to spec, servers should respond status 400 in error case
  334. // https://www.rfc-editor.org/rfc/rfc6749#section-5.2
  335. // but some unorthodox servers respond 200 in error case
  336. if failureStatus || tokError.code != "" {
  337. return nil, refreshToken, tokError
  338. }
  339. if token.Value == "" {
  340. return nil, refreshToken, errors.New("auth: server response missing access_token")
  341. }
  342. return token, refreshToken, nil
  343. }