Skip to content

Commit 7486a46

Browse files
committed
feat: add ability to skip validation checks
As noted in #63, an equivalent to Echo's `Skipper` would allow for middleware users to opt-out of validation in a more straightforward way. In a slightly different implementation to our `echo-middleware`, this does not allow the `Skipper` to consume the body of the original request, and instead duplicates it for the `Skipper`, and the other uses of it.
1 parent 229a92f commit 7486a46

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

oapi_validate.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
package nethttpmiddleware
99

1010
import (
11+
"bytes"
1112
"context"
1213
"errors"
1314
"fmt"
15+
"io"
1416
"log"
1517
"net/http"
1618
"strings"
@@ -74,6 +76,11 @@ type ErrorHandlerOptsMatchedRoute struct {
7476
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
7577
type MultiErrorHandler func(openapi3.MultiError) (int, error)
7678

79+
// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware
80+
//
81+
// Return `true` if the request should be skipped
82+
type Skipper func(r *http.Request) bool
83+
7784
// Options allows configuring the OapiRequestValidator.
7885
type Options struct {
7986
// Options contains any configuration for the underlying `openapi3filter`
@@ -100,6 +107,9 @@ type Options struct {
100107
SilenceServersWarning bool
101108
// DoNotValidateServers ensures that there is no Host validation performed (see `SilenceServersWarning` and https://github.com/deepmap/oapi-codegen/issues/882 for more details)
102109
DoNotValidateServers bool
110+
111+
// Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware
112+
Skipper Skipper
103113
}
104114

105115
// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
@@ -126,6 +136,15 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
126136

127137
return func(next http.Handler) http.Handler {
128138
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
139+
if options != nil && options.Skipper != nil {
140+
r2, err := copyHTTPRequest(r)
141+
if err == nil && options.Skipper(r2) {
142+
// serve with the original request
143+
next.ServeHTTP(w, r)
144+
return
145+
}
146+
}
147+
129148
if options == nil {
130149
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
131150
} else if options.ErrorHandlerWithOpts != nil {
@@ -141,6 +160,22 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
141160

142161
}
143162

163+
func copyHTTPRequest(r *http.Request) (*http.Request, error) {
164+
r2 := r.Clone(r.Context())
165+
166+
if r.Body != nil {
167+
bodyBytes, err := io.ReadAll(r.Body)
168+
if err != nil {
169+
return nil, err
170+
}
171+
// keep the original request body available
172+
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
173+
// and have it available for the copy
174+
r2.Body = io.NopCloser(bytes.NewReader(bodyBytes))
175+
}
176+
return r2, nil
177+
}
178+
144179
func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) {
145180
// validate request
146181
statusCode, err := validateRequest(r, router, options)

oapi_validate_example_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,145 @@ paths:
839839
// Received an HTTP 400 response. Expected HTTP 400
840840
// Response body: There was a bad request
841841
}
842+
843+
func ExampleOapiRequestValidatorWithOptions_withSkipper() {
844+
rawSpec := `
845+
openapi: "3.0.0"
846+
info:
847+
version: 1.0.0
848+
title: TestServer
849+
servers:
850+
- url: http://example.com/
851+
paths:
852+
# we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests
853+
/resource:
854+
post:
855+
operationId: createResource
856+
responses:
857+
'204':
858+
description: No content
859+
requestBody:
860+
required: true
861+
content:
862+
text/plain: {}
863+
`
864+
865+
must := func(err error) {
866+
if err != nil {
867+
panic(err)
868+
}
869+
}
870+
871+
use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler {
872+
var s http.Handler
873+
s = r
874+
875+
for _, mw := range middlewares {
876+
s = mw(s)
877+
}
878+
879+
return s
880+
}
881+
882+
logResponseBody := func(rr *httptest.ResponseRecorder) {
883+
if rr.Result().Body != nil {
884+
data, _ := io.ReadAll(rr.Result().Body)
885+
if len(data) > 0 {
886+
fmt.Printf("Response body: %s", data)
887+
}
888+
}
889+
}
890+
891+
spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec))
892+
must(err)
893+
894+
// NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec
895+
// See also: Options#SilenceServersWarning
896+
spec.Servers = nil
897+
898+
router := http.NewServeMux()
899+
900+
router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
901+
fmt.Printf("%s /resource was called\n", r.Method)
902+
903+
if r.Method == http.MethodPost {
904+
data, err := io.ReadAll(r.Body)
905+
if err != nil {
906+
w.WriteHeader(http.StatusInternalServerError)
907+
return
908+
}
909+
fmt.Printf("Request body: %s\n", data)
910+
w.WriteHeader(http.StatusNoContent)
911+
return
912+
}
913+
914+
w.WriteHeader(http.StatusMethodNotAllowed)
915+
})
916+
917+
router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
918+
w.WriteHeader(http.StatusOK)
919+
})
920+
921+
authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error {
922+
fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName)
923+
return fmt.Errorf("this check always fails - don't let anyone in!")
924+
}
925+
926+
skipperFunc := func(r *http.Request) bool {
927+
// always consume the request body, because we're not following best practices
928+
_, _ = io.ReadAll(r.Body)
929+
930+
// skip the undocumented healthcheck endpoint
931+
if r.URL.Path == "/healthz" {
932+
return true
933+
}
934+
935+
return false
936+
}
937+
938+
// create middleware
939+
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
940+
Options: openapi3filter.Options{
941+
AuthenticationFunc: authenticationFunc,
942+
},
943+
Skipper: skipperFunc,
944+
})
945+
946+
// then wire it in
947+
server := use(router, mw)
948+
949+
// ================================================================================
950+
fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated")
951+
952+
req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody)
953+
must(err)
954+
955+
rr := httptest.NewRecorder()
956+
957+
server.ServeHTTP(rr, req)
958+
959+
fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code)
960+
logResponseBody(rr)
961+
962+
// ================================================================================
963+
fmt.Println("# The skipper cannot consume the request body")
964+
965+
req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there")))
966+
must(err)
967+
req.Header.Set("Content-Type", "text/plain")
968+
969+
rr = httptest.NewRecorder()
970+
971+
server.ServeHTTP(rr, req)
972+
973+
fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code)
974+
logResponseBody(rr)
975+
976+
// Output:
977+
// # A request that is made to the undocumented healthcheck endpoint does not get validated
978+
// Received an HTTP 200 response. Expected HTTP 200
979+
// # The skipper cannot consume the request body
980+
// POST /resource was called
981+
// Request body: Hello there
982+
// Received an HTTP 204 response. Expected HTTP 204
983+
}

0 commit comments

Comments
 (0)