jws: improve fix for CVE-2025-22868
The fix for CVE-2025-22868 relies on strings.Count, which isn't ideal
because it precludes failing fast when the token contains an unexpected
number of periods. Moreover, Verify still allocates more than necessary.
Eschew strings.Count in favor of strings.Cut. Some benchmark results:
goos: darwin
goarch: amd64
pkg: golang.org/x/oauth2/jws
cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
│ old │ new │
│ sec/op │ sec/op vs base │
Verify/full_of_periods-8 24862.50n ± 1% 57.87n ± 0% -99.77% (p=0.000 n=20)
Verify/two_trailing_periods-8 3.485m ± 1% 3.445m ± 1% -1.13% (p=0.003 n=20)
geomean 294.3µ 14.12µ -95.20%
│ old │ new │
│ B/op │ B/op vs base │
Verify/full_of_periods-8 16.00 ± 0% 16.00 ± 0% ~ (p=1.000 n=20) ¹
Verify/two_trailing_periods-8 2.001Mi ± 0% 1.001Mi ± 0% -49.98% (p=0.000 n=20)
geomean 5.658Ki 4.002Ki -29.27%
¹ all samples are equal
│ old │ new │
│ allocs/op │ allocs/op vs base │
Verify/full_of_periods-8 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹
Verify/two_trailing_periods-8 12.000 ± 0% 9.000 ± 0% -25.00% (p=0.000 n=20)
geomean 3.464 3.000 -13.40%
¹ all samples are equal
Also, remove all remaining calls to strings.Split.
Updates golang/go#71490
Change-Id: Icac3c7a81562161ab6533d892ba19247d6d5b943
GitHub-Last-Rev: 3a82900f747798f5f36065126385880277c0fce7
GitHub-Pull-Request: golang/oauth2#774
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/655455
Commit-Queue: Neal Patel <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Neal Patel <[email protected]>
Auto-Submit: Neal Patel <[email protected]>
diff --git a/jws/jws.go b/jws/jws.go
index 6f03a49..27ab061 100644
--- a/jws/jws.go
+++ b/jws/jws.go
@@ -116,12 +116,12 @@
// Decode decodes a claim set from a JWS payload.
func Decode(payload string) (*ClaimSet, error) {
// decode returned id token to get expiry
- s := strings.Split(payload, ".")
- if len(s) < 2 {
+ _, claims, _, ok := parseToken(payload)
+ if !ok {
// TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received")
}
- decoded, err := base64.RawURLEncoding.DecodeString(s[1])
+ decoded, err := base64.RawURLEncoding.DecodeString(claims)
if err != nil {
return nil, err
}
@@ -165,18 +165,34 @@
// Verify tests whether the provided JWT token's signature was produced by the private key
// associated with the supplied public key.
func Verify(token string, key *rsa.PublicKey) error {
- if strings.Count(token, ".") != 2 {
+ header, claims, sig, ok := parseToken(token)
+ if !ok {
return errors.New("jws: invalid token received, token must have 3 parts")
}
-
- parts := strings.SplitN(token, ".", 3)
- signedContent := parts[0] + "." + parts[1]
- signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
+ signatureString, err := base64.RawURLEncoding.DecodeString(sig)
if err != nil {
return err
}
h := sha256.New()
- h.Write([]byte(signedContent))
+ h.Write([]byte(header + tokenDelim + claims))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
}
+
+func parseToken(s string) (header, claims, sig string, ok bool) {
+ header, s, ok = strings.Cut(s, tokenDelim)
+ if !ok { // no period found
+ return "", "", "", false
+ }
+ claims, s, ok = strings.Cut(s, tokenDelim)
+ if !ok { // only one period found
+ return "", "", "", false
+ }
+ sig, _, ok = strings.Cut(s, tokenDelim)
+ if ok { // three periods found
+ return "", "", "", false
+ }
+ return header, claims, sig, true
+}
+
+const tokenDelim = "."
diff --git a/jws/jws_test.go b/jws/jws_test.go
index 39a136a..1776f56 100644
--- a/jws/jws_test.go
+++ b/jws/jws_test.go
@@ -7,6 +7,8 @@
import (
"crypto/rand"
"crypto/rsa"
+ "net/http"
+ "strings"
"testing"
)
@@ -39,8 +41,57 @@
}
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
- err := Verify("abc.def", nil)
- if err == nil {
- t.Error("got no errors; want improperly formed JWT not to be verified")
+ cases := []struct {
+ desc string
+ token string
+ }{
+ {
+ desc: "no periods",
+ token: "aa",
+ }, {
+ desc: "only one period",
+ token: "a.a",
+ }, {
+ desc: "more than two periods",
+ token: "a.a.a.a",
+ },
+ }
+ for _, tc := range cases {
+ f := func(t *testing.T) {
+ err := Verify(tc.token, nil)
+ if err == nil {
+ t.Error("got no errors; want improperly formed JWT not to be verified")
+ }
+ }
+ t.Run(tc.desc, f)
+ }
+}
+
+func BenchmarkVerify(b *testing.B) {
+ cases := []struct {
+ desc string
+ token string
+ }{
+ {
+ desc: "full of periods",
+ token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
+ }, {
+ desc: "two trailing periods",
+ token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
+ },
+ }
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ b.Fatal(err)
+ }
+ for _, bc := range cases {
+ f := func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ Verify(bc.token, &privateKey.PublicKey)
+ }
+ }
+ b.Run(bc.desc, f)
}
}