Skip to content

Commit efa4ed8

Browse files
🔥 feat: Add support for Sec-Fetch-Site header in CSRF middleware (#3913)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 130caa6 commit efa4ed8

File tree

6 files changed

+218
-10
lines changed

6 files changed

+218
-10
lines changed

AGENTS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ These targets can be invoked via `make <target>` as needed during development an
5656
## Pull request guidelines
5757

5858
- PR titles must start with a category prefix describing the change: `🐛 bug:`, `🔥 feat:`, `📒 docs:`, or `🧹 chore:`.
59-
- Generated PR bodies should contain a **Summary** section that captures all changes included in the PR, not just the latest commit.
59+
- Generated PR titles and bodies must summarize the *entire* set of changes on the branch (for example, based on `git log --oneline <base>..HEAD` or the full diff), **not** just the latest commit. The Summary section should reflect all modifications that will be merged.
6060

6161
## Programmatic checks
6262

@@ -75,3 +75,7 @@ make test
7575
```
7676

7777
All checks must pass before the generated code can be merged.
78+
79+
After completing the programmatic checks above, confirm that any relevant
80+
documentation has been updated to reflect the changes made, including PR
81+
instructions when applicable.

constants.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ const (
256256
HeaderTE = "TE"
257257
HeaderTrailer = "Trailer"
258258
HeaderTransferEncoding = "Transfer-Encoding"
259+
HeaderSecFetchSite = "Sec-Fetch-Site"
259260
HeaderSecWebSocketAccept = "Sec-WebSocket-Accept"
260261
HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions"
261262
HeaderSecWebSocketKey = "Sec-WebSocket-Key"

docs/middleware/csrf.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ async function makeRequest(url, data) {
171171

172172
The middleware employs a robust, defense-in-depth strategy to protect against CSRF attacks. The primary defense is token-based validation, which operates in one of two modes depending on your configuration. This is supplemented by a mandatory secondary check on the request's origin.
173173

174+
### Fetch Metadata Guardrails
175+
176+
- **Sec-Fetch-Site**: For unsafe methods, the middleware inspects the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header when present. If the header value is not one of "same-origin", "none", "same-site", or "cross-site", the request is rejected with `ErrFetchSiteInvalid`. If the header is valid or absent, the request proceeds to the standard origin and token validation checks. This provides an early check to block requests with invalid `Sec-Fetch-Site` values, while allowing legitimate same-site and cross-site requests to be validated by the existing mechanisms.
177+
174178
### 1. Token Validation Patterns
175179

176180
#### Double Submit Cookie (Default Mode)

docs/whats_new.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,8 @@ The `Expiration` field in the CSRF middleware configuration has been renamed to
13111311

13121312
CSRF now redacts tokens and storage keys by default and exposes a `DisableValueRedaction` toggle (default `false`) if you must surface those values in diagnostics.
13131313

1314+
The CSRF middleware now validates the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header for unsafe HTTP methods. When present, requests with invalid `Sec-Fetch-Site` values (not one of "same-origin", "none", "same-site", or "cross-site") are rejected with `ErrFetchSiteInvalid`. Valid or absent headers proceed to standard origin and token validation checks, providing an early gate to catch malformed requests while maintaining compatibility with legitimate cross-site traffic.
1315+
13141316
### Idempotency
13151317

13161318
Idempotency middleware now redacts keys by default and offers a `DisableValueRedaction` configuration flag (default `false`) to expose them when debugging.

middleware/csrf/csrf.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ import (
1414
)
1515

1616
var (
17-
ErrTokenNotFound = errors.New("csrf: token not found")
18-
ErrTokenInvalid = errors.New("csrf: token invalid")
19-
ErrRefererNotFound = errors.New("csrf: referer header missing")
20-
ErrRefererInvalid = errors.New("csrf: referer header invalid")
21-
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
22-
ErrOriginInvalid = errors.New("csrf: origin header invalid")
23-
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
24-
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
25-
dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value.
17+
ErrTokenNotFound = errors.New("csrf: token not found")
18+
ErrTokenInvalid = errors.New("csrf: token invalid")
19+
ErrFetchSiteInvalid = errors.New("csrf: sec-fetch-site header invalid")
20+
ErrRefererNotFound = errors.New("csrf: referer header missing")
21+
ErrRefererInvalid = errors.New("csrf: referer header invalid")
22+
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
23+
ErrOriginInvalid = errors.New("csrf: origin header invalid")
24+
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
25+
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
26+
dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value.
2627

2728
)
2829

@@ -127,6 +128,11 @@ func New(config ...Config) fiber.Handler {
127128
default:
128129
// Assume that anything not defined as 'safe' by RFC7231 needs protection
129130

131+
// Evaluate Sec-Fetch-Site to reject cross-site requests earlier when available.
132+
if err := validateSecFetchSite(c); err != nil {
133+
return cfg.ErrorHandler(c, err)
134+
}
135+
130136
// Enforce an origin check for unsafe requests.
131137
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)
132138

@@ -313,6 +319,21 @@ func (handler *Handler) DeleteToken(c fiber.Ctx) error {
313319
return nil
314320
}
315321

322+
func validateSecFetchSite(c fiber.Ctx) error {
323+
secFetchSite := utils.Trim(c.Get(fiber.HeaderSecFetchSite), ' ')
324+
325+
if secFetchSite == "" {
326+
return nil
327+
}
328+
329+
switch utils.ToLower(secFetchSite) {
330+
case "same-origin", "none", "cross-site", "same-site":
331+
return nil
332+
default:
333+
return ErrFetchSiteInvalid
334+
}
335+
}
336+
316337
// originMatchesHost checks that the origin header matches the host header
317338
// returns an error if the origin header is not present or is invalid
318339
// returns nil if the origin header is valid

middleware/csrf/csrf_test.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,182 @@ func Test_CSRF_Extractor_EmptyString(t *testing.T) {
823823
require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body()))
824824
}
825825

826+
func Test_CSRF_SecFetchSite(t *testing.T) {
827+
t.Parallel()
828+
829+
errorHandler := func(c fiber.Ctx, err error) error {
830+
return c.Status(fiber.StatusForbidden).SendString(err.Error())
831+
}
832+
833+
app := fiber.New()
834+
835+
app.Use(New(Config{ErrorHandler: errorHandler}))
836+
837+
app.All("/", func(c fiber.Ctx) error {
838+
return c.SendStatus(fiber.StatusOK)
839+
})
840+
841+
h := app.Handler()
842+
ctx := &fasthttp.RequestCtx{}
843+
ctx.Request.Header.SetMethod(fiber.MethodGet)
844+
ctx.Request.URI().SetScheme("http")
845+
ctx.Request.URI().SetHost("example.com")
846+
ctx.Request.Header.SetHost("example.com")
847+
h(ctx)
848+
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
849+
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
850+
851+
tests := []struct {
852+
name string
853+
method string
854+
secFetchSite string
855+
origin string
856+
expectedStatus int16
857+
https bool
858+
expectFetchSiteInvalid bool
859+
}{
860+
{
861+
name: "same-origin allowed",
862+
method: fiber.MethodPost,
863+
secFetchSite: "same-origin",
864+
origin: "http://example.com",
865+
expectedStatus: http.StatusOK,
866+
},
867+
{
868+
name: "none allowed",
869+
method: fiber.MethodPost,
870+
secFetchSite: "none",
871+
origin: "http://example.com",
872+
expectedStatus: http.StatusOK,
873+
},
874+
{
875+
name: "cross-site with origin allowed",
876+
method: fiber.MethodPost,
877+
secFetchSite: "cross-site",
878+
origin: "http://example.com",
879+
expectedStatus: http.StatusOK,
880+
},
881+
{
882+
name: "same-site with origin allowed",
883+
method: fiber.MethodPost,
884+
secFetchSite: "same-site",
885+
origin: "http://example.com",
886+
expectedStatus: http.StatusOK,
887+
},
888+
{
889+
name: "cross-site with mismatched origin blocked",
890+
method: fiber.MethodPost,
891+
secFetchSite: "cross-site",
892+
origin: "https://attacker.example",
893+
expectedStatus: http.StatusForbidden,
894+
},
895+
{
896+
name: "same-site with null origin blocked",
897+
method: fiber.MethodPost,
898+
secFetchSite: "same-site",
899+
origin: "null",
900+
expectedStatus: http.StatusForbidden,
901+
https: true,
902+
},
903+
{
904+
name: "invalid header blocked",
905+
method: fiber.MethodPost,
906+
secFetchSite: "weird",
907+
origin: "http://example.com",
908+
expectedStatus: http.StatusForbidden,
909+
expectFetchSiteInvalid: true,
910+
},
911+
{
912+
name: "no header with no origin",
913+
method: fiber.MethodPost,
914+
origin: "",
915+
expectedStatus: http.StatusOK,
916+
},
917+
{
918+
name: "no header with matching origin",
919+
method: fiber.MethodPost,
920+
origin: "http://example.com",
921+
expectedStatus: http.StatusOK,
922+
},
923+
{
924+
name: "no header with mismatched origin",
925+
method: fiber.MethodPost,
926+
origin: "https://attacker.example",
927+
expectedStatus: http.StatusForbidden,
928+
},
929+
{
930+
name: "no header with null origin",
931+
method: fiber.MethodPost,
932+
origin: "null",
933+
expectedStatus: http.StatusForbidden,
934+
https: true,
935+
},
936+
{
937+
name: "GET allowed",
938+
method: fiber.MethodGet,
939+
secFetchSite: "cross-site",
940+
expectedStatus: http.StatusOK,
941+
},
942+
{
943+
name: "HEAD allowed",
944+
method: fiber.MethodHead,
945+
secFetchSite: "cross-site",
946+
expectedStatus: http.StatusOK,
947+
},
948+
{
949+
name: "OPTIONS allowed",
950+
method: fiber.MethodOptions,
951+
secFetchSite: "cross-site",
952+
expectedStatus: http.StatusOK,
953+
},
954+
{
955+
name: "PUT with mismatched origin blocked",
956+
method: fiber.MethodPut,
957+
secFetchSite: "cross-site",
958+
origin: "https://attacker.example",
959+
expectedStatus: http.StatusForbidden,
960+
},
961+
}
962+
963+
for _, tt := range tests {
964+
t.Run(tt.name, func(t *testing.T) {
965+
t.Parallel()
966+
c := &fasthttp.RequestCtx{}
967+
scheme := "http"
968+
if tt.https {
969+
scheme = "https"
970+
}
971+
c.Request.Header.SetMethod(tt.method)
972+
c.Request.URI().SetScheme(scheme)
973+
c.Request.URI().SetHost("example.com")
974+
c.Request.Header.SetHost("example.com")
975+
c.Request.Header.SetProtocol(scheme)
976+
if scheme == "https" {
977+
c.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
978+
}
979+
if tt.origin != "" {
980+
c.Request.Header.Set(fiber.HeaderOrigin, tt.origin)
981+
}
982+
if tt.secFetchSite != "" {
983+
c.Request.Header.Set(fiber.HeaderSecFetchSite, tt.secFetchSite)
984+
}
985+
986+
safe := tt.method == fiber.MethodGet || tt.method == fiber.MethodHead || tt.method == fiber.MethodOptions || tt.method == fiber.MethodTrace
987+
988+
if !safe {
989+
c.Request.Header.Set(HeaderName, token)
990+
c.Request.Header.SetCookie(ConfigDefault.CookieName, token)
991+
}
992+
993+
h(c)
994+
require.Equal(t, int(tt.expectedStatus), c.Response.StatusCode())
995+
if tt.expectFetchSiteInvalid {
996+
require.Equal(t, ErrFetchSiteInvalid.Error(), string(c.Response.Body()))
997+
}
998+
})
999+
}
1000+
}
1001+
8261002
func Test_CSRF_Origin(t *testing.T) {
8271003
t.Parallel()
8281004
app := fiber.New()

0 commit comments

Comments
 (0)