diff --git a/runner/run.go b/runner/run.go index ae468b9..50c2fbc 100644 --- a/runner/run.go +++ b/runner/run.go @@ -362,7 +362,7 @@ func getRequestFromTest(testInput test.Input) *ftwhttp.Request { // If we use raw or encoded request, then we don't use other fields if raw != nil { - req = ftwhttp.NewRawRequest(raw, !*testInput.AutocompleteHeaders) + req = ftwhttp.NewRawRequest(raw, *testInput.AutocompleteHeaders) } else { rline := &ftwhttp.RequestLine{ Method: testInput.GetMethod(), @@ -373,7 +373,7 @@ func getRequestFromTest(testInput test.Input) *ftwhttp.Request { data := testInput.ParseData() // create a new request req = ftwhttp.NewRequest(rline, testInput.Headers, - data, !*testInput.AutocompleteHeaders) + data, *testInput.AutocompleteHeaders) } return req diff --git a/runner/run_test.go b/runner/run_test.go index 9b801f9..e05dfb2 100644 --- a/runner/run_test.go +++ b/runner/run_test.go @@ -209,7 +209,11 @@ func (s *runTestSuite) BeforeTest(_ string, name string) { } // get tests template from file tmpl, err := template.ParseFiles(fmt.Sprintf("testdata/%s.yaml", name)) - s.Require().NoError(err) + if err != nil { + log.Info().Msgf("No test data found for test %s, assuming that's ok", name) + return + } + // create a temporary file to hold the test testFileContents, err := os.CreateTemp("testdata", "mock-test-*.yaml") s.Require().NoError(err, "cannot create temporary file") @@ -347,3 +351,94 @@ func (s *runTestSuite) TestIgnoredTestsRun() { s.Require().NoError(err) s.Equal(res.Stats.TotalFailed(), 1, "Oops, test run failed!") } + +func (s *runTestSuite) TestGetRequestFromTestWithAutocompleteHeaders() { + boolean := true + method := "POST" + input := test.Input{ + AutocompleteHeaders: &boolean, + Method: &method, + Headers: ftwhttp.Header{}, + DestAddr: &s.dest.DestAddr, + Port: &s.dest.Port, + Protocol: &s.dest.Protocol, + } + request := getRequestFromTest(input) + + client, err := ftwhttp.NewClient(ftwhttp.NewClientConfig()) + s.Require().NoError(err) + + dest := &ftwhttp.Destination{ + DestAddr: input.GetDestAddr(), + Port: input.GetPort(), + Protocol: input.GetProtocol(), + } + err = client.NewConnection(*dest) + s.Require().NoError(err) + _, err = client.Do(*request) + s.Require().NoError(err) + + s.Equal("0", request.Headers().Get("Content-Length"), "Autocompletion should add 'Content-Length' header to POST requests") + s.Equal("close", request.Headers().Get("Connection"), "Autocompletion should add 'Connection: close' header") +} + +func (s *runTestSuite) TestGetRawRequestFromTestWithAutocompleteHeaders() { + boolean := true + method := "POST" + input := test.Input{ + AutocompleteHeaders: &boolean, + Method: &method, + Headers: ftwhttp.Header{}, + DestAddr: &s.dest.DestAddr, + Port: &s.dest.Port, + Protocol: &s.dest.Protocol, + RAWRequest: "POST / HTTP/1.1\r\nHost: localhost\r\nUser-Agent: test\r\n\r\n", + } + request := getRequestFromTest(input) + + client, err := ftwhttp.NewClient(ftwhttp.NewClientConfig()) + s.Require().NoError(err) + + dest := &ftwhttp.Destination{ + DestAddr: input.GetDestAddr(), + Port: input.GetPort(), + Protocol: input.GetProtocol(), + } + err = client.NewConnection(*dest) + s.Require().NoError(err) + _, err = client.Do(*request) + s.Require().NoError(err) + + s.Equal("", request.Headers().Get("Content-Length"), "Raw requests should not be modified") + s.Equal("", request.Headers().Get("Connection"), "Raw requests should not be modified") +} + +func (s *runTestSuite) TestGetRequestFromTestWithoutAutocompleteHeaders() { + boolean := false + method := "POST" + input := test.Input{ + AutocompleteHeaders: &boolean, + Method: &method, + Headers: ftwhttp.Header{}, + DestAddr: &s.dest.DestAddr, + Port: &s.dest.Port, + Protocol: &s.dest.Protocol, + } + request := getRequestFromTest(input) + + client, err := ftwhttp.NewClient(ftwhttp.NewClientConfig()) + s.Require().NoError(err) + + dest := &ftwhttp.Destination{ + DestAddr: input.GetDestAddr(), + Port: input.GetPort(), + Protocol: input.GetProtocol(), + } + err = client.NewConnection(*dest) + s.Require().NoError(err) + _, err = client.Do(*request) + s.Require().NoError(err) + + s.Equal("", request.Headers().Get("Content-Length"), "Autocompletion is disabled") + s.Equal("", request.Headers().Get("Connection"), "Autocompletion is disabled") +}