diff --git a/ebpf/bpfwrapper/eventCallbacks.go b/ebpf/bpfwrapper/eventCallbacks.go index 22082b28..355e1249 100644 --- a/ebpf/bpfwrapper/eventCallbacks.go +++ b/ebpf/bpfwrapper/eventCallbacks.go @@ -140,7 +140,7 @@ func SocketDataEventCallback(inputChan chan []byte, connectionFactory *connectio bytesSent := event.Attr.Bytes_sent // The 4 bytes are being lost in padding, thus, not taking them into consideration. - eventAttributesLogicalSize := 45 + eventAttributesLogicalSize := 53 if len(data) > eventAttributesLogicalSize { copy(event.Msg[:], data[eventAttributesLogicalSize:eventAttributesLogicalSize+int(utils.Abs(bytesSent))]) diff --git a/ebpf/connections/factory.go b/ebpf/connections/factory.go index ac762cd8..07bdafd4 100644 --- a/ebpf/connections/factory.go +++ b/ebpf/connections/factory.go @@ -17,6 +17,7 @@ import ( ) var httpBytes = []byte("HTTP") +var http2Preface = []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") // Factory is a routine-safe container that holds a trackers with unique ID, and able to create new tracker. type Factory struct { @@ -82,6 +83,12 @@ var ( trackerDataProcessInterval = 100 ) +const ( + protocolUnknown = "UNKN" + protocolhttp1 = "HTTP1" + protocolhttp2 = "HTTP2" +) + func init() { utils.InitVar("TRAFFIC_DISABLE_EGRESS", &disableEgress) utils.InitVar("TRAFFIC_MAX_ACTIVE_CONN", &maxActiveConnections) @@ -91,6 +98,14 @@ func init() { utils.InitVar("TRACKER_DATA_PROCESS_INTERVAL", &trackerDataProcessInterval) } +func hasHTTPResponse(buffer []byte) bool { + return len(buffer) >= len(httpBytes) && bytes.Equal(buffer[:len(httpBytes)], httpBytes) +} + +func hasHTTP2Preface(buffer []byte) bool { + return len(buffer) >= len(http2Preface) && bytes.Equal(buffer[:len(http2Preface)], http2Preface) +} + func ProcessTrackerData(connID structs.ConnID, tracker *Tracker, isComplete bool) { tracker.mutex.Lock() defer tracker.mutex.Unlock() @@ -120,14 +135,14 @@ func ProcessTrackerData(connID structs.ConnID, tracker *Tracker, isComplete bool hostName = kafkaUtil.PodInformerInstance.GetPodNameByProcessId(int32(connID.Id >> 32)) } - if len(sentBuffer) >= len(httpBytes) && (bytes.Equal(sentBuffer[:len(httpBytes)], httpBytes)) { - tryReadFromBD(destIpStr, srcIpStr, receiveBuffer, sentBuffer, isComplete, 1, connID.Id, connID.Fd, uniqueDaemonsetId, hostName) + protocol := tracker.protocol + + if hasHTTPResponse(sentBuffer) || hasHTTP2Preface(receiveBuffer) { + tryReadFromBD(destIpStr, srcIpStr, receiveBuffer, sentBuffer, isComplete, 1, connID.Id, connID.Fd, uniqueDaemonsetId, hostName, protocol) } - if !disableEgress { - // attempt to parse the egress as well by switching the recv and sent buffers. - if len(receiveBuffer) >= len(httpBytes) && (bytes.Equal(receiveBuffer[:len(httpBytes)], httpBytes)) { - tryReadFromBD(srcIpStr, destIpStr, sentBuffer, receiveBuffer, isComplete, 2, connID.Id, connID.Fd, uniqueDaemonsetId, hostName) - } + + if !disableEgress && (hasHTTPResponse(receiveBuffer) || hasHTTP2Preface(sentBuffer)) { + tryReadFromBD(srcIpStr, destIpStr, sentBuffer, receiveBuffer, isComplete, 2, connID.Id, connID.Fd, uniqueDaemonsetId, hostName, protocol) } } diff --git a/ebpf/connections/parser.go b/ebpf/connections/parser.go index 65b4a69d..3dfd2f52 100644 --- a/ebpf/connections/parser.go +++ b/ebpf/connections/parser.go @@ -4,7 +4,7 @@ import ( "github.com/akto-api-security/mirroring-api-logging/trafficUtil/kafkaUtil" ) -func tryReadFromBD(ip string, destIp string, receiveBuffer []byte, sentBuffer []byte, isComplete bool, direction int, id uint64, fd uint32, daemonsetIdentifier, hostName string) { +func tryReadFromBD(ip string, destIp string, receiveBuffer []byte, sentBuffer []byte, isComplete bool, direction int, id uint64, fd uint32, daemonsetIdentifier, hostName string, protocol string) { ctx := kafkaUtil.TrafficContext{ SourceIP: ip, DestIP: destIp, @@ -17,6 +17,7 @@ func tryReadFromBD(ip string, destIp string, receiveBuffer []byte, sentBuffer [] SocketFD: fd, DaemonsetIdentifier: daemonsetIdentifier, HostName: hostName, + Protocol: protocol, } kafkaUtil.ParseAndProduce(receiveBuffer, sentBuffer, ctx) } diff --git a/ebpf/connections/tracker.go b/ebpf/connections/tracker.go index 0d8d8acb..9cfb46fe 100644 --- a/ebpf/connections/tracker.go +++ b/ebpf/connections/tracker.go @@ -30,6 +30,7 @@ type Tracker struct { srcPort uint16 foundHTTP bool + protocol string // "http1", "http2", or "unknown" } func NewTracker(connID structs.ConnID) *Tracker { @@ -40,6 +41,7 @@ func NewTracker(connID structs.ConnID) *Tracker { mutex: sync.RWMutex{}, ssl: false, foundHTTP: false, + protocol: protocolUnknown, } } @@ -58,6 +60,29 @@ func (conn *Tracker) AddDataEvent(event structs.SocketDataEvent) { conn.mutex.Lock() defer conn.mutex.Unlock() + // Update protocol from eBPF if it has changed from UNKN + protocolBytes := event.Attr.Protocol[:] + nullIndex := -1 + for i, b := range protocolBytes { + if b == 0 { + nullIndex = i + break + } + } + if nullIndex > 0 { + protocolStr := string(protocolBytes[:nullIndex]) + if protocolStr != protocolUnknown && conn.protocol != protocolStr { + switch protocolStr { + case protocolhttp1: + conn.protocol = protocolhttp1 + case protocolhttp2: + conn.protocol = protocolhttp2 + default: + conn.protocol = protocolUnknown + } + } + } + if !conn.ssl && event.Attr.Ssl { for k := range conn.sentBuf { conn.sentBuf[k] = []byte{} diff --git a/ebpf/go.mod b/ebpf/go.mod index f3162c1e..62c7e126 100644 --- a/ebpf/go.mod +++ b/ebpf/go.mod @@ -12,6 +12,7 @@ require ( ) require ( + github.com/akto-api-security/gomiddleware v0.1.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/ebpf/go.sum b/ebpf/go.sum index b48ff900..54c4c45b 100644 --- a/ebpf/go.sum +++ b/ebpf/go.sum @@ -1,3 +1,5 @@ +github.com/akto-api-security/gomiddleware v0.1.4 h1:jz3Umei5ItlyCBwHROdh9oRrLPhr0V6mwLT9vfvspc8= +github.com/akto-api-security/gomiddleware v0.1.4/go.mod h1:zDsxe1UTr+rGvHt6r1h+c8RkBBzy/A7iTMGoXiTZ5oI= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/ebpf/kernel/module.cc b/ebpf/kernel/module.cc index a17a1251..eaa9cbd7 100644 --- a/ebpf/kernel/module.cc +++ b/ebpf/kernel/module.cc @@ -49,6 +49,7 @@ struct conn_info_t { bool ssl; u32 readEventsCount; u32 writeEventsCount; + char protocol[8]; }; union sockaddr_t { @@ -87,6 +88,7 @@ struct socket_open_event_t { u32 src_ip; unsigned short src_port; u64 socket_open_ns; + char protocol[8]; // Protocol detected from payload: "HTTP1", "HTTP2", "UNKN" }; struct socket_close_event_t { @@ -108,6 +110,7 @@ struct socket_data_event_t { u32 readEventsCount; u32 writeEventsCount; bool ssl; + char protocol[8]; // Protocol detected: "HTTP1", "HTTP2", "UNKN" char msg[MAX_MSG_SIZE]; }; @@ -233,6 +236,9 @@ static __inline void process_syscall_accept(struct pt_regs* ret, const struct ac conn_info.readEventsCount = 0; conn_info.writeEventsCount = 0; + // Initialize protocol as unknown - will be detected from first data packet + __builtin_memcpy(conn_info.protocol, "UNKN", 5); + u32 tgid = id >> 32; u64 tgid_fd = 0; if(isConnect){ @@ -280,6 +286,7 @@ static __inline void process_syscall_accept(struct pt_regs* ret, const struct ac socket_open_event.ip = conn_info.ip; socket_open_event.src_ip = srcIp; socket_open_event.src_port = lport; + __builtin_memcpy(socket_open_event.protocol, conn_info.protocol, 8); if (PRINT_BPF_LOGS){ bpf_trace_printk("accept call: %llu %d %d", socket_open_event.id, socket_open_event.fd, isConnect); @@ -318,7 +325,42 @@ static __inline void process_syscall_close(struct pt_regs* ret, const struct clo socket_close_event.socket_close_ns = bpf_ktime_get_ns(); socket_close_events.perf_submit(ret, &socket_close_event, sizeof(struct socket_close_event_t)); - conn_info_map.delete(&tgid_fd); + conn_info_map.delete(&tgid_fd); +} + +static __inline void detect_protocol_from_data(struct conn_info_t *conn_info, const char *buf, size_t count) { + // Only detect if protocol is still unknown + if (conn_info->protocol[0] != 'U') { + return; + } + + if (count < 6) { + return; + } + + // HTTP/2 connection preface: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + if (buf[0] == 'P' && buf[1] == 'R' && buf[2] == 'I' && buf[3] == ' ' && buf[4] == '*') { + __builtin_memcpy(conn_info->protocol, "HTTP2", 6); + return; + } + + // HTTP/1.x request methods (verbs): GET, POST, PUT, DELETE, HEAD, PATCH + // GET /path HTTP/1.1 + if ((buf[0] == 'G' && buf[1] == 'E' && buf[2] == 'T' && buf[3] == ' ') || + (buf[0] == 'P' && buf[1] == 'O' && buf[2] == 'S' && buf[3] == 'T') || + (buf[0] == 'P' && buf[1] == 'U' && buf[2] == 'T' && buf[3] == ' ') || + (buf[0] == 'D' && buf[1] == 'E' && buf[2] == 'L' && buf[3] == 'E') || + (buf[0] == 'H' && buf[1] == 'E' && buf[2] == 'A' && buf[3] == 'D') || + (buf[0] == 'P' && buf[1] == 'A' && buf[2] == 'T' && buf[3] == 'C')) { + __builtin_memcpy(conn_info->protocol, "HTTP1", 6); + return; + } + + // HTTP/1.x response: "HTTP/1.0 200 OK" or "HTTP/1.1 200 OK" + if (buf[0] == 'H' && buf[1] == 'T' && buf[2] == 'T' && buf[3] == 'P' && buf[4] == '/') { + __builtin_memcpy(conn_info->protocol, "HTTP1", 6); + return; + } } static __inline void process_syscall_data(struct pt_regs* ret, const struct data_args_t* args, u64 id, bool is_send, bool ssl) { @@ -370,8 +412,9 @@ static __inline void process_syscall_data(struct pt_regs* ret, const struct data socket_data_event->fd = conn_info->fd; socket_data_event->conn_start_ns = conn_info->conn_start_ns; socket_data_event->port = conn_info->port; - socket_data_event->ip = conn_info->ip; + socket_data_event->ip = conn_info->ip; socket_data_event->ssl = conn_info->ssl; + __builtin_memcpy(socket_data_event->protocol, conn_info->protocol, 8); int bytes_sent = 0; size_t size_to_save = 0; @@ -401,6 +444,9 @@ static __inline void process_syscall_data(struct pt_regs* ret, const struct data size_to_save = MAX_MSG_SIZE; } + // Detect protocol from first packet payload + detect_protocol_from_data(conn_info, socket_data_event->msg, size_to_save); + if (is_send){ conn_info->writeEventsCount = (conn_info->writeEventsCount) + 1u; } else { diff --git a/ebpf/structs/structs.go b/ebpf/structs/structs.go index 3a88bd09..342f0d36 100644 --- a/ebpf/structs/structs.go +++ b/ebpf/structs/structs.go @@ -18,6 +18,7 @@ type SocketDataEventAttr struct { ReadEventsCount uint32 WriteEventsCount uint32 Ssl bool + Protocol [8]byte } /* @@ -46,6 +47,7 @@ type SocketOpenEvent struct { SrcPort uint16 Padding [2]byte Socket_open_ns uint64 + Protocol [8]byte } type SocketCloseEvent struct { diff --git a/trafficUtil/kafkaUtil/parser.go b/trafficUtil/kafkaUtil/parser.go index 811baead..b2d9183b 100644 --- a/trafficUtil/kafkaUtil/parser.go +++ b/trafficUtil/kafkaUtil/parser.go @@ -10,11 +10,13 @@ import ( "io" "log/slog" "net/http" + "net/url" "os" "strings" "sync" "time" + http2parser "github.com/akto-api-security/gomiddleware/http2parser" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/apiProcessor" trafficpb "github.com/akto-api-security/mirroring-api-logging/trafficUtil/protobuf/traffic_payload" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/trafficMetrics" @@ -35,6 +37,7 @@ type TrafficContext struct { SocketFD uint32 DaemonsetIdentifier string HostName string + Protocol string } // ParsedTraffic holds the parsed HTTP requests and responses with their bodies. @@ -242,6 +245,10 @@ var ( const ONE_MINUTE = 60 +const ( + protocolhttp2 = "HTTP2" +) + func init() { utils.InitVar("DEBUG_MODE", &debugMode) utils.InitVar("OUTPUT_BANDWIDTH_LIMIT", &outputBandwidthLimitPerMin) @@ -465,10 +472,17 @@ func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, ctx TrafficContext shouldPrint := debugMode && strings.Contains(string(receiveBuffer), "x-debug-token") if shouldPrint { - slog.Debug("ParseAndProduce", "receiveBuffer", string(receiveBuffer), "sentBuffer", string(sentBuffer)) + slog.Debug("ParseAndProduce", "receiveBuffer", string(receiveBuffer), "sentBuffer", string(sentBuffer), "protocol", ctx.Protocol) } - parsed := parseHTTPTraffic(receiveBuffer, sentBuffer, shouldPrint) + // Parse based on protocol + var parsed *ParsedTraffic + if ctx.Protocol == protocolhttp2 { + // slog.Debug("Using HTTP/2 parser", "sourceIp", ctx.SourceIP, "destIp", ctx.DestIP) + parsed = parseHTTP2Traffic(receiveBuffer, sentBuffer, ctx, shouldPrint) + } else { + parsed = parseHTTPTraffic(receiveBuffer, sentBuffer, shouldPrint) + } if parsed == nil { return } @@ -501,10 +515,10 @@ func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, ctx TrafficContext for i := 0; i < len(requests); i++ { req := &requests[i] resp := &responses[i] - + url := req.URL.String() checkDebugUrlAndPrint(url, req.Host, "URL,host found in ParseAndProduce") - + // Convert headers in a single pass (both protobuf and string map formats) headers := convertHeaders(req, resp, shouldPrint) @@ -590,3 +604,146 @@ func sendMetrics(headers ConvertedHeaders, ctx TrafficContext, outgoingBytes int } } } + +func parseHTTP2Traffic(receiveBuffer []byte, sentBuffer []byte, ctx TrafficContext, shouldPrint bool) *ParsedTraffic { + http2Preface := []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + if len(receiveBuffer) >= len(http2Preface) && bytes.Equal(receiveBuffer[:len(http2Preface)], http2Preface) { + receiveBuffer = receiveBuffer[len(http2Preface):] + if shouldPrint { + slog.Debug("Skipped HTTP/2 connection preface from receive buffer") + } + } + if len(sentBuffer) >= len(http2Preface) && bytes.Equal(sentBuffer[:len(http2Preface)], http2Preface) { + sentBuffer = sentBuffer[len(http2Preface):] + if shouldPrint { + slog.Debug("Skipped HTTP/2 connection preface from sent buffer") + } + } + + opts := http2parser.NewParseOptions( + http2parser.WithBase64Encoding(true), + http2parser.WithWaitForEndStream(true), + http2parser.WithGRPCTrailers(true), + ) + streams := make(map[uint32]*http2parser.HTTP2Stream) + + err := http2parser.ParseHTTP2Frames(receiveBuffer, streams, true, opts) + if err != nil { + slog.Debug("Error parsing HTTP/2 requests", "error", err) + } + + err = http2parser.ParseHTTP2Frames(sentBuffer, streams, false, opts) + if err != nil { + slog.Debug("Error parsing HTTP/2 responses", "error", err) + } + + // Convert HTTP/2 streams to ParsedTraffic + parsed := &ParsedTraffic{ + Requests: []http.Request{}, + RequestBodies: []string{}, + Responses: []http.Response{}, + ResponseBodies: []string{}, + } + + for streamID, stream := range streams { + if !stream.RequestComplete || !stream.ResponseComplete { + if shouldPrint { + slog.Debug("Incomplete HTTP/2 stream", "streamID", streamID, "requestComplete", stream.RequestComplete, "responseComplete", stream.ResponseComplete) + } + continue + } + + req, err := convertHTTP2ToRequest(stream) + if err != nil { + if shouldPrint { + slog.Debug("Failed to convert HTTP/2 stream to request", "streamID", streamID, "error", err) + } + continue + } + + resp := convertHTTP2ToResponse(stream) + + parsed.Requests = append(parsed.Requests, *req) + parsed.RequestBodies = append(parsed.RequestBodies, string(stream.RequestBody)) + parsed.Responses = append(parsed.Responses, *resp) + parsed.ResponseBodies = append(parsed.ResponseBodies, string(stream.ResponseBody)) + } + + if shouldPrint { + slog.Debug("Parsed HTTP/2 traffic", "streamCount", len(parsed.Requests)) + } + + return parsed +} + +func convertHTTP2ToRequest(stream *http2parser.HTTP2Stream) (*http.Request, error) { + method := stream.Method + if method == "" { + method = "GET" + } + + path := stream.Path + if path == "" { + path = "/" + } + + scheme := stream.RequestHeaders[":scheme"] + if scheme == "" { + scheme = "https" + } + + authority := stream.RequestHeaders[":authority"] + if authority == "" { + authority = stream.RequestHeaders["host"] + } + + urlStr := scheme + "://" + authority + path + parsedURL, err := url.Parse(urlStr) + if err != nil { + return nil, fmt.Errorf("failed to parse URL %s: %w", urlStr, err) + } + + header := make(http.Header) + for name, value := range stream.RequestHeaders { + if !strings.HasPrefix(name, ":") { + header.Set(name, value) + } + } + if authority != "" { + header.Set("Host", authority) + } + + proto := "HTTP/2.0" + if stream.IsGRPC { + proto = "gRPC" + } + + return &http.Request{ + Method: method, + URL: parsedURL, + Proto: proto, + Header: header, + Host: authority, + }, nil +} + +func convertHTTP2ToResponse(stream *http2parser.HTTP2Stream) *http.Response { + header := make(http.Header) + for name, value := range stream.ResponseHeaders { + if !strings.HasPrefix(name, ":") { + header.Set(name, value) + } + } + + proto := "HTTP/2.0" + if stream.IsGRPC { + proto = "gRPC" + } + + return &http.Response{ + StatusCode: stream.StatusCode, + Status: stream.Status, + Proto: proto, + Header: header, + } +}