diff --git a/internal/transport/client_stream.go b/internal/transport/client_stream.go index cd8152ef13c7..961f35dc5de1 100644 --- a/internal/transport/client_stream.go +++ b/internal/transport/client_stream.go @@ -19,6 +19,7 @@ package transport import ( + "sync" "sync/atomic" "golang.org/x/net/http2" @@ -28,6 +29,12 @@ import ( "google.golang.org/grpc/status" ) +// nonGRPCDataMaxLen is the maximum length of nonGRPCDataBuf. +// +// NOTE: If changed this value, you MUST update the corresponding test in: +// - /test/end2end_test.go:TestHTTPServerSendsNonGRPCHeaderSurfaceFurtherData +const nonGRPCDataMaxLen = 1024 + // ClientStream implements streaming functionality for a gRPC client. type ClientStream struct { Stream // Embed for common stream functionality. @@ -46,7 +53,12 @@ type ClientStream struct { // headerValid indicates whether a valid header was received. Only // meaningful after headerChan is closed (always call waitOnHeader() before // reading its value). - headerValid bool + headerValid bool + + collectionMu sync.Mutex // protects nonGRPCStatus and nonGRPCDataBuf during the non-gRPC data collection lifecycle. + nonGRPCStatus *status.Status // the initial status from the non-gRPC response header, finalized with collected data before closing. + nonGRPCDataBuf []byte // stores the data of a non-gRPC response. + noHeaders bool // set if the client never received headers (set only after the stream is done). headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream @@ -54,6 +66,38 @@ type ClientStream struct { statsHandler stats.Handler // nil for internal streams (e.g., health check, ORCA) where telemetry is not supported. } +func (s *ClientStream) startNonGRPCDataCollection(st *status.Status) { + s.collectionMu.Lock() + defer s.collectionMu.Unlock() + if s.nonGRPCStatus != nil { + // If nonGRPCStatus is already set, it means the stream is already in + // the non-gRPC data collection lifecycle. + return + } + s.nonGRPCStatus = st + s.nonGRPCDataBuf = make([]byte, 0, nonGRPCDataMaxLen) +} + +// tryHandleNonGRPCData tries to collect non-gRPC body from the given data frame. +// It returns two booleans: +// handle indicates whether the frame should be handled as a non-gRPC response body, +// end indicates whether the stream should be closed after handling this frame. +func (s *ClientStream) tryHandleNonGRPCData(f *parsedDataFrame) (handle bool, end bool) { + s.collectionMu.Lock() + defer s.collectionMu.Unlock() + if s.nonGRPCStatus == nil { + // if not in the non-gRPC data collection lifecycle, do not handle this frame. + return false, false + } + + n := min(f.data.Len(), nonGRPCDataMaxLen-len(s.nonGRPCDataBuf)) + s.nonGRPCDataBuf = append(s.nonGRPCDataBuf, f.data.ReadOnlyData()[0:n]...) + if len(s.nonGRPCDataBuf) >= nonGRPCDataMaxLen || f.StreamEnded() { + return true, true + } + return true, false +} + // Read reads an n byte message from the input stream. func (s *ClientStream) Read(n int) (mem.BufferSlice, error) { b, err := s.Stream.read(n) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index d6bc6a6cc730..dbe83bfffe12 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -944,9 +944,21 @@ func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode <-s.done return } - // status and trailers can be updated here without any synchronization because the stream goroutine will - // only read it after it sees an io.EOF error from read or write and we'll write those errors - // only after updating this. + + // If the stream is in the non-gRPC data collection lifecycle, use the + // nonGRPCStatus and nonGRPCDataBuf to construct the final status and + // error to return to the user. This is to ensure that non-gRPC data + // collected is included in the final status message returned to the user. + s.collectionMu.Lock() + if s.nonGRPCStatus != nil { + data := "\ndata: " + strconv.Quote(string(s.nonGRPCDataBuf)) + st = status.New(s.nonGRPCStatus.Code(), s.nonGRPCStatus.Message()+data) + err = st.Err() + // Clear the nonGRPCStatus to indicate the non-grpc data collection is done. + s.nonGRPCStatus = nil + } + s.collectionMu.Unlock() + s.status = st if len(mdata) > 0 { s.trailer = mdata @@ -1224,6 +1236,23 @@ func (t *http2Client) handleData(f *parsedDataFrame) { t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) return } + + handle, end := s.tryHandleNonGRPCData(f) + if handle { + if w := s.fc.onRead(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: s.id, + increment: w, + }) + } + if end { + // Close the stream; closeStream will finalize the nonGRPCStatus and nonGRPCDataBuf, + // and provide them as err and st. + t.closeStream(s, nil, true, http2.ErrCodeProtocol, nil, nil, true) + } + return + } + dataLen := f.data.Len() if f.Header().Flags.Has(http2.FlagDataPadded) { if w := s.fc.onRead(size - uint32(dataLen)); w > 0 { @@ -1568,7 +1597,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } se := status.New(grpcErrorCode, strings.Join(errs, "; ")) - t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + if endStream { + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, true) + return + } + + s.startNonGRPCDataCollection(se) return } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 66c80387c03c..4ee7ecff6415 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -111,7 +111,8 @@ const ( notifyCall misbehaved encodingRequiredStatus - invalidHeaderField + invalidContentType + malformedHeader delayRead pingpong ) @@ -220,7 +221,7 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *ServerStream) s.Read(math.MaxInt) } -func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) { +func (h *testStreamHandler) handleStreamInvalidContentType(s *ServerStream) { headerFields := []hpack.HeaderField{} headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) h.t.controlBuf.put(&headerFrame{ @@ -230,6 +231,19 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) { }) } +func (h *testStreamHandler) handleStreamMalformedHeader(s *ServerStream) { + headerFields := []hpack.HeaderField{ + {Name: ":status", Value: "200"}, + {Name: "content-type", Value: "application/grpc"}, + {Name: "x-bad-bin", Value: "!!!invalid-base64!!!"}, + } + h.t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + }) +} + // handleStreamDelayRead delays reads so that the other side has to halt on // stream-level flow control. // This handler assumes dynamic flow control is turned off and assumes window @@ -425,12 +439,23 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT }) wg.Done() }() - case invalidHeaderField: + case invalidContentType: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { wg.Add(1) go func() { - h.handleStreamInvalidHeaderField(s) + h.handleStreamInvalidContentType(s) + wg.Done() + }() + }) + wg.Done() + }() + case malformedHeader: + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + wg.Add(1) + go func() { + h.handleStreamMalformedHeader(s) wg.Done() }() }) @@ -1638,8 +1663,8 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { s.Read(math.MaxInt) } -func (s) TestInvalidHeaderField(t *testing.T) { - server, ct, cancel := setUp(t, 0, invalidHeaderField) +func (s) TestInvalidContentType(t *testing.T) { + server, ct, cancel := setUp(t, 0, invalidContentType) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1660,8 +1685,32 @@ func (s) TestInvalidHeaderField(t *testing.T) { server.stop() } +func (s) TestHeaderChanClosedAfterReceivingNonGRPCResponse(t *testing.T) { + server, ct, cancel := setUp(t, 0, invalidContentType) + defer cancel() + defer server.stop() + defer ct.Close(fmt.Errorf("closed manually by test")) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}, nil) + if err != nil { + t.Fatalf("failed to create the stream") + } + // The server sends a non-gRPC response without ending the stream, so the + // stream enters data collection mode. headerChan is not closed until the + // stream itself closes. + if _, err := s.Header(); err == nil { + t.Fatalf("Header() succeeded, want error") + } + select { + case <-s.headerChan: + default: + t.Errorf("s.headerChan: got open, want closed") + } +} + func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { - server, ct, cancel := setUp(t, 0, invalidHeaderField) + server, ct, cancel := setUp(t, 0, malformedHeader) defer cancel() defer server.stop() defer ct.Close(fmt.Errorf("closed manually by test")) @@ -2685,6 +2734,7 @@ func (s) TestClientDecodeHeader(t *testing.T) { name string metaHeaderFrame *http2.MetaHeadersFrame wantStatus *status.Status + isNonGRPCStatus bool }{ { name: "valid_header", @@ -2708,6 +2758,7 @@ func (s) TestClientDecodeHeader(t *testing.T) { codes.Unknown, "unexpected HTTP status code received from server: 200 (OK); malformed header: missing HTTP content-type", ), + isNonGRPCStatus: true, }, { name: "invalid_grpc_status", @@ -2734,6 +2785,7 @@ func (s) TestClientDecodeHeader(t *testing.T) { codes.Internal, "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", ), + isNonGRPCStatus: true, }, { name: "invalid_content_type_with_http_status_504", @@ -2747,6 +2799,7 @@ func (s) TestClientDecodeHeader(t *testing.T) { codes.Unavailable, "unexpected HTTP status code received from server: 504 (Gateway Timeout); transport: received unexpected content-type \"application/json\"", ), + isNonGRPCStatus: true, }, { name: "http_fallback_and_invalid_http_status", @@ -2803,7 +2856,12 @@ func (s) TestClientDecodeHeader(t *testing.T) { } s.operateHeaders(tc.metaHeaderFrame) - got := cs.status + var got *status.Status + if tc.isNonGRPCStatus { + got = cs.nonGRPCStatus + } else { + got = cs.status + } want := tc.wantStatus if got.Code() != want.Code() || got.Message() != want.Message() { t.Errorf("operateHeaders(%v) got status %q, want %q", tc.metaHeaderFrame, got, want) diff --git a/test/end2end_test.go b/test/end2end_test.go index ddf6fd7b46e2..fb492fcb93b3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -34,6 +34,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -6798,6 +6799,118 @@ func (s) TestAuthorityHeader(t *testing.T) { } } +func (s) TestHTTPServerSendsNonGRPCHeaderSurfaceFurtherData(t *testing.T) { + const nonGRPCDataMaxLen = 1024 + tests := []struct { + name string + responses []httpServerResponse + wantCode codes.Code + wantErr string + }{ + { + name: "non-gRPC content-type without payload", + responses: []httpServerResponse{ + { + headers: [][]string{ + { + ":status", "200", + "content-type", "text/html", + }, + }, + // payload: nil + }, + }, + wantCode: codes.Unknown, + wantErr: `unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html" +data: ""`, + }, + { + name: "non-gRPC content-type with payload", + responses: []httpServerResponse{ + { + headers: [][]string{ + { + ":status", "200", + "content-type", "text/html", + }, + }, + payload: []byte(`Hello World`), + }, + }, + wantCode: codes.Unknown, + wantErr: `unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html" +data: "Hello World"`, + }, + { + name: "non-gRPC content-type with bytes payload length more than nonGRPCDataMaxLen", + responses: []httpServerResponse{ + { + headers: [][]string{ + { + ":status", "200", + "content-type", "text/html", + }, + }, + payload: bytes.Repeat([]byte("a"), nonGRPCDataMaxLen+1), + }, + }, + wantCode: codes.Unknown, + wantErr: `unexpected HTTP status code received from server: 200 (OK); transport: received unexpected content-type "text/html" +data: ` + strconv.Quote(strings.Repeat("a", nonGRPCDataMaxLen)), + }, + { + name: "content-type not provided", + responses: []httpServerResponse{ + { + headers: [][]string{{ + ":status", "502", + }}, + payload: []byte("hello"), + }, + }, + wantCode: codes.Unavailable, + wantErr: `unexpected HTTP status code received from server: 502 (Bad Gateway); malformed header: missing HTTP content-type +data: "hello"`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen() failed: %v", err) + } + defer lis.Close() + + hs := &httpServer{responses: test.responses} + hs.start(t, lis) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient() failed: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + _, err = client.EmptyCall(ctx, &testpb.Empty{}) + if err == nil { + t.Fatalf("EmptyCall() = nil; want non-nil error due to non-gRPC response") + } + + if got, want := status.Code(err), test.wantCode; got != want { + t.Fatalf("Unexpected error code: got %v, want %v\nfull error:\n%v", got, want, err) + } + + if got := status.Convert(err).Message(); got != test.wantErr { + t.Errorf("Unexpected error message: \ngot:\n%v\nwant:\n%v", got, test.wantErr) + } + }) + } +} + // wrapCloseListener tracks Accepts/Closes and maintains a counter of the // number of open connections. type wrapCloseListener struct {