diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bf4df63e..99a8de03 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,7 +36,16 @@ on: options: - legacy - ebpf - default: legacy + default: legacy + Architecture: + description: "The target architecture(s) for the Docker image." + required: true + type: choice + options: + - both + - arm64 + - amd64 + default: both # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: @@ -82,11 +91,19 @@ jobs: ECR_REPOSITORY: akto-api-security REGISTRY_ALIAS: p7q3h0z2 IMAGE_TAG: ${{ github.event.inputs.Tag }} + ARCH_INPUT: ${{ github.event.inputs.Architecture }} run: | # Build a docker container and push it to DockerHub docker buildx create --use - echo "Building and Pushing image to ECR..." - docker buildx build --platform linux/arm64/v8,linux/amd64 -t $ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG . --push + if [ "$ARCH_INPUT" == "arm64" ]; then + PLATFORM="linux/arm64/v8" + elif [ "$ARCH_INPUT" == "amd64" ]; then + PLATFORM="linux/amd64" + else + PLATFORM="linux/arm64/v8,linux/amd64" + fi + echo "Building and Pushing image to ECR with platform: $PLATFORM" + docker buildx build --platform $PLATFORM -t $ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG . --push echo "::set-output name=image::$ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG" - name: Build, tag, and push the image to Amazon ECR -ebpf @@ -97,11 +114,19 @@ jobs: ECR_REPOSITORY: akto-api-security REGISTRY_ALIAS: p7q3h0z2 IMAGE_TAG: ${{ github.event.inputs.EbpfTag }} + ARCH_INPUT: ${{ github.event.inputs.Architecture }} run: | # Build a docker container and push it to DockerHub docker buildx create --use - echo "Building and Pushing image to ECR..." - docker buildx build --platform linux/arm64/v8,linux/amd64 -t $ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG -f Dockerfile.eBPF . --push + if [ "$ARCH_INPUT" == "arm64" ]; then + PLATFORM="linux/arm64/v8" + elif [ "$ARCH_INPUT" == "amd64" ]; then + PLATFORM="linux/amd64" + else + PLATFORM="linux/arm64/v8,linux/amd64" + fi + echo "Building and Pushing image to ECR with platform: $PLATFORM" + docker buildx build --platform $PLATFORM -t $ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG -f Dockerfile.eBPF . --push echo "::set-output name=image::$ECR_REGISTRY/$REGISTRY_ALIAS/mirror-api-logging:$IMAGE_TAG" build-docker: @@ -136,11 +161,19 @@ jobs: env: ECR_REGISTRY: aktosecurity IMAGE_TAG: ${{ github.event.inputs.Tag }} + ARCH_INPUT: ${{ github.event.inputs.Architecture }} run: | # Build a docker container and push it to DockerHub docker buildx create --use - echo "Building and Pushing image to DockerHub..." - docker buildx build --platform linux/arm64/v8,linux/amd64 -t $ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG . --push + if [ "$ARCH_INPUT" == "arm64" ]; then + PLATFORM="linux/arm64/v8" + elif [ "$ARCH_INPUT" == "amd64" ]; then + PLATFORM="linux/amd64" + else + PLATFORM="linux/arm64/v8,linux/amd64" + fi + echo "Building and Pushing image to DockerHub with platform: $PLATFORM" + docker buildx build --platform $PLATFORM -t $ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG . --push echo "::set-output name=image::$ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG" - name: Build, tag, and push the image to DockerHub - ebpf @@ -149,9 +182,17 @@ jobs: env: ECR_REGISTRY: aktosecurity IMAGE_TAG: ${{ github.event.inputs.EbpfTag }} + ARCH_INPUT: ${{ github.event.inputs.Architecture }} run: | # Build a docker container and push it to DockerHub docker buildx create --use - echo "Building and Pushing image to DockerHub..." - docker buildx build --platform linux/arm64/v8,linux/amd64 -t $ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG -f Dockerfile.eBPF . --push + if [ "$ARCH_INPUT" == "arm64" ]; then + PLATFORM="linux/arm64/v8" + elif [ "$ARCH_INPUT" == "amd64" ]; then + PLATFORM="linux/amd64" + else + PLATFORM="linux/arm64/v8,linux/amd64" + fi + echo "Building and Pushing image to DockerHub with platform: $PLATFORM" + docker buildx build --platform $PLATFORM -t $ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG -f Dockerfile.eBPF . --push echo "::set-output name=image::$ECR_REGISTRY/mirror-api-logging:$IMAGE_TAG" diff --git a/.gitignore b/.gitignore index 7c1c098f..b872b0c3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ mirroring-api-logging .idea/ **/.vscode/ temp -**temp \ No newline at end of file +**temp +data-* \ No newline at end of file diff --git a/Dockerfile.eBPF b/Dockerfile.eBPF index 96884c74..29d329d1 100644 --- a/Dockerfile.eBPF +++ b/Dockerfile.eBPF @@ -1,4 +1,4 @@ -FROM alpine:3.21 AS base +FROM alpine:3.22 AS base USER root RUN apk add bcc-tools bcc-dev bcc-doc linux-headers build-base @@ -6,7 +6,7 @@ RUN apk add bcc-tools bcc-dev bcc-doc linux-headers build-base FROM base AS builder # Install Go based on architecture -ARG GO_VERSION=1.24.3 +ARG GO_VERSION=1.25.8 ARG TARGETARCH RUN if [ "$TARGETARCH" = "arm64" ]; then \ diff --git a/ebpf-run.sh b/ebpf-run.sh index a85566c8..f549bf68 100644 --- a/ebpf-run.sh +++ b/ebpf-run.sh @@ -1,8 +1,13 @@ #!/bin/sh -LOG_FILE="/tmp/dump.log" +LOG_FILE=${LOG_FILE:-/tmp/dump.log} MAX_LOG_SIZE=${MAX_LOG_SIZE:-10485760} # Default to 10 MB if not set (10 MB = 10 * 1024 * 1024 bytes) -CHECK_INTERVAL=60 # Check interval in seconds +CHECK_INTERVAL=${CHECK_INTERVAL:-60} +CHECK_INTERVAL_MEM=${CHECK_INTERVAL_MEM:-5} # Check interval in seconds (configurable via env) +MEMORY_THRESHOLD=${MEMORY_THRESHOLD:-85} # Kill process at this % memory usage (configurable via env) +GOMEMLIMIT_PERCENT=${GOMEMLIMIT_PERCENT:-60} # GOMEMLIMIT as % of container memory limit (configurable via env) +AKTO_SUPPRESS_TRACE=${AKTO_SUPPRESS_TRACE:-true} +CRASH_RESTART_BACKOFF_SECONDS=${CRASH_RESTART_BACKOFF_SECONDS:-10} # Function to rotate the log file rotate_log() { @@ -14,6 +19,39 @@ rotate_log() { fi } +# Function to check memory usage and kill process if threshold exceeded +check_memory_and_kill() { + # Resolve container's cgroup path (needed when hostPID: true shifts cgroup root) + CGROUP_BASE=$(cut -d: -f3 /proc/self/cgroup | head -1) + + # Get current memory usage in bytes + if [ -f "/sys/fs/cgroup${CGROUP_BASE}/memory.current" ]; then + # cgroup v2 with hostPID + CURRENT_MEM=$(cat "/sys/fs/cgroup${CGROUP_BASE}/memory.current") + elif [ -f /sys/fs/cgroup/memory.current ]; then + # cgroup v2 normal + CURRENT_MEM=$(cat /sys/fs/cgroup/memory.current) + elif [ -f "/sys/fs/cgroup${CGROUP_BASE}/memory.usage_in_bytes" ]; then + # cgroup v1 with hostPID + CURRENT_MEM=$(cat "/sys/fs/cgroup${CGROUP_BASE}/memory.usage_in_bytes") + elif [ -f /sys/fs/cgroup/memory/memory.usage_in_bytes ]; then + # cgroup v1 normal + CURRENT_MEM=$(cat /sys/fs/cgroup/memory/memory.usage_in_bytes) + else + return + fi + + # Calculate percentage used + PERCENT_USED=$((CURRENT_MEM * 100 / MEM_LIMIT_BYTES)) + + echo "Memory usage: ${PERCENT_USED}% (${CURRENT_MEM} / ${MEM_LIMIT_BYTES} bytes)" + + if [ "$PERCENT_USED" -ge "$MEMORY_THRESHOLD" ]; then + echo "Memory threshold ${MEMORY_THRESHOLD}% exceeded (${PERCENT_USED}%), killing ebpf-logging process" + pkill -9 ebpf-logging + fi +} + # Start monitoring in the background if [[ "${ENABLE_LOGS}" == "false" ]]; then while true; do @@ -22,12 +60,128 @@ if [[ "${ENABLE_LOGS}" == "false" ]]; then done & fi -while : -do -if [[ "${ENABLE_LOGS}" == "false" ]]; then - ./ebpf-logging >> "$LOG_FILE" 2>&1 +# 1. Check if MEM_LIMIT is provided as env variable +if [ -z "$MEM_LIMIT" ]; then + # Resolve container's cgroup path (needed when hostPID: true shifts cgroup root) + CGROUP_BASE=$(cut -d: -f3 /proc/self/cgroup | head -1) + + # Not provided, detect and read cgroup memory limits + if [ -f "/sys/fs/cgroup${CGROUP_BASE}/memory.max" ]; then + # cgroup v2 with hostPID + MEM_LIMIT_BYTES=$(cat "/sys/fs/cgroup${CGROUP_BASE}/memory.max") + elif [ -f /sys/fs/cgroup/memory.max ]; then + # cgroup v2 normal + MEM_LIMIT_BYTES=$(cat /sys/fs/cgroup/memory.max) + elif [ -f "/sys/fs/cgroup${CGROUP_BASE}/memory.limit_in_bytes" ]; then + # cgroup v1 with hostPID + MEM_LIMIT_BYTES=$(cat "/sys/fs/cgroup${CGROUP_BASE}/memory.limit_in_bytes") + elif [ -f /sys/fs/cgroup/memory/memory.limit_in_bytes ]; then + # cgroup v1 normal + MEM_LIMIT_BYTES=$(cat /sys/fs/cgroup/memory/memory.limit_in_bytes) + else + # Fallback to free -b (bytes) if cgroup file not found + echo "Neither cgroup v2 nor v1 memory file found, defaulting to free -b" + MEM_LIMIT_BYTES=$(free -b | awk '/Mem:/ {print $2}') + fi + + # 2. Handle edge cases: "max" (cgroup v2) or 9223372036854775807 (cgroup v1 INT64_MAX) mean no limit + if [ "$MEM_LIMIT_BYTES" = "max" ] || [ "$MEM_LIMIT_BYTES" = "9223372036854775807" ]; then + echo "Cgroup memory limit is unlimited, defaulting to free memory" + MEM_LIMIT_BYTES=$(free -b | awk '/Mem:/ {print $2}') + fi + + # 3. Convert the memory limit from bytes to MB (integer division) + MEM_LIMIT_MB=$((MEM_LIMIT_BYTES / 1024 / 1024)) else - ./ebpf-logging + # MEM_LIMIT provided as env variable, treat as MB + echo "Using MEM_LIMIT from environment variable: ${MEM_LIMIT} MB" + MEM_LIMIT_MB=$MEM_LIMIT + # Convert MB to bytes for calculations + MEM_LIMIT_BYTES=$((MEM_LIMIT * 1024 * 1024)) fi - sleep 2 + +echo "Using container memory limit: ${MEM_LIMIT_MB} MB" + +# Set GOMEMLIMIT for the Go process +GOMEMLIMIT_MB=$((MEM_LIMIT_MB * GOMEMLIMIT_PERCENT / 100)) +export GOMEMLIMIT="${GOMEMLIMIT_MB}MiB" +echo "Setting GOMEMLIMIT to: ${GOMEMLIMIT} (${GOMEMLIMIT_PERCENT}% of ${MEM_LIMIT_MB} MB)" + +# ENABLE_LOGS (same intent as always): +# false -> append ebpf stdout+stderr to LOG_FILE (2>&1), not primary container streams. +# true (or anything else) -> ebpf inherits container stdout/stderr (kubectl logs). +# AKTO_SUPPRESS_TRACE=true -> optional stderr-only SIGSEGV/cgo filter; when off, file mode matches legacy exactly. +run_ebpf_once() { + log_to_file=false + [[ "${ENABLE_LOGS}" == "false" ]] && log_to_file=true + + # Legacy path: single merged stream, no FIFO. + if [ "${AKTO_SUPPRESS_TRACE}" != "true" ]; then + if [ "$log_to_file" = "true" ]; then + ./ebpf-logging >> "$LOG_FILE" 2>&1 + else + ./ebpf-logging + fi + return $? + fi + + # Filter path: stderr only through awk; stdout unchanged. FIFO connects ebpf stderr -> awk reader. + ERRPIPE="/tmp/ebpf-stderr-$$" + rm -f "$ERRPIPE" + if ! mkfifo "$ERRPIPE"; then + return 1 + fi + + to_logfile=0 + [ "$log_to_file" = "true" ] && to_logfile=1 + + awk -v to_logfile="$to_logfile" -v logf="$LOG_FILE" ' + BEGIN { quiet = 0 } + /^SIGSEGV:/ || /^signal arrived during cgo execution/ { + if (!quiet) { + msg = "SIGSEGV/cgo crash (multi-line trace suppressed; set AKTO_SUPPRESS_TRACE=false for full output)" + if (to_logfile) print msg >> logf + else print msg > "/dev/stderr" + } + quiet = 1 + next + } + quiet { next } + { + if (to_logfile) print >> logf + else print > "/dev/stderr" + } + ' < "$ERRPIPE" & + AWKPID=$! + + if [ "$log_to_file" = "true" ]; then + ./ebpf-logging >> "$LOG_FILE" 2>"$ERRPIPE" + else + ./ebpf-logging 2>"$ERRPIPE" + fi + ebpf_exit=$? + wait "$AWKPID" 2>/dev/null + rm -f "$ERRPIPE" + return "$ebpf_exit" +} + +# Start memory monitoring in the background +while true; do + check_memory_and_kill + sleep "$CHECK_INTERVAL_MEM" +done & + +while : +do + # Source environment file if it exists (contains vars set by processCommandMessage) + if [ -f /ebpf/.env ]; then + set -a + source /ebpf/.env + set +a + fi + + run_ebpf_once + ebpf_exit=$? + + sleep "${CRASH_RESTART_BACKOFF_SECONDS}" done diff --git a/ebpf/bpfwrapper/eventCallbacks.go b/ebpf/bpfwrapper/eventCallbacks.go index 67b04dfb..22082b28 100644 --- a/ebpf/bpfwrapper/eventCallbacks.go +++ b/ebpf/bpfwrapper/eventCallbacks.go @@ -178,6 +178,7 @@ func SocketDataEventCallback(inputChan chan []byte, connectionFactory *connectio "data", dataStr, "rc", event.Attr.ReadEventsCount, "wc", event.Attr.WriteEventsCount, - "ssl", event.Attr.Ssl) + "ssl", event.Attr.Ssl, + "bytesSent", bytesSent) } } diff --git a/ebpf/connections/factory.go b/ebpf/connections/factory.go index ac762cd8..19718f68 100644 --- a/ebpf/connections/factory.go +++ b/ebpf/connections/factory.go @@ -18,6 +18,12 @@ import ( var httpBytes = []byte("HTTP") +var sequenceCheckSkip = false + +func init() { + utils.InitVar("AKTO_SKIP_SEQUENCE_CHECK", &sequenceCheckSkip) +} + // Factory is a routine-safe container that holds a trackers with unique ID, and able to create new tracker. type Factory struct { processor map[structs.ConnID]chan interface{} @@ -52,7 +58,12 @@ func convertToSingleByteArr(bufMap map[int][]byte) []byte { kPrev := -1 for _, k := range keys { if kPrev == -1 { - if k != 1 { + // C sets read, write event count=0 only on new connection open + // For requests arriving after a time gap on the same underlying connection the + // read,write count will not be 1, they will simply continue from the last request + // This can only be replicated when there is a time gap/inactivityThreshold between requests + // on the same underlying connection + if !sequenceCheckSkip && k != 1 { utils.LogProcessing("Bad start sequence", "key", k, "value", string(bufMap[k])) break } @@ -73,13 +84,15 @@ func convertToSingleByteArr(bufMap map[int][]byte) []byte { var ( disableEgress = false maxActiveConnections = 4096 - inactivityThreshold = 3 * time.Second + inactivityThreshold = 7 * time.Second // Value in MB bufferMemThreshold = 400 // unique id of daemonset uniqueDaemonsetId = uuid.New().String() trackerDataProcessInterval = 100 + + socketDataEventBytesThreshold = 10 * 1024 * 1024 ) func init() { @@ -89,6 +102,7 @@ func init() { utils.InitVar("TRAFFIC_BUFFER_THRESHOLD", &bufferMemThreshold) utils.InitVar("AKTO_MEM_SOFT_LIMIT", &bufferMemThreshold) utils.InitVar("TRACKER_DATA_PROCESS_INTERVAL", &trackerDataProcessInterval) + utils.InitVar("SOCKET_DATA_EVENT_BYTES_THRESHOLD", &socketDataEventBytesThreshold) } func ProcessTrackerData(connID structs.ConnID, tracker *Tracker, isComplete bool) { @@ -201,6 +215,26 @@ func (factory *Factory) CreateIfNotExists(connectionID structs.ConnID) { } } +// resetTimer stops, drains, and resets the timer to the given duration. +func resetTimer(t *time.Timer, d time.Duration) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) +} + +// Worker lifecycle: +// +// ACTIVE: +// - socket data/open -> reset inactivity timer on each event +// - socket close -> schedule delayed termination +// - inactivity timer -> terminate immediately +// +// TERMINATION is final and happens exactly once. +// either due to inactivityThreshold or due to socker close event func (factory *Factory) StartWorker(connectionID structs.ConnID, tracker *Tracker, ch chan interface{}) { go func(connID structs.ConnID, tracker *Tracker, ch chan interface{}) { @@ -216,9 +250,17 @@ func (factory *Factory) StartWorker(connectionID structs.ConnID, tracker *Tracke case *structs.SocketDataEvent: utils.LogProcessing("Received data event", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) tracker.AddDataEvent(*e) + if tracker.GetSentBytes()+tracker.GetRecvBytes() > uint64(socketDataEventBytesThreshold) { + utils.LogProcessing("Socket Data threshold data breached, processing current data", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) + factory.StopProcessing(connID) + return + } else { + resetTimer(inactivityTimer, inactivityThreshold) + } case *structs.SocketOpenEvent: utils.LogProcessing("Received open event", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) tracker.AddOpenEvent(*e) + resetTimer(inactivityTimer, inactivityThreshold) case *structs.SocketCloseEvent: utils.LogProcessing("Received close event", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) tracker.AddCloseEvent(*e) @@ -230,15 +272,13 @@ func (factory *Factory) StartWorker(connectionID structs.ConnID, tracker *Tracke case <-delayedDeleteChan: utils.LogProcessing("Stopping go routine (delayed close)", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) - factory.ProcessAndStopWorker(connID) - factory.DeleteWorker(connID) + factory.StopProcessing(connID) return case <-inactivityTimer.C: // Eat the go routine after inactive threshold, process the tracker and stop the worker utils.LogProcessing("Inactivity threshold reached, marking connection as inactive and processing", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) - factory.ProcessAndStopWorker(connID) - factory.DeleteWorker(connID) + factory.StopProcessing(connID) utils.LogProcessing("Stopping go routine", "fd", connID.Fd, "id", connID.Id, "timestamp", connID.Conn_start_ns, "ip", connID.Ip, "port", connID.Port) return } @@ -246,6 +286,11 @@ func (factory *Factory) StartWorker(connectionID structs.ConnID, tracker *Tracke }(connectionID, tracker, ch) } +func (factory *Factory) StopProcessing(connID structs.ConnID) { + factory.ProcessAndStopWorker(connID) + factory.DeleteWorker(connID) +} + func (factory *Factory) ProcessAndStopWorker(connectionID structs.ConnID) { tracker, connExists := factory.getTracker(connectionID) if connExists { diff --git a/ebpf/connections/parser.go b/ebpf/connections/parser.go index b8046d1a..65b4a69d 100644 --- a/ebpf/connections/parser.go +++ b/ebpf/connections/parser.go @@ -5,5 +5,18 @@ import ( ) func tryReadFromBD(ip string, destIp string, receiveBuffer []byte, sentBuffer []byte, isComplete bool, direction int, id uint64, fd uint32, daemonsetIdentifier, hostName string) { - kafkaUtil.ParseAndProduce(receiveBuffer, sentBuffer, ip, destIp, 0, false, "MIRRORING", isComplete, direction, id, fd, daemonsetIdentifier, hostName) + ctx := kafkaUtil.TrafficContext{ + SourceIP: ip, + DestIP: destIp, + VxlanID: 0, + IsPending: false, + TrafficSource: "MIRRORING", + IsComplete: isComplete, + Direction: direction, + ProcessID: uint32(id >> 32), + SocketFD: fd, + DaemonsetIdentifier: daemonsetIdentifier, + HostName: hostName, + } + kafkaUtil.ParseAndProduce(receiveBuffer, sentBuffer, ctx) } diff --git a/ebpf/connections/tracker.go b/ebpf/connections/tracker.go index 0d8d8acb..3056d0be 100644 --- a/ebpf/connections/tracker.go +++ b/ebpf/connections/tracker.go @@ -108,3 +108,11 @@ func (conn *Tracker) AddCloseEvent(event structs.SocketCloseEvent) { conn.closeTimestamp = uint64(time.Now().UnixNano()) conn.lastAccessTimestamp = uint64(time.Now().UnixNano()) } + +func (conn *Tracker) GetSentBytes() uint64 { + return conn.sentBytes +} + +func (conn *Tracker) GetRecvBytes() uint64 { + return conn.recvBytes +} \ No newline at end of file diff --git a/ebpf/conntrack/conntrack.go b/ebpf/conntrack/conntrack.go new file mode 100644 index 00000000..f536f801 --- /dev/null +++ b/ebpf/conntrack/conntrack.go @@ -0,0 +1,365 @@ +package conntrack + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/akto-api-security/mirroring-api-logging/ebpf/structs" + "github.com/iovisor/gobpf/bcc" + "github.com/akto-api-security/mirroring-api-logging/trafficUtil/utils" +) + +var enableConnPrefill = false + +func init() { + utils.InitVar("AKTO_CONN_PREFILL", &enableConnPrefill) +} + +// SocketInfo holds parsed socket information from /proc/net/tcp +type SocketInfo struct { + Inode uint64 + LocalIP uint32 + LocalPort uint16 + RemoteIP uint32 + RemotePort uint16 +} + +// FdInfo holds file descriptor information +type FdInfo struct { + Fd uint32 + Inode uint64 +} + +// ConnectionInfo combines FD and socket information for a connection +type ConnectionInfo struct { + Fd uint32 + RemoteIP uint32 + RemotePort uint16 + LocalIP uint32 + LocalPort uint16 +} + +// GenTgidFd generates the map key matching the C function gen_tgid_fd +func GenTgidFd(tgid uint32, fd uint32) uint64 { + return (uint64(tgid) << 32) | uint64(fd) +} + +// SerializeConnInfo serializes ConnInfoT to bytes matching BPF map layout +func SerializeConnInfo(info *structs.ConnInfoT) ([]byte, error) { + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.LittleEndian, info) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// ParseProcNetTcp parses /proc//net/tcp and /proc//net/tcp6 +// Returns a map of inode -> SocketInfo +func ParseProcNetTcp(pid uint32) (map[uint64]*SocketInfo, error) { + sockets := make(map[uint64]*SocketInfo) + + // Parse both tcp and tcp6 + for _, proto := range []string{"tcp", "tcp6"} { + path := fmt.Sprintf("/proc/%d/net/%s", pid, proto) + file, err := os.Open(path) + if err != nil { + // Skip if file doesn't exist + continue + } + defer file.Close() + + scanner := bufio.NewScanner(file) + // Skip header line + if scanner.Scan() { + // header skipped + } + + for scanner.Scan() { + line := scanner.Text() + info, err := parseTcpLine(line, proto == "tcp6") + if err != nil { + continue + } + sockets[info.Inode] = info + } + } + + return sockets, nil +} + +// parseTcpLine parses a single line from /proc/net/tcp or tcp6 +// Format: sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode +func parseTcpLine(line string, isIPv6 bool) (*SocketInfo, error) { + fields := strings.Fields(line) + if len(fields) < 10 { + return nil, fmt.Errorf("invalid line format") + } + + // Parse local address (field 1) + localIP, localPort, err := parseAddressPort(fields[1], isIPv6) + if err != nil { + return nil, err + } + + // Parse remote address (field 2) + remoteIP, remotePort, err := parseAddressPort(fields[2], isIPv6) + if err != nil { + return nil, err + } + + // Parse inode (field 9) + inode, err := strconv.ParseUint(fields[9], 10, 64) + if err != nil { + return nil, err + } + + return &SocketInfo{ + Inode: inode, + LocalIP: localIP, + LocalPort: localPort, + RemoteIP: remoteIP, + RemotePort: remotePort, + }, nil +} + +// parseAddressPort parses "IP:PORT" hex format from /proc/net/tcp +// For IPv6, extracts the last 32 bits (matching C code behavior for IPv4-mapped addresses) +func parseAddressPort(addrPort string, isIPv6 bool) (uint32, uint16, error) { + parts := strings.Split(addrPort, ":") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid address:port format") + } + + // Parse port (always 16-bit hex) + port, err := strconv.ParseUint(parts[1], 16, 16) + if err != nil { + return 0, 0, err + } + + // Parse IP + var ip uint32 + if isIPv6 { + // IPv6 address is 32 hex chars, take last 8 chars (32 bits) + // This matches the C code: conn_info.ip = (in_addr.s6_addr32)[3] + ipHex := parts[0] + if len(ipHex) >= 8 { + ip64, err := strconv.ParseUint(ipHex[len(ipHex)-8:], 16, 32) + if err != nil { + return 0, 0, err + } + ip = uint32(ip64) + } + } else { + // IPv4 address is 8 hex chars, already in little-endian + ip64, err := strconv.ParseUint(parts[0], 16, 32) + if err != nil { + return 0, 0, err + } + ip = uint32(ip64) + } + + return ip, uint16(port), nil +} + +// GetSocketFds returns all socket file descriptors for a process +func GetSocketFds(pid uint32) ([]FdInfo, error) { + fdDir := fmt.Sprintf("/proc/%d/fd", pid) + entries, err := os.ReadDir(fdDir) + if err != nil { + return nil, err + } + + var fds []FdInfo + for _, entry := range entries { + fd, err := strconv.ParseUint(entry.Name(), 10, 32) + if err != nil { + continue + } + + // Read symlink to check if it's a socket + linkPath := filepath.Join(fdDir, entry.Name()) + target, err := os.Readlink(linkPath) + if err != nil { + continue + } + + // Check if it's a socket: socket:[inode] + if strings.HasPrefix(target, "socket:[") && strings.HasSuffix(target, "]") { + inodeStr := target[8 : len(target)-1] + inode, err := strconv.ParseUint(inodeStr, 10, 64) + if err != nil { + continue + } + fds = append(fds, FdInfo{ + Fd: uint32(fd), + Inode: inode, + }) + } + } + + return fds, nil +} + +// EnumerateExistingConnections returns all existing TCP connections for a PID +func EnumerateExistingConnections(pid uint32) ([]ConnectionInfo, error) { + // Get socket info by inode + sockets, err := ParseProcNetTcp(pid) + if err != nil { + return nil, err + } + + // Get FDs and their inodes + fds, err := GetSocketFds(pid) + if err != nil { + return nil, err + } + + // Match FDs to sockets + var connections []ConnectionInfo + for _, fdInfo := range fds { + if sockInfo, ok := sockets[fdInfo.Inode]; ok { + // Skip connections with no remote endpoint (listening sockets) + if sockInfo.RemoteIP == 0 && sockInfo.RemotePort == 0 { + continue + } + connections = append(connections, ConnectionInfo{ + Fd: fdInfo.Fd, + RemoteIP: sockInfo.RemoteIP, + RemotePort: sockInfo.RemotePort, + LocalIP: sockInfo.LocalIP, + LocalPort: sockInfo.LocalPort, + }) + } + } + return connections, nil +} + +// PopulateConnInfoWithRotation adds a connection to the BPF maps with rotation logic +func PopulateConnInfoWithRotation( + connInfoTable, connCounterTable, connInfoMapKeysTable *bcc.Table, + tgidFd uint64, + connInfo *structs.ConnInfoT, + maxMapSize int, +) error { + // Read current counter value + counterKey := make([]byte, 4) // int key = 0 + counterBytes, err := connCounterTable.Get(counterKey) + if err != nil { + // Counter might not exist yet, start at 0 + counterBytes = make([]byte, 4) + } + + counter := int32(binary.LittleEndian.Uint32(counterBytes)) + + // Reset if near limit + if counter > int32(maxMapSize-5) { + counter = 0 + } + counter++ + + // Get old key at this index and delete from conn_info_map + indexKey := make([]byte, 4) + binary.LittleEndian.PutUint32(indexKey, uint32(counter)) + + oldKeyBytes, err := connInfoMapKeysTable.Get(indexKey) + if err == nil && len(oldKeyBytes) == 8 { + // Delete old entry + connInfoTable.Delete(oldKeyBytes) + } + + // Write new tgid_fd to keys array + tgidFdBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(tgidFdBytes, tgidFd) + if err := connInfoMapKeysTable.Set(indexKey, tgidFdBytes); err != nil { + return fmt.Errorf("failed to set conn_info_map_keys: %w", err) + } + + // Serialize and write conn_info to map + connInfoBytes, err := SerializeConnInfo(connInfo) + if err != nil { + return fmt.Errorf("failed to serialize conn_info: %w", err) + } + if err := connInfoTable.Set(tgidFdBytes, connInfoBytes); err != nil { + return fmt.Errorf("failed to set conn_info_map: %w", err) + } + + // Update counter + binary.LittleEndian.PutUint32(counterKey, uint32(counter)) + if err := connCounterTable.Set([]byte{0, 0, 0, 0}, counterKey); err != nil { + return fmt.Errorf("failed to update conn_counter: %w", err) + } + + return nil +} + +// PopulateExistingConnections enumerates and populates all existing connections for given PIDs +func PopulateExistingConnections( + pids []uint32, + connInfoTable, connCounterTable, connInfoMapKeysTable *bcc.Table, + maxMapSize int, +) { + var totalConnFound, totalConnPopulated int + var enumerationFailures, populationFailures int + + // Don't actually prefill in C maps, finding conn is still done + // to know how many are typically open and therefore data is not captured. + if !enableConnPrefill { + slog.Debug("connection prefill disabled", "enableConnPrefill", enableConnPrefill) + } + + for _, pid := range pids { + connections, err := EnumerateExistingConnections(pid) + if err != nil { + enumerationFailures++ + continue + } + + totalConnFound += len(connections) + + if !enableConnPrefill { + continue + } + + for _, conn := range connections { + tgidFd := GenTgidFd(pid, conn.Fd) + + connInfo := &structs.ConnInfoT{ + Id: (uint64(pid) << 32) | uint64(pid), + Fd: conn.Fd, + ConnStartNs: 0, // Unknown for pre-existing connections + Port: conn.RemotePort, + Ip: conn.RemoteIP, + Ssl: false, + ReadEventsCount: 0, + WriteEventsCount: 0, + } + + err := PopulateConnInfoWithRotation( + connInfoTable, connCounterTable, connInfoMapKeysTable, + tgidFd, connInfo, maxMapSize, + ) + if err != nil { + populationFailures++ + continue + } + + totalConnPopulated++ + } + } + + slog.Info("completed connection prefill", + "pids_processed", len(pids), + "total_connections_found", totalConnFound, + "total_connections_populated", totalConnPopulated, + "enumeration_failures", enumerationFailures, + "population_failures", populationFailures, + ) +} diff --git a/ebpf/go.mod b/ebpf/go.mod index f3162c1e..41b99a47 100644 --- a/ebpf/go.mod +++ b/ebpf/go.mod @@ -1,8 +1,6 @@ module github.com/akto-api-security/mirroring-api-logging/ebpf -go 1.24.0 - -toolchain go1.24.3 +go 1.25.8 require ( github.com/akto-api-security/mirroring-api-logging/trafficUtil v0.0.0-00010101000000-000000000000 @@ -12,6 +10,8 @@ require ( ) require ( + github.com/bits-and-blooms/bitset v1.24.2 // indirect + github.com/bits-and-blooms/bloom/v3 v3.7.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/ebpf/go.sum b/ebpf/go.sum index b48ff900..ee3cc7e9 100644 --- a/ebpf/go.sum +++ b/ebpf/go.sum @@ -1,3 +1,7 @@ +github.com/bits-and-blooms/bitset v1.24.2 h1:M7/NzVbsytmtfHbumG+K2bremQPMJuqv1JD3vOaFxp0= +github.com/bits-and-blooms/bitset v1.24.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.7.1 h1:WXovk4TRKZttAMJfoQx6K2DM0zNIt8w+c67UqO+etV0= +github.com/bits-and-blooms/bloom/v3 v3.7.1/go.mod h1:rZzYLLje2dfzXfAkJNxQQHsKurAyK55KUnL43Euk0hU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -97,6 +101,8 @@ github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08 github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= +github.com/twmb/murmur3 v1.1.8 h1:8Yt9taO/WN3l08xErzjeschgZU2QSrwm1kclYq+0aRg= +github.com/twmb/murmur3 v1.1.8/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= diff --git a/ebpf/kernel/module.cc b/ebpf/kernel/module.cc index 4cee7779..a17a1251 100644 --- a/ebpf/kernel/module.cc +++ b/ebpf/kernel/module.cc @@ -7,6 +7,7 @@ #define socklen_t size_t #define MAX_MSG_SIZE 30720 +#define CHUNK_LIMIT CHUNK_SIZE_LIMIT #define LOOP_LIMIT 42 #define ARCH_TYPE 1 @@ -370,8 +371,35 @@ static __inline void process_syscall_data(struct pt_regs* ret, const struct data socket_data_event->conn_start_ns = conn_info->conn_start_ns; socket_data_event->port = conn_info->port; socket_data_event->ip = conn_info->ip; - socket_data_event->bytes_sent = is_send ? 1 : -1; socket_data_event->ssl = conn_info->ssl; + + int bytes_sent = 0; + size_t size_to_save = 0; + int i =0; + #pragma unroll + for (i = 0; i < CHUNK_LIMIT; ++i) { + const int bytes_remaining = bytes_exchanged - bytes_sent; + + if (bytes_remaining <= 0) { + break; + } + size_t current_size = (bytes_remaining > MAX_MSG_SIZE && (i != CHUNK_LIMIT - 1)) ? MAX_MSG_SIZE : bytes_remaining; + + size_t current_size_minus_1 = current_size - 1; + asm volatile("" : "+r"(current_size_minus_1) :); + current_size = current_size_minus_1 + 1; + + if (current_size > MAX_MSG_SIZE) { + current_size = MAX_MSG_SIZE; + } + + if (current_size_minus_1 < MAX_MSG_SIZE) { + bpf_probe_read(&socket_data_event->msg, current_size, args->buf + bytes_sent); + size_to_save = current_size; + } else if (current_size_minus_1 < 0x7fffffff) { + bpf_probe_read(&socket_data_event->msg, MAX_MSG_SIZE, args->buf + bytes_sent); + size_to_save = MAX_MSG_SIZE; + } if (is_send){ conn_info->writeEventsCount = (conn_info->writeEventsCount) + 1u; @@ -385,29 +413,18 @@ static __inline void process_syscall_data(struct pt_regs* ret, const struct data if(PRINT_BPF_LOGS){ bpf_trace_printk("pid: %d conn-id:%d, fd: %d", id, conn_info->id, conn_info->fd); + bpf_trace_printk("current_size: %d i:%d, bytes_exchanged: %d", current_size, i, bytes_exchanged); unsigned long tdfd = ((id & 0xffff) << 32) + conn_info->fd; bpf_trace_printk("rwc: %d tdfd: %llu data: %s", (socket_data_event->readEventsCount*10000 + socket_data_event->writeEventsCount%10000),tgid_fd, socket_data_event->msg); } - size_t bytes_exchanged_minus_1 = bytes_exchanged - 1; - asm volatile("" : "+r"(bytes_exchanged_minus_1) :); - bytes_exchanged = bytes_exchanged_minus_1 + 1; - - size_t size_to_save = 0; - if (bytes_exchanged_minus_1 < MAX_MSG_SIZE) { - bpf_probe_read(&socket_data_event->msg, bytes_exchanged, args->buf); - size_to_save = bytes_exchanged; - socket_data_event->msg[size_to_save] = '\\0'; - } else if (bytes_exchanged_minus_1 < 0x7fffffff) { - bpf_probe_read(&socket_data_event->msg, MAX_MSG_SIZE, args->buf); - size_to_save = MAX_MSG_SIZE; - } - - + socket_data_event->bytes_sent = is_send ? 1 : -1; socket_data_event->bytes_sent *= size_to_save; - socket_data_events.perf_submit(ret, socket_data_event, sizeof(struct socket_data_event_t) - MAX_MSG_SIZE + size_to_save); + bytes_sent += current_size; + } + } static __inline void process_syscall_data_vecs(struct pt_regs* ret, struct data_args_t* args, u64 id, bool is_send){ diff --git a/ebpf/main.go b/ebpf/main.go index 7fbf6d67..cbaeef6c 100644 --- a/ebpf/main.go +++ b/ebpf/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/binary" "fmt" "log/slog" "os" @@ -13,7 +14,6 @@ import ( "strings" "syscall" "time" - // need an unreleased version of the gobpf library, using from a specific branch, reasoning in the thread below. // https://stackoverflow.com/questions/73714654/not-enough-arguments-in-call-to-c2func-bcc-func-load @@ -21,6 +21,7 @@ import ( "github.com/akto-api-security/mirroring-api-logging/ebpf/bpfwrapper" "github.com/akto-api-security/mirroring-api-logging/ebpf/connections" + "github.com/akto-api-security/mirroring-api-logging/ebpf/conntrack" "github.com/akto-api-security/mirroring-api-logging/ebpf/uprobeBuilder/process" "github.com/akto-api-security/mirroring-api-logging/ebpf/uprobeBuilder/ssl" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/apiProcessor" @@ -32,6 +33,12 @@ import ( var source string = "" +func replaceBpfChunkSizeMacros() { + chunkSizeLimit := 4 + trafficUtils.InitVar("BPF_CHUNK_SIZE_LIMIT", &chunkSizeLimit) + source = strings.Replace(source, "CHUNK_SIZE_LIMIT", strconv.Itoa(chunkSizeLimit), -1) +} + func replaceBpfLogsMacros() { printBpfLogsEnv := os.Getenv("PRINT_BPF_LOGS") @@ -80,10 +87,15 @@ func main() { // Setting GC percent as 50, uses less memory overhead. // More testing needed for final release. // debug.SetGCPercent(50) + run() } func run() { + slog.Debug("Go version", "version", runtime.Version()) + slog.Debug("runtime.NumCPU()", "count", runtime.NumCPU()) + slog.Debug("runtime.GOMAXPROCS(0)", "procs", runtime.GOMAXPROCS(0)) + byteString, err := os.ReadFile("./kernel/module.cc") if err != nil { slog.Error("failed to read kernel module", "error", err) @@ -92,6 +104,7 @@ func run() { source = string(byteString) replaceBpfLogsMacros() + replaceBpfChunkSizeMacros() replaceMaxConnectionMapSize() replaceArchType() @@ -111,11 +124,18 @@ func run() { apiProcessor.InitCloudTrafficProcessor() kafkaUtil.InitKafka() + kafkaUtil.StartConfigConsumer() + stopCh, err := kafkaUtil.SetupPodInformer() if err != nil { slog.Error("Failed to setup pod watcher", "error", err) } + if kafkaUtil.PodInformerInstance != nil { + kubePids := kafkaUtil.PodInformerInstance.GetAllKubePids() + fillExistingConnections(bpfModule, kubePids) + } + connectionFactory := connections.NewFactory() trafficMetrics.InitTrafficMaps() @@ -202,7 +222,7 @@ func run() { trafficUtils.InitVar("AKTO_DEBUG_MEM_PROFILING", &doProfiling) if doProfiling { - ticker := time.NewTicker(time.Minute) // Create a ticker to trigger every minute + ticker := time.NewTicker(30 * time.Second) // Create a ticker to trigger every 30 seconds defer ticker.Stop() for range ticker.C { @@ -226,15 +246,70 @@ func run() { slog.Info("Stopping pod watcher") close(stopCh) } - + slog.Info("signaled to terminate") } +func fillExistingConnections(bpfModule *bcc.Module, tracedPids []uint32) { + connInfoTable := bcc.NewTable(bpfModule.TableId("conn_info_map"), bpfModule) + connCounterTable := bcc.NewTable(bpfModule.TableId("conn_counter"), bpfModule) + connInfoMapKeysTable := bcc.NewTable(bpfModule.TableId("conn_info_map_keys"), bpfModule) + + maxConnectionSizeMapSize := 131072 + trafficUtils.InitVar("TRAFFIC_MAX_CONNECTION_MAP_SIZE", &maxConnectionSizeMapSize) + + slog.Info("populating pre-existing connections", "pids", tracedPids) + conntrack.PopulateExistingConnections( + tracedPids, + connInfoTable, + connCounterTable, + connInfoMapKeysTable, + maxConnectionSizeMapSize, + ) +} + +// Use this when specific pids tracing is required. +func setupTracePids(bpfModule *bcc.Module) []uint32 { + kubePidsTable := bcc.NewTable(bpfModule.TableId("kubernetes_pids"), bpfModule) + var tracedPids []uint32 + if tracePids := os.Getenv("TRACE_PIDS"); tracePids != "" { + for _, pidStr := range strings.Split(tracePids, ",") { + pidStr = strings.TrimSpace(pidStr) + if pidStr == "" { + continue + } + pid, err := strconv.ParseUint(pidStr, 10, 32) + if err != nil { + slog.Error("invalid pid in TRACE_PIDS", "pid", pidStr, "error", err) + continue + } + var pidKey [4]byte + binary.LittleEndian.PutUint32(pidKey[:], uint32(pid)) + if err := kubePidsTable.Set(pidKey[:], []byte{1}); err != nil { + slog.Error("failed to add pid to kubernetes_pids map", "pid", pid, "error", err) + } else { + slog.Info("added pid to kubernetes_pids map", "pid", pid) + tracedPids = append(tracedPids, uint32(pid)) + } + } + } else { + slog.Warn("TRACE_PIDS env variable not set, no PIDs will be traced") + } + return tracedPids +} + func captureMemoryProfile() { - f, _ := os.Create("mem.prof") // Create memory profile file + timestamp := time.Now().Format("20060102_150405") + fileName := fmt.Sprintf("mem_%s.prof", timestamp) + f, err := os.Create(fileName) + if err != nil { + slog.Error("failed to create memory profile", "error", err) + return + } defer f.Close() - pprof.WriteHeapProfile(f) // Write memory profile + pprof.WriteHeapProfile(f) + slog.Info("memory profile captured", "filename", fileName) } func captureCpuProfile() { diff --git a/ebpf/structs/structs.go b/ebpf/structs/structs.go index 3a88bd09..14407e44 100644 --- a/ebpf/structs/structs.go +++ b/ebpf/structs/structs.go @@ -52,3 +52,28 @@ type SocketCloseEvent struct { ConnId ConnID Socket_open_ns uint64 } + +// ConnInfoT matches the C struct conn_info_t layout for BPF map population +// C struct: +// +// u64 id; +// u32 fd; +// u64 conn_start_ns; +// unsigned short port; +// u32 ip; +// bool ssl; +// u32 readEventsCount; +// u32 writeEventsCount; +type ConnInfoT struct { + Id uint64 + Fd uint32 + Padding1 [4]byte // alignment padding for conn_start_ns + ConnStartNs uint64 + Port uint16 + Padding2 [2]byte // alignment padding for ip + Ip uint32 + Ssl bool + Padding3 [3]byte // alignment padding for readEventsCount + ReadEventsCount uint32 + WriteEventsCount uint32 +} \ No newline at end of file diff --git a/ebpf/uprobeBuilder/process/processFactory.go b/ebpf/uprobeBuilder/process/processFactory.go index c83d73bb..39b26286 100644 --- a/ebpf/uprobeBuilder/process/processFactory.go +++ b/ebpf/uprobeBuilder/process/processFactory.go @@ -118,6 +118,10 @@ func (processFactory *ProcessFactory) AddNewProcessesToProbe(bpfModule *bcc.Modu // openssl probes here are being attached on dynamically linked SSL libraries only. attached, err := ssl.TryOpensslProbes(libraries, bpfModule) + if len(containers) == 0 { + containers = append(containers, "unknown") + } + if attached { p := Process{ pid: pid, @@ -158,9 +162,7 @@ func (processFactory *ProcessFactory) AddNewProcessesToProbe(bpfModule *bcc.Modu } else if err != nil { slog.Error("Node probing error", "pid", pid, "error", err) } - processFactory.unattachedProcess[pid] = true - } } } diff --git a/main.go b/main.go index aff7d276..ebc68d17 100644 --- a/main.go +++ b/main.go @@ -158,10 +158,20 @@ func (s *myStream) ReassemblyComplete() { } func tryReadFromBD(bd *bidi, isPending bool) { - - kafkaUtil.ParseAndProduce(bd.a.bytes, bd.b.bytes, - bd.key.net.Src().String(), bd.key.net.Dst().String(), bd.vxlanID, isPending, bd.source, true, 1, 0, 0, "0") - + ctx := kafkaUtil.TrafficContext{ + SourceIP: bd.key.net.Src().String(), + DestIP: bd.key.net.Dst().String(), + VxlanID: bd.vxlanID, + IsPending: isPending, + TrafficSource: bd.source, + IsComplete: true, + Direction: 1, + ProcessID: 0, + SocketFD: 0, + DaemonsetIdentifier: "0", + HostName: "", + } + kafkaUtil.ParseAndProduce(bd.a.bytes, bd.b.bytes, ctx) } // maybeFinish will wait until both directions are complete, then print out diff --git a/trafficUtil/go.mod b/trafficUtil/go.mod index 7debcb5e..33c88e08 100644 --- a/trafficUtil/go.mod +++ b/trafficUtil/go.mod @@ -13,7 +13,10 @@ require ( k8s.io/client-go v0.33.0 ) +require github.com/bits-and-blooms/bitset v1.24.2 // indirect + require ( + github.com/bits-and-blooms/bloom/v3 v3.7.1 github.com/davecgh/go-spew v1.1.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/trafficUtil/go.sum b/trafficUtil/go.sum index 4489cb35..60ac7995 100644 --- a/trafficUtil/go.sum +++ b/trafficUtil/go.sum @@ -1,3 +1,7 @@ +github.com/bits-and-blooms/bitset v1.24.2 h1:M7/NzVbsytmtfHbumG+K2bremQPMJuqv1JD3vOaFxp0= +github.com/bits-and-blooms/bitset v1.24.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.7.1 h1:WXovk4TRKZttAMJfoQx6K2DM0zNIt8w+c67UqO+etV0= +github.com/bits-and-blooms/bloom/v3 v3.7.1/go.mod h1:rZzYLLje2dfzXfAkJNxQQHsKurAyK55KUnL43Euk0hU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -96,6 +100,7 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twmb/murmur3 v1.1.8/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= diff --git a/trafficUtil/kafkaUtil/ebpf_telemetry.go b/trafficUtil/kafkaUtil/ebpf_telemetry.go new file mode 100644 index 00000000..587cd7da --- /dev/null +++ b/trafficUtil/kafkaUtil/ebpf_telemetry.go @@ -0,0 +1,334 @@ +package kafkaUtil + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math/rand" + "os" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "github.com/segmentio/kafka-go" +) + +type MessageType string + +const ( + MessageTypeEnvReload MessageType = "ENV_RELOAD" + MessageTypeRestart MessageType = "RESTART" +) + +type TrafficAgentCommandMessage struct { + MessageType MessageType `json:"messageType"` + DaemonNames []string `json:"daemonNames"` + Env map[string]string `json:"env"` + DaemonEnvMap map[string]map[string]string `json:"daemonEnvMap"` + Timestamp int64 `json:"timestamp"` +} + +var ( + lastCPUTime float64 + lastMeasureTime time.Time + cpuMutex sync.Mutex +) + +func getEnvData() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + +func getCPUUsage() (cpuPercent float64, cpuCoresUsed float64) { + var rusage syscall.Rusage + syscall.Getrusage(syscall.RUSAGE_SELF, &rusage) + + totalCPUSec := float64(rusage.Utime.Sec+rusage.Stime.Sec) + + float64(rusage.Utime.Usec+rusage.Stime.Usec)/1000000 + + cpuMutex.Lock() + defer cpuMutex.Unlock() + + now := time.Now() + + if !lastMeasureTime.IsZero() { + elapsed := now.Sub(lastMeasureTime).Seconds() + cpuDelta := totalCPUSec - lastCPUTime + cpuPercent = (cpuDelta / elapsed) * 100 + cpuCoresUsed = cpuDelta / elapsed + } + + lastCPUTime = totalCPUSec + lastMeasureTime = now + + return cpuPercent, cpuCoresUsed +} + +func getProfilingData() map[string]interface{} { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + allocMB := float64(memStats.Alloc) / 1024 / 1024 + sysMB := float64(memStats.Sys) / 1024 / 1024 + totalAllocMB := float64(memStats.TotalAlloc) / 1024 / 1024 + + cpuPercent, cpuCoresUsed := getCPUUsage() + + profiling := map[string]interface{}{ + "memory_used_mb": allocMB, + "memory_total_mb": sysMB, + "memory_cumulative_mb": totalAllocMB, + "cpu_percent": cpuPercent, + "cpu_cores_used": cpuCoresUsed, + "cpu_cores_total": runtime.NumCPU(), + "goroutines": runtime.NumGoroutine(), + "num_gc": memStats.NumGC, + } + + return profiling +} + +func writeEnvFile() error { + dir := "/ebpf" + finalPath := "/ebpf/.env" + tmpPath := "/ebpf/.env.tmp" + + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + var content strings.Builder + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + key := parts[0] + val := parts[1] + + // Proper shell escaping + escapedVal := strconv.Quote(val) + + content.WriteString("export ") + content.WriteString(key) + content.WriteString("=") + content.WriteString(escapedVal) + content.WriteString("\n") + } + + err := os.WriteFile(tmpPath, []byte(content.String()), 0644) + if err != nil { + slog.Error("Failed to write environment file", "path", tmpPath, "error", err) + return err + } + slog.Debug("Environment variables written to file", "path", tmpPath) + // Atomic replace + err = os.Rename(tmpPath, finalPath) + if err != nil { + slog.Error("Failed to rename environment file", "path", tmpPath, "error", err) + return err + } + slog.Debug("Environment variables renamed to file", "path", finalPath) + return nil +} + +func restartSelf() { + slog.Warn("Restarting process with new environment...") + + // Write current environment to file so shell script can source it on next restart + writeEnvFile() + + // Exit and let shell script restart with fresh process + os.Exit(0) +} + +// processCommandMessage handles a single command message +func processCommandMessage(command TrafficAgentCommandMessage) { + daemonPodName := getDaemonPodName() + + if command.MessageType == MessageTypeRestart { + _, ok := command.DaemonEnvMap[daemonPodName] + if !ok { + _, ok = command.DaemonEnvMap["ALL"] + } + if !ok { + slog.Debug("Restart command not for this daemon, ignoring", + "thisDaemonPodName", daemonPodName) + return + } + slog.Info("Restarting process...") + restartSelf() + return + } + + if command.MessageType == MessageTypeEnvReload { + // Resolve env vars for this daemon: prefer pod-specific entry, fall back to "ALL" + envVars, ok := command.DaemonEnvMap[daemonPodName] + if !ok { + envVars, ok = command.DaemonEnvMap["ALL"] + } + if !ok { + slog.Debug("ENV_RELOAD not targeted at this daemon, ignoring", + "thisDaemonPodName", daemonPodName) + return + } + + slog.Info("Processing ENV_RELOAD command", + "thisDaemonPodName", daemonPodName, + "envCount", len(envVars)) + + if len(envVars) == 0 { + slog.Warn("ENV_RELOAD with no environment variables provided for this daemon") + return + } + + for key, value := range envVars { + oldValue := os.Getenv(key) + if oldValue != value { + slog.Warn("Updating environment variable", + "key", key, + "oldValue", oldValue, + "newValue", value) + os.Setenv(key, value) + } + } + slog.Info("Environment variables updated, restarting process...") + restartSelf() + return + } + + slog.Warn("Unknown message type, ignoring", "messageType", command.MessageType) +} + +func StartConfigConsumer() { + kafka_url := os.Getenv("AKTO_KAFKA_BROKER_MAL") + if len(kafka_url) == 0 { + kafka_url = os.Getenv("AKTO_KAFKA_BROKER_URL") + } + + if kafka_url == "" { + slog.Warn("Kafka URL not configured, config consumer disabled") + return + } + + topic := "akto.config.updates" + groupID := fmt.Sprintf("ebpf-config-consumer-%s", getDaemonPodName()) + + slog.Info("Starting config consumer", "topic", topic, "groupID", groupID, "daemonId", uniqueDaemonsetId) + + // Create Kafka reader (consumer) + readerConfig := kafka.ReaderConfig{ + Brokers: []string{kafka_url}, + Topic: topic, + GroupID: groupID, + MinBytes: 1, + MaxBytes: 10e6, + CommitInterval: time.Second, + StartOffset: kafka.LastOffset, + } + + // Apply common TLS and SASL configuration + readerConfig.Dialer = getKafkaDialer() + + reader := kafka.NewReader(readerConfig) + + // Start consumer goroutine + go func() { + defer reader.Close() + + ctx := context.Background() + for { + msg, err := reader.FetchMessage(ctx) + if err != nil { + slog.Error("Error reading config update message", "error", err) + continue + } + + slog.Debug("Received command message", "value", string(msg.Value)) + + var command TrafficAgentCommandMessage + err = json.Unmarshal(msg.Value, &command) + if err != nil { + slog.Error("Failed to parse command message", "error", err) + if err := reader.CommitMessages(ctx, msg); err != nil { + slog.Error("Failed to commit unparseable message", "error", err) + } + continue + } + + if err := reader.CommitMessages(ctx, msg); err != nil { + slog.Error("Failed to commit message offset", "error", err) + } + + slog.Debug("Received command message", "value", string(msg.Value)) + processCommandMessage(command) + } + }() + + slog.Info("Config consumer started successfully") +} + +func sendHeartbeatMessage(ctx context.Context, daemonPodName, imageVersion string) { + additionalData := map[string]interface{}{ + "env": getEnvData(), + "profiling": getProfilingData(), + } + + additionalDataJSON, err := json.Marshal(additionalData) + if err != nil { + slog.Error("Failed to marshal additionalData", "error", err) + additionalDataJSON = []byte("{}") + } + + heartbeatMessage := map[string]string{ + "type": "heartbeat", + "daemonId": uniqueDaemonsetId, + "daemonPodName": daemonPodName, + "timestamp": fmt.Sprint(time.Now().Unix()), + "moduleType": moduleType, + "imageVersion": imageVersion, + "additionalData": string(additionalDataJSON), + } + + slog.Debug("Sending Kafka heartbeat", "daemonPod", daemonPodName, "imageVersion", imageVersion, "heartbeatMessage", heartbeatMessage) + err = ProduceHeartbeat(ctx, heartbeatMessage) + if err != nil { + slog.Error("Failed to send heartbeat to Kafka", "error", err) + } +} + +func sendKafkaHeartbeat() { + if heartbeatIntervalSeconds <= 0 { + slog.Info("Kafka heartbeat disabled", "interval", heartbeatIntervalSeconds) + return + } + + daemonPodName := getDaemonPodName() + imageVersion := getImageVersion() + + slog.Debug("Starting Kafka heartbeat routine", "interval_seconds", heartbeatIntervalSeconds, "daemonPod", daemonPodName, "daemonId", uniqueDaemonsetId) + ctx := context.Background() + + slog.Info("Sending initial heartbeat") + sendHeartbeatMessage(ctx, daemonPodName, imageVersion) + + for { + jitter := time.Duration(1+rand.Intn(5)) * time.Second + sleepDuration := time.Duration(heartbeatIntervalSeconds)*time.Second + jitter + + slog.Debug("Sleeping before next heartbeat", "base_interval", heartbeatIntervalSeconds, "jitter_seconds", jitter.Seconds(), "total_sleep", sleepDuration.Seconds()) + time.Sleep(sleepDuration) + + sendHeartbeatMessage(ctx, daemonPodName, imageVersion) + } +} diff --git a/trafficUtil/kafkaUtil/kafka.go b/trafficUtil/kafkaUtil/kafka.go index fdecb1dd..2a6c9d63 100644 --- a/trafficUtil/kafkaUtil/kafka.go +++ b/trafficUtil/kafkaUtil/kafka.go @@ -10,17 +10,21 @@ import ( "os" "strconv" "strings" + "sync" "time" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/apiProcessor" trafficpb "github.com/akto-api-security/mirroring-api-logging/trafficUtil/protobuf/traffic_payload" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/utils" + "github.com/google/uuid" "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/sasl/plain" "google.golang.org/protobuf/proto" ) var kafkaWriter *kafka.Writer +var kafkaWriterMutex sync.RWMutex var KafkaErrMsgCount = 0 var KafkaErrMsgEpoch = time.Now() var BytesInThreshold = 500 * 1024 * 1024 @@ -29,12 +33,32 @@ var useTLS = false var InsecureSkipVerify = true var tlsCACertPath = "./ca.crt" +var isAuthImplemented = false +var kafkaUsername = "" +var kafkaPassword = "" + +var kafkaErrorThreshold = 500 +var kafkaReconnectIntervalMinutes = -1 +var heartbeatIntervalSeconds = 60 +var uniqueDaemonsetId = uuid.New().String() +var moduleType = "TRAFFIC_COLLECTOR" + +var globalTransport *kafka.Transport +var transportOnce sync.Once + func init() { utils.InitVar("USE_TLS", &useTLS) utils.InitVar("INSECURE_SKIP_VERIFY", &InsecureSkipVerify) utils.InitVar("TLS_CA_CERT_PATH", &tlsCACertPath) + utils.InitVar("IS_AUTH_IMPLEMENTED", &isAuthImplemented) + utils.InitVar("KAFKA_USERNAME", &kafkaUsername) + utils.InitVar("KAFKA_PASSWORD", &kafkaPassword) + + utils.InitVar("KAFKA_ERROR_THRESHOLD", &kafkaErrorThreshold) + utils.InitVar("KAFKA_RECONNECT_INTERVAL_MINUTES", &kafkaReconnectIntervalMinutes) + utils.InitVar("KAFKA_HEARTBEAT_INTERVAL_SECONDS", &heartbeatIntervalSeconds) } func InitKafka() { @@ -63,19 +87,20 @@ func InitKafka() { kafka_batch_size, e := strconv.Atoi(os.Getenv("AKTO_TRAFFIC_BATCH_SIZE")) if e != nil { - utils.PrintLog("AKTO_TRAFFIC_BATCH_SIZE should be valid integer") - return + kafka_batch_size = 100 } kafka_batch_time_secs, e := strconv.Atoi(os.Getenv("AKTO_TRAFFIC_BATCH_TIME_SECS")) if e != nil { - utils.PrintLog("AKTO_TRAFFIC_BATCH_TIME_SECS should be valid integer") - return + kafka_batch_time_secs = 10 } kafka_batch_time_secs_duration := time.Duration(kafka_batch_time_secs) for { + kafkaWriterMutex.Lock() kafkaWriter = getKafkaWriter(kafka_url, kafka_batch_size, kafka_batch_time_secs_duration*time.Second) + kafkaWriterMutex.Unlock() + utils.LogMemoryStats() utils.PrintLog("logging kafka stats before pushing message") LogKafkaStats() @@ -85,16 +110,31 @@ func InitKafka() { out, _ := json.Marshal(value) ctx := context.Background() - err := ProduceStr(ctx, string(out), "testKafkaConnection", "testKafkaConnectionHost") + err := ProduceStr(ctx, string(out), "testKafkaConnection", "testKafkaConnectionHost", "") utils.PrintLog("logging kafka stats post pushing message") LogKafkaStats() if err != nil { slog.Error("error establishing connection with kafka, sending message failed, retrying in 2 seconds", "error", err) + kafkaWriterMutex.Lock() kafkaWriter.Close() + kafkaWriterMutex.Unlock() + if globalTransport != nil { + globalTransport.CloseIdleConnections() + } time.Sleep(time.Second * 2) } else { utils.PrintLog("connection establishing with kafka successfully") + kafkaWriterMutex.Lock() kafkaWriter.Completion = kafkaCompletion() + kafkaWriterMutex.Unlock() + + // Start periodic reconnection routine + go periodicKafkaReconnect(kafka_url, kafka_batch_size, kafka_batch_time_secs_duration*time.Second) + slog.Info("Started Kafka periodic reconnection routine", "interval_minutes", kafkaReconnectIntervalMinutes) + + // Start heartbeat routine + go sendKafkaHeartbeat() + slog.Info("Started Kafka heartbeat routine", "interval_seconds", heartbeatIntervalSeconds) break } } @@ -105,17 +145,103 @@ func kafkaCompletion() func(messages []kafka.Message, err error) { if err != nil { KafkaErrMsgCount += len(messages) slog.Error("kafka error message", "err", err, "count", KafkaErrMsgCount, "messagesCount", len(messages)) + + if KafkaErrMsgCount > kafkaErrorThreshold { + slog.Error("kafka error count exceeded threshold, restarting module", "count", KafkaErrMsgCount, "threshold", kafkaErrorThreshold) + os.Exit(1) + } } else { utils.PrintLog("kafka messages sent successfully", "messagesCount", len(messages)) } } } -func Close() { - kafkaWriter.Close() +func periodicKafkaReconnect(kafka_url string, kafka_batch_size int, kafka_batch_time_secs_duration time.Duration) { + if kafkaReconnectIntervalMinutes <= 0 { + slog.Info("Kafka reconnection disabled", "interval", kafkaReconnectIntervalMinutes) + return + } + + ticker := time.NewTicker(time.Duration(kafkaReconnectIntervalMinutes) * time.Minute) + defer ticker.Stop() + + for range ticker.C { + slog.Info("Starting periodic Kafka reconnection", "interval_minutes", kafkaReconnectIntervalMinutes) + + // Create new writer + newWriter := getKafkaWriter(kafka_url, kafka_batch_size, kafka_batch_time_secs_duration) + newWriter.Completion = kafkaCompletion() + + // Test the new connection + ctx := context.Background() + value := map[string]string{ + "testConnectionString": "periodicReconnect", + } + out, _ := json.Marshal(value) + testMsg := kafka.Message{ + Topic: "akto.api.logs", + Value: out, + } + + err := newWriter.WriteMessages(ctx, testMsg) + if err != nil { + slog.Error("Failed to test new Kafka connection during periodic reconnect, keeping old connection", "error", err) + newWriter.Close() + continue + } + + // Replace old writer with new one + kafkaWriterMutex.Lock() + oldWriter := kafkaWriter + kafkaWriter = newWriter + kafkaWriterMutex.Unlock() + + // Close old writer + if oldWriter != nil { + slog.Info("Closing old Kafka writer") + oldWriter.Close() + } + + slog.Info("Kafka reconnection completed successfully") + } +} + +func getDaemonPodName() string { + aktoAgentName := os.Getenv("AKTO_AGENT_NAME") + podName := os.Getenv("POD_NAME") + nodeName := os.Getenv("NODE_NAME") + + if aktoAgentName != "" { + return fmt.Sprintf("akto-tc:%s", aktoAgentName) + } + + if podName != "" && nodeName != "" { + return fmt.Sprintf("akto-tc:%s:%s", podName, nodeName) + } + + hostname := os.Getenv("HOSTNAME") + if hostname == "" { + hostname = fmt.Sprintf("daemon-%s", uniqueDaemonsetId[:8]) + } + return fmt.Sprintf("akto-tc:%s", hostname) } +func getImageVersion() string { + imageVersion := os.Getenv("AKTO_IMAGE_VERSION") + if imageVersion == "" { + imageVersion = "aktosecurity/mirror-api-logging:k8s-ebpf" + } + return imageVersion +} + +// Heartbeat and config consumer functions moved to ebpf_telemetry.go + func LogKafkaStats() { + kafkaWriterMutex.RLock() + defer kafkaWriterMutex.RUnlock() + if kafkaWriter == nil { + return + } stats := kafkaWriter.Stats() slog.Debug("Kafka Stats", "dials", stats.Dials, @@ -200,7 +326,11 @@ func Produce(ctx context.Context, value *trafficpb.HttpResponseParam) error { Value: protoBytes, } - err = kafkaWriter.WriteMessages(ctx, msg) + kafkaWriterMutex.RLock() + writer := kafkaWriter + kafkaWriterMutex.RUnlock() + + err = writer.WriteMessages(ctx, msg) if err != nil { slog.Error("Kafka write for threat failed", "topic", topic, "error", err) return err @@ -234,6 +364,31 @@ const ( LogTypeDebug = "DEBUG" ) +func ProduceHeartbeat(ctx context.Context, heartbeatData map[string]string) error { + out, err := json.Marshal(heartbeatData) + if err != nil { + return err + } + + topic := "akto.daemonset.producer.heartbeats" + msg := kafka.Message{ + Topic: topic, + Value: []byte(string(out)), + } + + kafkaWriterMutex.RLock() + writer := kafkaWriter + kafkaWriterMutex.RUnlock() + + err = writer.WriteMessages(ctx, msg) + + if err != nil { + slog.Error("ERROR while writing heartbeat messages", "topic", topic, "error", err) + return err + } + return nil +} + func ProduceLogs(ctx context.Context, message string, logType string) error { value := map[string]string{ "message": message, @@ -250,7 +405,11 @@ func ProduceLogs(ctx context.Context, message string, logType string) error { Value: []byte(string(out)), } - err := kafkaWriter.WriteMessages(ctx, msg) + kafkaWriterMutex.RLock() + writer := kafkaWriter + kafkaWriterMutex.RUnlock() + + err := writer.WriteMessages(ctx, msg) if err != nil { slog.Error("ERROR while writing messages", "topic", topic, "error", err) @@ -259,21 +418,45 @@ func ProduceLogs(ctx context.Context, message string, logType string) error { return nil } -func ProduceStr(ctx context.Context, message string, url, reqHost string) error { - // initialize the writer with the broker addresses, and the topic +// buildCollectionDetailsHeader creates the collection_details Kafka header +// Format: "host|method|url" +// Returns nil if any parameter is empty (skip header for incomplete messages) +func buildCollectionDetailsHeader(host, method, url string) []kafka.Header { + if host == "" || method == "" || url == "" { + return nil + } + + headerValue := fmt.Sprintf("%s|%s|%s", host, method, url) + return []kafka.Header{ + { + Key: "collection_details", + Value: []byte(headerValue), + }, + } +} + +func ProduceStr(ctx context.Context, message string, url, reqHost, method string) error { topic := "akto.api.logs" + checkDebugUrlAndPrint(url, reqHost, "begin kafka write to akto.api.logs topic") + msg := kafka.Message{ - Topic: topic, - Value: []byte(message), + Topic: topic, + Value: []byte(message), + Headers: buildCollectionDetailsHeader(reqHost, method, url), } - err := kafkaWriter.WriteMessages(ctx, msg) + kafkaWriterMutex.RLock() + writer := kafkaWriter + kafkaWriterMutex.RUnlock() + + err := writer.WriteMessages(ctx, msg) if err != nil { slog.Error("ERROR while writing messages", "topic", topic, "error", err) + checkDebugUrlAndPrint(url, reqHost, fmt.Sprintf("Kafka write failed: %v", err)) return err } - checkDebugUrlAndPrint(url, reqHost, "Kafka write successful: "+message) + checkDebugUrlAndPrint(url, reqHost, "Kafka write successful: ") return nil } @@ -294,8 +477,45 @@ func NewTLSConfig(caPath string) (*tls.Config, error) { }, nil } -func getKafkaWriter(kafkaURL string, batchSize int, batchTimeout time.Duration) *kafka.Writer { +func getKafkaDialer() *kafka.Dialer { + dialer := &kafka.Dialer{} + + // Add TLS config if enabled + if useTLS { + tlsConfig, err := NewTLSConfig(tlsCACertPath) + if err != nil { + slog.Error("Failed to create TLS config", "error", err) + } else { + dialer.TLS = tlsConfig + } + } + // Add SASL auth if enabled + if isAuthImplemented && kafkaUsername != "" && kafkaPassword != "" { + slog.Info("Configuring SASL plain authentication", "username", kafkaUsername) + dialer.SASLMechanism = plain.Mechanism{ + Username: kafkaUsername, + Password: kafkaPassword, + } + } + + return dialer +} + +func getGlobalTransport() *kafka.Transport { + transportOnce.Do(func() { + dialer := getKafkaDialer() + globalTransport = &kafka.Transport{ + TLS: dialer.TLS, + SASL: dialer.SASLMechanism, + IdleTimeout: 30 * time.Second, + MetadataTTL: 60 * time.Second, + } + }) + return globalTransport +} + +func getKafkaWriter(kafkaURL string, batchSize int, batchTimeout time.Duration) *kafka.Writer { kafkaWriter := kafka.Writer{ Addr: kafka.TCP(kafkaURL), BatchSize: batchSize, @@ -308,11 +528,6 @@ func getKafkaWriter(kafkaURL string, batchSize int, batchTimeout time.Duration) Compression: kafka.Lz4, } - if useTLS { - tlsConfig, _ := NewTLSConfig(tlsCACertPath) - kafkaWriter.Transport = &kafka.Transport{ - TLS: tlsConfig, - } - } + kafkaWriter.Transport = getGlobalTransport() return &kafkaWriter } diff --git a/trafficUtil/kafkaUtil/lru_cache.go b/trafficUtil/kafkaUtil/lru_cache.go new file mode 100644 index 00000000..0c3cd16c --- /dev/null +++ b/trafficUtil/kafkaUtil/lru_cache.go @@ -0,0 +1,67 @@ +package kafkaUtil + +import ( + "container/list" + "sync" +) + +// LRUCache is a simple LRU cache for tracking recent request signatures +type LRUCache struct { + capacity int + cache map[string]*list.Element + list *list.List + mu sync.RWMutex +} + +type lruEntry struct { + key string + timeBucket uint8 // 0-255 representing time buckets +} + +// NewLRUCache creates a new LRU cache with the given capacity +func NewLRUCache(capacity int) *LRUCache { + return &LRUCache{ + capacity: capacity, + cache: make(map[string]*list.Element), + list: list.New(), + } +} + +// Get retrieves a value from the cache and moves it to the front (most recently used) +// Returns the time bucket and a boolean indicating if the key was found +func (c *LRUCache) Get(key string) (uint8, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if elem, found := c.cache[key]; found { + c.list.MoveToFront(elem) + return elem.Value.(*lruEntry).timeBucket, true + } + return 0, false +} + +// Put inserts or updates a key-value pair in the cache +// If the cache is at capacity, the least recently used entry is evicted +func (c *LRUCache) Put(key string, timeBucket uint8) { + c.mu.Lock() + defer c.mu.Unlock() + + if elem, found := c.cache[key]; found { + c.list.MoveToFront(elem) + elem.Value.(*lruEntry).timeBucket = timeBucket + return + } + + if c.list.Len() >= c.capacity { + // Evict oldest (least recently used) + oldest := c.list.Back() + if oldest != nil { + c.list.Remove(oldest) + delete(c.cache, oldest.Value.(*lruEntry).key) + } + } + + entry := &lruEntry{key: key, timeBucket: timeBucket} + elem := c.list.PushFront(entry) + c.cache[key] = elem +} diff --git a/trafficUtil/kafkaUtil/lru_cache_test.go b/trafficUtil/kafkaUtil/lru_cache_test.go new file mode 100644 index 00000000..7a9e1a1d --- /dev/null +++ b/trafficUtil/kafkaUtil/lru_cache_test.go @@ -0,0 +1,123 @@ +package kafkaUtil + +import ( + "testing" +) + +// TestLRUCache tests the LRU cache implementation +func TestLRUCache(t *testing.T) { + cache := NewLRUCache(3) + + // Test Put and Get + cache.Put("key1", 1) + bucket, found := cache.Get("key1") + if !found || bucket != 1 { + t.Errorf("Expected key1=1, got found=%v, bucket=%d", found, bucket) + } + + // Test Get non-existent key + _, found = cache.Get("nonexistent") + if found { + t.Errorf("Expected nonexistent key to not be found") + } + + // Test capacity and eviction + cache.Put("key2", 2) + cache.Put("key3", 3) + cache.Put("key4", 4) // Should evict key1 + + _, found = cache.Get("key1") + if found { + t.Errorf("Expected key1 to be evicted") + } + + // Check that key4 is present + bucket, found = cache.Get("key4") + if !found || bucket != 4 { + t.Errorf("Expected key4=4, got found=%v, bucket=%d", found, bucket) + } +} + +// TestLRUCacheEviction tests that LRU cache correctly evicts oldest entries +func TestLRUCacheEviction(t *testing.T) { + cache := NewLRUCache(2) + + // Add 2 entries + cache.Put("a", 1) + cache.Put("b", 2) + + // Both should be present + _, foundA := cache.Get("a") + _, foundB := cache.Get("b") + if !foundA || !foundB { + t.Errorf("Both entries should be present") + } + + // Add third entry (should evict oldest) + cache.Put("c", 3) + + // "a" should be evicted (least recently used) + _, foundA = cache.Get("a") + if foundA { + t.Errorf("Entry 'a' should be evicted") + } + + // "b" and "c" should still be present + _, foundB = cache.Get("b") + _, foundC := cache.Get("c") + if !foundB || !foundC { + t.Errorf("Entries 'b' and 'c' should be present") + } +} + +// TestLRUCacheUpdate tests that updating an entry moves it to front +func TestLRUCacheUpdate(t *testing.T) { + cache := NewLRUCache(2) + + cache.Put("a", 1) + cache.Put("b", 2) + + // Access "a" to move it to front + cache.Get("a") + + // Add third entry (should evict "b" since "a" was more recently accessed) + cache.Put("c", 3) + + // "b" should be evicted, not "a" + _, foundA := cache.Get("a") + _, foundB := cache.Get("b") + if !foundA { + t.Errorf("Entry 'a' should not be evicted") + } + if foundB { + t.Errorf("Entry 'b' should be evicted") + } +} + +// BenchmarkLRUCacheGet benchmarks LRU cache Get operation +func BenchmarkLRUCacheGet(b *testing.B) { + cache := NewLRUCache(10000) + + // Pre-populate + for i := 0; i < 1000; i++ { + key := "key" + string(rune(i)) + cache.Put(key, uint8(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + string(rune(i%1000)) + cache.Get(key) + } +} + +// BenchmarkLRUCachePut benchmarks LRU cache Put operation +func BenchmarkLRUCachePut(b *testing.B) { + cache := NewLRUCache(10000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := "key" + string(rune(i%1000)) + cache.Put(key, uint8(i%256)) + } +} diff --git a/trafficUtil/kafkaUtil/parser.go b/trafficUtil/kafkaUtil/parser.go index b5acdfd0..02d19c52 100644 --- a/trafficUtil/kafkaUtil/parser.go +++ b/trafficUtil/kafkaUtil/parser.go @@ -19,8 +19,241 @@ import ( trafficpb "github.com/akto-api-security/mirroring-api-logging/trafficUtil/protobuf/traffic_payload" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/trafficMetrics" "github.com/akto-api-security/mirroring-api-logging/trafficUtil/utils" + bloomfilter "github.com/bits-and-blooms/bloom/v3" ) +// TrafficContext holds metadata about the captured traffic. +// This consolidates the many parameters previously passed to ParseAndProduce. +type TrafficContext struct { + SourceIP string + DestIP string + VxlanID int + IsPending bool + TrafficSource string + IsComplete bool + Direction int + ProcessID uint32 // Extracted from idfd >> 32 + SocketFD uint32 + DaemonsetIdentifier string + HostName string +} + +// ParsedTraffic holds the parsed HTTP requests and responses with their bodies. +type ParsedTraffic struct { + Requests []http.Request + RequestBodies []string + Responses []http.Response + ResponseBodies []string +} + +// HeaderSet holds HTTP headers in both protobuf and string map formats. +type HeaderSet struct { + Protobuf map[string]*trafficpb.StringList + StringMap map[string]string +} + +// ConvertedHeaders holds converted headers for both request and response. +type ConvertedHeaders struct { + Request HeaderSet + Response HeaderSet + DebugID string // x-debug-token value if present +} + +// PayloadInput contains the data needed to build traffic payloads. +type PayloadInput struct { + Request *http.Request + Response *http.Response + Headers ConvertedHeaders + RequestBody string + ResponseBody string + SourceIP string // IP after GetSourceIp processing + Context TrafficContext +} + +// buildProtobufPayload creates the protobuf payload for the threat client. +func buildProtobufPayload(input PayloadInput) *trafficpb.HttpResponseParam { + return &trafficpb.HttpResponseParam{ + Method: input.Request.Method, + Path: input.Request.URL.String(), + RequestHeaders: input.Headers.Request.Protobuf, + ResponseHeaders: input.Headers.Response.Protobuf, + RequestPayload: input.RequestBody, + ResponsePayload: input.ResponseBody, + Ip: input.SourceIP, + DestIp: input.Context.DestIP, + Time: int32(time.Now().Unix()), + StatusCode: int32(input.Response.StatusCode), + Type: string(input.Request.Proto), + Status: input.Response.Status, + AktoAccountId: fmt.Sprint(1000000), + AktoVxlanId: fmt.Sprint(input.Context.VxlanID), + IsPending: input.Context.IsPending, + Source: input.Context.TrafficSource, + } +} + +// buildJSONPayload creates the JSON map payload (legacy format, TODO: remove). +func buildJSONPayload(input PayloadInput) map[string]string { + reqHeaderString, _ := json.Marshal(input.Headers.Request.StringMap) + respHeaderString, _ := json.Marshal(input.Headers.Response.StringMap) + + return map[string]string{ + "path": input.Request.URL.String(), + "requestHeaders": string(reqHeaderString), + "responseHeaders": string(respHeaderString), + "method": input.Request.Method, + "requestPayload": input.RequestBody, + "responsePayload": input.ResponseBody, + "ip": input.Context.SourceIP, + "destIp": input.Context.DestIP, + "time": fmt.Sprint(time.Now().Unix()), + "statusCode": fmt.Sprint(input.Response.StatusCode), + "type": string(input.Request.Proto), + "status": input.Response.Status, + "akto_account_id": fmt.Sprint(1000000), + "akto_vxlan_id": fmt.Sprint(input.Context.VxlanID), + "is_pending": fmt.Sprint(input.Context.IsPending), + "source": input.Context.TrafficSource, + "direction": fmt.Sprint(input.Context.Direction), + "process_id": fmt.Sprint(input.Context.ProcessID), + "socket_id": fmt.Sprint(input.Context.SocketFD), + "daemonset_id": fmt.Sprint(input.Context.DaemonsetIdentifier), + "enable_graph": fmt.Sprint(utils.EnableGraph), + } +} + +// resolvePodLabels resolves pod labels for inbound traffic and adds them to the value map. +func resolvePodLabels(value map[string]string, ctx TrafficContext, url, host string) { + + if PodInformerInstance == nil { + checkDebugUrlAndPrint(url, host, "Pod labels not resolved, PodInformerInstance is nil") + return + } + + if ctx.Direction == utils.DirectionOutbound { + checkDebugUrlAndPrint(url, host, fmt.Sprintf("Pod labels not resolved for outbound request, podName: %s, direction: %v", ctx.HostName, ctx.Direction)) + return + } + + processName := PodInformerInstance.GetProcessNameByProcessId(int32(ctx.ProcessID)) + if strings.Contains(processName, "envoy") { + checkDebugUrlAndPrint(url, host, fmt.Sprintf("Pod labels not resolved for envoy request, podName: %s, direction: %v", ctx.HostName, ctx.Direction)) + return + } + + if ctx.HostName == "" { + checkDebugUrlAndPrint(url, host, "Failed to resolve pod name, hostName is empty for processId "+fmt.Sprint(ctx.ProcessID)) + slog.Debug("Failed to resolve pod name, hostName is empty for ", "processId", ctx.ProcessID, "hostName", ctx.HostName) + return + } + + podLabels, err := PodInformerInstance.ResolvePodLabels(ctx.HostName, url, host) + if err != nil { + slog.Error("Failed to resolve pod labels", "hostName", ctx.HostName, "error", err) + checkDebugUrlAndPrint(url, host, "Error resolving pod labels "+ctx.HostName) + return + } + + value["tag"] = podLabels + checkDebugUrlAndPrint(url, host, "Pod labels found in ParseAndProduce, podLabels found "+fmt.Sprint(podLabels)+" for hostName "+ctx.HostName) +} + +func mergeInjectTags(value map[string]string) { + if len(injectTagsMap) == 0 { + return + } + + merged := map[string]string{} + for k, v := range injectTagsMap { + merged[k] = v + } + + // Parse and merge any existing tag JSON (e.g. from pod labels) + if existing, ok := value["tag"]; ok && existing != "" { + podLabelMap := map[string]string{} + if err := json.Unmarshal([]byte(existing), &podLabelMap); err == nil { + for k, v := range podLabelMap { + merged[k] = v // pod labels overwrite inject tags on conflict + } + } + } + + if b, err := json.Marshal(merged); err == nil { + value["tag"] = string(b) + } +} + +// convertHeaders converts HTTP headers to both protobuf and string map formats in a single pass. +func convertHeaders(req *http.Request, resp *http.Response, shouldPrint bool) ConvertedHeaders { + result := ConvertedHeaders{ + Request: HeaderSet{ + Protobuf: make(map[string]*trafficpb.StringList), + StringMap: make(map[string]string), + }, + Response: HeaderSet{ + Protobuf: make(map[string]*trafficpb.StringList), + StringMap: make(map[string]string), + }, + } + + // Convert request headers + for name, values := range req.Header { + for _, value := range values { + result.Request.Protobuf[strings.ToLower(name)] = &trafficpb.StringList{ + Values: []string{value}, + } + if shouldPrint && strings.EqualFold(name, "x-debug-token") { + result.DebugID = value + } + result.Request.StringMap[name] = value + } + } + result.Request.Protobuf["host"] = &trafficpb.StringList{Values: []string{req.Host}} + result.Request.StringMap["host"] = req.Host + + // Convert response headers + for name, values := range resp.Header { + for _, value := range values { + result.Response.Protobuf[strings.ToLower(name)] = &trafficpb.StringList{ + Values: []string{value}, + } + result.Response.StringMap[name] = value + } + } + + return result +} + +// shouldProcessRequest checks all filter conditions and returns true if the request should be processed. +func shouldProcessRequest(req *http.Request, reqHeaders map[string]string, ctx TrafficContext) bool { + if !IsValidMethod(req.Method) { + return false + } + + if !utils.PassesFilter(trafficMetrics.FilterHeaderValueMap, reqHeaders) { + return false + } + + if utils.IgnoreIpTraffic && utils.CheckIfIp(req.Host) { + return false + } + + if utils.IgnoreCloudMetadataCalls && req.Host == "169.254.169.254" { + return false + } + + if utils.IgnoreEnvoyProxycalls && ctx.SourceIP == utils.EnvoyProxyIp && ctx.Direction == utils.DirectionOutbound { + slog.Debug("Ignoring outbound envoy proxy call", "sourceIp", ctx.SourceIP, "url", req.URL.String(), "host", req.Host) + return false + } + + if utils.FilterPacket(reqHeaders) { + return false + } + + return true +} + var ( goodRequests = 0 badRequests = 0 @@ -29,6 +262,7 @@ var ( currentBandwidthProcessed = 0 lastSampleUpdate = time.Now().Unix() sampleMutex = sync.RWMutex{} + injectTagsMap = map[string]string{} methodsMap = map[string]bool{ "GET": true, "HEAD": true, @@ -43,14 +277,30 @@ var ( DebugStrings = []string{} EventChanBuffSize = 100000 + + // Body parsing optimization variables + lruCache *LRUCache + lruCacheCapacity = 100000 + bloomFilterCapacity = 1000000 + bloomFilterFPRate = 0.01 + timeBucketDuration = 10 * time.Minute + memSamplingEnabled = false ) +var bloomFilter *bloomfilter.BloomFilter + const ONE_MINUTE = 60 func init() { utils.InitVar("DEBUG_MODE", &debugMode) utils.InitVar("OUTPUT_BANDWIDTH_LIMIT", &outputBandwidthLimitPerMin) utils.InitVar("EVENT_CHAN_BUFF_SIZE", &EventChanBuffSize) + utils.InitVar("AKTO_MEM_SAMPLING_ENABLED", &memSamplingEnabled) + utils.InitVar("LRU_CACHE_CAPACITY", &lruCacheCapacity) + utils.InitVar("BLOOM_FILTER_CAPACITY", &bloomFilterCapacity) + utils.InitVar("BLOOM_FILTER_FP_RATE", &bloomFilterFPRate) + utils.InitVar("TIME_BUCKET_DURATION_MINUTES", &timeBucketDuration) + // convert MB to B if outputBandwidthLimitPerMin != -1 { outputBandwidthLimitPerMin = outputBandwidthLimitPerMin * 1024 * 1024 @@ -62,12 +312,41 @@ func init() { } slog.Info("debugStrings", "DebugStrings", DebugStrings) + // Only initialize Bloom Filter and LRU Cache if memory sampling is enabled + if memSamplingEnabled { + bloomFilter = bloomfilter.NewWithEstimates(uint(bloomFilterCapacity), bloomFilterFPRate) + + // Initialize LRU Cache + lruCache = NewLRUCache(lruCacheCapacity) + + // Reset Bloom filter every 24 hours to prevent permanent false positives + go func() { + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + for range ticker.C { + bloomFilter.ClearAll() + } + }() + } + + injectTagsEnv := "" + utils.InitVar("AKTO_INJECT_TAGS", &injectTagsEnv) + if injectTagsEnv != "" { + for _, pair := range strings.Split(injectTagsEnv, ";") { + pair = strings.TrimSpace(pair) + if idx := strings.IndexByte(pair, '='); idx > 0 { + k := strings.TrimSpace(pair[:idx]) + v := strings.TrimSpace(pair[idx+1:]) + if k != "" { + injectTagsMap[k] = v + } + } + } + slog.Info("AKTO_INJECT_TAGS loaded", "tags", injectTagsMap) + } + // Start ticker to read debug URLs from file every 30 seconds go func() { - if !utils.FileLoggingEnabled { - slog.Info("File logging is not enabled, skipping debug URL file watcher") - return - } ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { @@ -129,13 +408,13 @@ func checkDebugUrlAndPrint(url string, host string, message string) { for _, debugString := range DebugStrings { if strings.Contains(url, debugString) { ctx := context.Background() - logMsg := fmt.Sprintf("%s : %s", message, url) + logMsg := fmt.Sprintf("url: %s, host: %s, message: %s", url, host, message) utils.PrintLogDebug(logMsg) go ProduceLogs(ctx, logMsg, LogTypeInfo) break } else if strings.Contains(host, debugString) { ctx := context.Background() - logMsg := fmt.Sprintf("%s : %s", message, host) + logMsg := fmt.Sprintf("url: %s, host: %s, message: %s", url, host, message) utils.PrintLogDebug(logMsg) go ProduceLogs(ctx, logMsg, LogTypeInfo) break @@ -173,22 +452,51 @@ func IsValidMethod(method string) bool { return ok } -func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, sourceIp string, destIp string, vxlanID int, isPending bool, - trafficSource string, isComplete bool, direction int, idfd uint64, fd uint32, daemonsetIdentifier string, hostName string) { +// shouldParseBody returns true if body should be parsed for this request. +// Uses Bloom Filter + LRU Cache for memory-efficient tracking. +// Only applies optimization if memSamplingEnabled is true. +func shouldParseBody(method, host, path string) bool { + // Only apply body parsing optimization if memory sampling is enabled + if !memSamplingEnabled { + return true + } - if checkAndUpdateBandwidthProcessed(0) { - return + key := buildSignatureKey(method, host, path) + + // Step 1: Check Bloom Filter (fast, probabilistic) + if !bloomFilter.TestString(key) { + // Definitely first time seeing this signature + bloomFilter.AddString(key) + lruCache.Put(key, getTimeBucket()) + return true } - shouldPrint := debugMode && strings.Contains(string(receiveBuffer), "x-debug-token") - if shouldPrint { - slog.Debug("ParseAndProduce", "receiveBuffer", string(receiveBuffer), "sentBuffer", string(sentBuffer)) + // Step 2: Bloom filter says "maybe seen before" - check LRU for precise tracking + if timeBucket, found := lruCache.Get(key); found { + // Check if time bucket has expired + if isTimeBucketExpired(timeBucket) { + // More than 10 minutes since last parse + lruCache.Put(key, getTimeBucket()) + return true + } + // Recently parsed, skip body + return false } - reader := bufio.NewReader(bytes.NewReader(receiveBuffer)) - i := 0 + // Step 3: In Bloom but not in LRU (evicted or false positive) + // Treat as new - parse body and add to LRU + lruCache.Put(key, getTimeBucket()) + return true +} + +// parseHTTPTraffic parses HTTP requests and responses from raw byte buffers. +// Returns nil if parsing fails (errors are logged). +func parseHTTPTraffic(reqBuffer, respBuffer []byte, shouldPrint bool) *ParsedTraffic { + // Parse requests + reader := bufio.NewReader(bytes.NewReader(reqBuffer)) requests := []http.Request{} - requestsContent := []string{} + requestBodies := []string{} + parseBodyFlags := []bool{} // track which requests should have body parsed for { req, err := http.ReadRequest(reader) @@ -196,84 +504,136 @@ func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, sourceIp string, d break } else if err != nil { utils.PrintLog(fmt.Sprintf("HTTP-request error: %s \n", err)) - return + return nil } - body, err := io.ReadAll(req.Body) - req.Body.Close() - if err != nil { - utils.PrintLog(fmt.Sprintf("Got body err: %s\n", err)) - return + + // Determine if we should parse body for this request + parseBody := shouldParseBody(req.Method, req.Host, req.URL.Path) + + var body []byte + if parseBody { + body, err = io.ReadAll(req.Body) + if err != nil { + utils.PrintLog(fmt.Sprintf("Got body err: %s\n", err)) + body = []byte{} + } + } else { + // Skip body parsing - MUST drain from bufio.Reader to avoid corrupting next request + io.Copy(io.Discard, req.Body) + body = []byte{} + // Inject discovery-only header in REQUEST + req.Header.Set("x-akto-skip-sample-update", "true") } + req.Body.Close() requests = append(requests, *req) - requestsContent = append(requestsContent, string(body)) - i++ + requestBodies = append(requestBodies, string(body)) + parseBodyFlags = append(parseBodyFlags, parseBody) } if shouldPrint { - slog.Debug("ParseAndProduce", "count", i) + slog.Debug("parseHTTPTraffic", "requestCount", len(requests)) } + if len(requests) == 0 { - return + return nil } - reader = bufio.NewReader(bytes.NewReader(sentBuffer)) - i = 0 - + // Parse responses + reader = bufio.NewReader(bytes.NewReader(respBuffer)) responses := []http.Response{} - responsesContent := []string{} - - for { + responseBodies := []string{} + for i := 0; ; i++ { resp, err := http.ReadResponse(reader, nil) if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { utils.PrintLog(fmt.Sprintf("HTTP-Response error: %s\n", err)) - return + return nil } - body, err := io.ReadAll(resp.Body) - if err != nil { - utils.PrintLog(fmt.Sprintf("Got err reading resp body: %s\n", err)) - return - } - encoding := resp.Header["Content-Encoding"] - var r io.Reader - r = bytes.NewBuffer(body) - if len(encoding) > 0 && (encoding[0] == "gzip" || encoding[0] == "deflate") { - r, err = gzip.NewReader(r) + var body []byte + // Only parse response body if we parsed the corresponding request body + shouldParseRespBody := i < len(parseBodyFlags) && parseBodyFlags[i] + + if shouldParseRespBody { + body, err = io.ReadAll(resp.Body) if err != nil { - utils.PrintLog(fmt.Sprintf("HTTP-gunzip "+"Failed to gzip decode: %s", err)) - return + utils.PrintLog(fmt.Sprintf("Got err reading resp body: %s\n", err)) + body = []byte{} } - } - if err == nil { - body, err = io.ReadAll(r) - if err != nil { - utils.PrintLog(fmt.Sprintf("Failed to read decompressed body: %s\n", err)) - return + + // Handle gzip/deflate decompression + encoding := resp.Header["Content-Encoding"] + var r io.Reader + r = bytes.NewBuffer(body) + if len(encoding) > 0 && (encoding[0] == "gzip" || encoding[0] == "deflate") { + r, err = gzip.NewReader(r) + if err != nil { + utils.PrintLog(fmt.Sprintf("HTTP-gunzip "+"Failed to gzip decode: %s", err)) + body = []byte{} + } } - if _, ok := r.(*gzip.Reader); ok { - r.(*gzip.Reader).Close() + if err == nil { + body, err = io.ReadAll(r) + if err != nil { + utils.PrintLog(fmt.Sprintf("Failed to read decompressed body: %s\n", err)) + body = []byte{} + } + if _, ok := r.(*gzip.Reader); ok { + r.(*gzip.Reader).Close() + } } + } else { + // Skip response body - MUST drain from bufio.Reader + io.Copy(io.Discard, resp.Body) + body = []byte{} } + resp.Body.Close() responses = append(responses, *resp) - responsesContent = append(responsesContent, string(body)) + responseBodies = append(responseBodies, string(body)) + } + + if shouldPrint { + slog.Debug("parseHTTPTraffic", "responseCount", len(responses)) + } + + return &ParsedTraffic{ + Requests: requests, + RequestBodies: requestBodies, + Responses: responses, + ResponseBodies: responseBodies, + } +} - i++ +func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, ctx TrafficContext) { + + if checkAndUpdateBandwidthProcessed(0) { + return } + shouldPrint := debugMode && strings.Contains(string(receiveBuffer), "x-debug-token") if shouldPrint { + slog.Debug("ParseAndProduce", "receiveBuffer", string(receiveBuffer), "sentBuffer", string(sentBuffer)) + } - slog.Debug("ParseAndProduce", "count", i) + parsed := parseHTTPTraffic(receiveBuffer, sentBuffer, shouldPrint) + if parsed == nil { + return } + + requests := parsed.Requests + requestsContent := parsed.RequestBodies + responses := parsed.Responses + responsesContent := parsed.ResponseBodies + if len(requests) != len(responses) { if shouldPrint { - slog.Debug("Len req-res mismatch", "lenRequests", len(requests), "lenResponses", len(responses), "lenReceiveBuffer", len(receiveBuffer), "lenSentBuffer", len(sentBuffer), "isComplete", isComplete) + slog.Debug("Len req-res mismatch", "lenRequests", len(requests), "lenResponses", len(responses), "lenReceiveBuffer", len(receiveBuffer), "lenSentBuffer", len(sentBuffer), "isComplete", ctx.IsComplete) } - if isComplete { + if ctx.IsComplete { return } correctLen := len(requests) @@ -283,192 +643,68 @@ func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, sourceIp string, d responses = responses[:correctLen] requests = requests[:correctLen] + responsesContent = responsesContent[:correctLen] + requestsContent = requestsContent[:correctLen] } - i = 0 - for { - if len(requests) < i+1 { - break - } + bgCtx := context.Background() + for i := 0; i < len(requests); i++ { req := &requests[i] resp := &responses[i] - if !IsValidMethod(req.Method) { - continue - } - - id := "" - - // build req headers for threat client - reqHeader := make(map[string]*trafficpb.StringList) - for name, values := range req.Header { - // Loop over all values for the name. - for _, value := range values { - reqHeader[strings.ToLower(name)] = &trafficpb.StringList{ - Values: []string{value}, - } - } - } - ip := GetSourceIp(reqHeader, sourceIp) - - reqHeader["host"] = &trafficpb.StringList{ - Values: []string{req.Host}, - } - - reqHeaderStr := make(map[string]string) - for name, values := range req.Header { - // Loop over all values for the name. - for _, value := range values { - if shouldPrint && - strings.EqualFold(name, "x-debug-token") { - id = value - } - reqHeaderStr[name] = value - } - } - - reqHeaderStr["host"] = req.Host - - passes := utils.PassesFilter(trafficMetrics.FilterHeaderValueMap, reqHeaderStr) - //printLog("Req header: " + mapToString(reqHeaderStr)) - //printLog(fmt.Sprintf("passes %t", passes)) - - if !passes { - i++ - continue - } - - if utils.IgnoreIpTraffic && utils.CheckIfIp(req.Host) { - i++ - continue - } - - if utils.IgnoreCloudMetadataCalls && req.Host == "169.254.169.254" { - i++ - continue - } - - if utils.IgnoreEnvoyProxycalls && sourceIp == utils.EnvoyProxyIp && direction == utils.DirectionOutbound { - slog.Debug("Ignoring outbound envoy proxy call", "sourceIp", sourceIp, "url", req.URL.String(), "host", req.Host) - i++ - continue - } + url := req.URL.String() + checkDebugUrlAndPrint(url, req.Host, "URL,host found in ParseAndProduce") - var skipPacket = utils.FilterPacket(reqHeaderStr) + // Convert headers in a single pass (both protobuf and string map formats) + headers := convertHeaders(req, resp, shouldPrint) - if skipPacket { - i++ + // Check all filter conditions + if !shouldProcessRequest(req, headers.Request.StringMap, ctx) { continue } - // build resp headers for threat client - respHeader := make(map[string]*trafficpb.StringList) - for name, values := range resp.Header { - // Loop over all values for the name. - for _, value := range values { - respHeader[strings.ToLower(name)] = &trafficpb.StringList{ - Values: []string{value}, - } - } - } + // Get source IP from headers + ip := GetSourceIp(headers.Request.Protobuf, ctx.SourceIP) - // TODO: remove and use protobuf instead - respHeaderStr := make(map[string]string) - for name, values := range resp.Header { - // Loop over all values for the name. - for _, value := range values { - respHeaderStr[name] = value - } + // Build payloads + input := PayloadInput{ + Request: req, + Response: resp, + Headers: headers, + RequestBody: requestsContent[i], + ResponseBody: responsesContent[i], + SourceIP: ip, + Context: ctx, } - url := req.URL.String() - checkDebugUrlAndPrint(url, req.Host, "URL,host found in ParseAndProduce") + value := buildJSONPayload(input) - // build kafka payload for threat client - payload := &trafficpb.HttpResponseParam{ - Method: req.Method, - Path: req.URL.String(), - RequestHeaders: reqHeader, - ResponseHeaders: respHeader, - RequestPayload: requestsContent[i], - ResponsePayload: responsesContent[i], - Ip: ip, - Time: int32(time.Now().Unix()), - StatusCode: int32(resp.StatusCode), - Type: string(req.Proto), - Status: resp.Status, - AktoAccountId: fmt.Sprint(1000000), - AktoVxlanId: fmt.Sprint(vxlanID), - IsPending: isPending, - } - - reqHeaderString, _ := json.Marshal(reqHeaderStr) - respHeaderString, _ := json.Marshal(respHeaderStr) - - // TODO: remove and use protobuf instead - value := map[string]string{ - "path": req.URL.String(), - "requestHeaders": string(reqHeaderString), - "responseHeaders": string(respHeaderString), - "method": req.Method, - "requestPayload": requestsContent[i], - "responsePayload": responsesContent[i], - "ip": sourceIp, - "destIp": destIp, - "time": fmt.Sprint(time.Now().Unix()), - "statusCode": fmt.Sprint(resp.StatusCode), - "type": string(req.Proto), - "status": resp.Status, - "akto_account_id": fmt.Sprint(1000000), - "akto_vxlan_id": fmt.Sprint(vxlanID), - "is_pending": fmt.Sprint(isPending), - "source": trafficSource, - "direction": fmt.Sprint(direction), - "process_id": fmt.Sprint(idfd >> 32), - "socket_id": fmt.Sprint(fd), - "daemonset_id": fmt.Sprint(daemonsetIdentifier), - "enable_graph": fmt.Sprint(utils.EnableGraph), - } - - // Process id was captured from the eBPF program using bpf_get_current_pid_tgid() - // Shifting by 32 gives us the process id on host machine. - var pid = idfd >> 32 - log := fmt.Sprintf("pod direction log: direction=%v, host=%v, path=%v, sourceIp=%v, destIp=%v, socketId=%v, processId=%v, hostName=%v", - direction, - reqHeaderStr["host"], + // Debug logging + log := fmt.Sprintf("before resolving pod labels direction log: direction=%v, host=%v, path=%v, sourceIp=%v, destIp=%v, socketId=%v, processId=%v, hostName=%v", + ctx.Direction, + headers.Request.StringMap["host"], value["path"], - sourceIp, - destIp, + ctx.SourceIP, + ctx.DestIP, value["socket_id"], - pid, - hostName, + ctx.ProcessID, + ctx.HostName, ) - utils.PrintLog(log) checkDebugUrlAndPrint(url, req.Host, log) - if PodInformerInstance != nil && direction == utils.DirectionInbound { + // Resolve pod labels for inbound traffic + resolvePodLabels(value, ctx, url, req.Host) - if hostName == "" { - checkDebugUrlAndPrint(url, req.Host, "Failed to resolve pod name, hostName is empty for processId "+fmt.Sprint(pid)) - slog.Error("Failed to resolve pod name, hostName is empty for ", "processId", pid, "hostName", hostName) - } else { - podLabels, err := PodInformerInstance.ResolvePodLabels(hostName, url, req.Host) - if err != nil { - slog.Error("Failed to resolve pod labels", "hostName", hostName, "error", err) - checkDebugUrlAndPrint(url, req.Host, "Error resolving pod labels "+hostName) - } else { - value["tag"] = podLabels - checkDebugUrlAndPrint(url, req.Host, "Pod labels found in ParseAndProduce, podLabels found "+fmt.Sprint(podLabels)+" for hostName "+hostName) - slog.Debug("Pod labels", "podName", hostName, "labels", podLabels) - } - } - } else { - checkDebugUrlAndPrint(url, req.Host, "Pod labels not resolved, PodInformerInstance is nil or direction is not inbound, direction: "+fmt.Sprint(direction)) - } + mergeInjectTags(value) - out, _ := json.Marshal(value) - ctx := context.Background() + checkDebugUrlAndPrint(url, req.Host, "After pod labels URL,host marshalling to JSON") + out, err := json.Marshal(value) + if err != nil { + slog.Error("Failed to json marshal the payload", "error", err) + checkDebugUrlAndPrint(url, req.Host, fmt.Sprintf("json marshal payload failed %v", err)) + return + } // calculating the size of outgoing bytes and requests (1) and saving it in outgoingCounterMap // this number is the closest (slightly higher) to the actual connection transfer bytes. @@ -478,36 +714,42 @@ func ParseAndProduce(receiveBuffer []byte, sentBuffer []byte, sourceIp string, d return } - hostString := reqHeaderStr["host"] - if utils.CheckIfIpHost(hostString) { - hostString = "ip-host" - } - oc := utils.GenerateOutgoingCounter(vxlanID, sourceIp, hostString) - trafficMetrics.SubmitOutgoingTrafficMetrics(oc, outgoingBytes) + if apiProcessor.CloudProcessorInstance != nil { + apiProcessor.CloudProcessorInstance.Produce(value) - if shouldPrint { - if strings.Contains(responsesContent[i], id) { - goodRequests++ - } else { - slog.Debug("req-resp.String()", "out", string(out)) - badRequests++ - } + } else { + // Produce to kafka with collection_details header + go ProduceStr(bgCtx, string(out), url, req.Host, req.Method) - if goodRequests%100 == 0 || badRequests%100 == 0 { - slog.Debug("Good requests", "count", goodRequests, "badRequests", badRequests) + // Only if threat enabled + if utils.ThreatEnabled { + payload := buildProtobufPayload(input) + go Produce(bgCtx, payload) } } - if apiProcessor.CloudProcessorInstance != nil { - apiProcessor.CloudProcessorInstance.Produce(value) + sendMetrics(headers, ctx, outgoingBytes, shouldPrint, responsesContent, i, out) + } +} + +func sendMetrics(headers ConvertedHeaders, ctx TrafficContext, outgoingBytes int, shouldPrint bool, responsesContent []string, i int, out []byte) { + hostString := headers.Request.StringMap["host"] + if utils.CheckIfIpHost(hostString) { + hostString = "ip-host" + } + oc := utils.GenerateOutgoingCounter(ctx.VxlanID, ctx.SourceIP, hostString) + trafficMetrics.SubmitOutgoingTrafficMetrics(oc, outgoingBytes) + if shouldPrint { + if strings.Contains(responsesContent[i], headers.DebugID) { + goodRequests++ } else { - // Produce to kafka - // TODO : remove and use protobuf instead - go ProduceStr(ctx, string(out), url, req.Host) - go Produce(ctx, payload) + slog.Debug("req-resp.String()", "out", string(out)) + badRequests++ } - i++ + if goodRequests%10 == 0 || badRequests%10 == 0 { + slog.Debug("Good requests", "count", goodRequests, "badRequests", badRequests) + } } } diff --git a/trafficUtil/kafkaUtil/parser_test.go b/trafficUtil/kafkaUtil/parser_test.go new file mode 100644 index 00000000..c0cd8ed9 --- /dev/null +++ b/trafficUtil/kafkaUtil/parser_test.go @@ -0,0 +1,102 @@ +package kafkaUtil + +import ( + "fmt" + "testing" + bloomfilter "github.com/bits-and-blooms/bloom/v3" +) + +// TestShouldParseBodyFirstRequest tests that first request body is parsed +func TestShouldParseBodyFirstRequest(t *testing.T) { + // Reset Bloom filter for testing + bloomFilter = bloomfilter.NewWithEstimates(uint(bloomFilterCapacity), bloomFilterFPRate) + lruCache = NewLRUCache(lruCacheCapacity) + + // First request should always be parsed + if !shouldParseBody("GET", "example.com", "/api/test") { + t.Errorf("First request should always parse body") + } +} + +// TestShouldParseBodySkipRecent tests that recent requests skip body parsing +func TestShouldParseBodySkipRecent(t *testing.T) { + // Reset Bloom filter and LRU for testing + bloomFilter = bloomfilter.NewWithEstimates(uint(bloomFilterCapacity), bloomFilterFPRate) + lruCache = NewLRUCache(lruCacheCapacity) + + method, host, path := "POST", "api.example.com", "/v1/submit" + + // First request should parse body + shouldParse1 := shouldParseBody(method, host, path) + if !shouldParse1 { + t.Errorf("First request should parse body") + } + + // Immediate second request should skip body + shouldParse2 := shouldParseBody(method, host, path) + if shouldParse2 { + t.Errorf("Recent request should skip body, but got shouldParse=%v", shouldParse2) + } + + // Third request should also skip body + shouldParse3 := shouldParseBody(method, host, path) + if shouldParse3 { + t.Errorf("Recent request should skip body, but got shouldParse=%v", shouldParse3) + } +} + +// TestShouldParseBodyDifferentSignatures tests different signatures are tracked separately +func TestShouldParseBodyDifferentSignatures(t *testing.T) { + // Reset Bloom filter and LRU for testing + bloomFilter = bloomfilter.NewWithEstimates(uint(bloomFilterCapacity), bloomFilterFPRate) + lruCache = NewLRUCache(lruCacheCapacity) + + // First signature + should1a := shouldParseBody("GET", "example.com", "/api/users") + if !should1a { + t.Errorf("First request for signature 1 should parse body") + } + + // Skip immediate request for signature 1 + should1b := shouldParseBody("GET", "example.com", "/api/users") + if should1b { + t.Errorf("Second request for signature 1 should skip body") + } + + // Different signature - should parse body + should2a := shouldParseBody("GET", "example.com", "/api/orders") + if !should2a { + t.Errorf("First request for signature 2 should parse body") + } + + // Different method - should parse body + should3a := shouldParseBody("POST", "example.com", "/api/users") + if !should3a { + t.Errorf("First request for signature 3 (different method) should parse body") + } + + // Different host - should parse body + should4a := shouldParseBody("GET", "other.com", "/api/users") + if !should4a { + t.Errorf("First request for signature 4 (different host) should parse body") + } +} + +// BenchmarkShouldParseBody benchmarks the shouldParseBody function +func BenchmarkShouldParseBody(b *testing.B) { + // Reset + bloomFilter = bloomfilter.NewWithEstimates(uint(bloomFilterCapacity), bloomFilterFPRate) + lruCache = NewLRUCache(lruCacheCapacity) + + // Pre-populate with some data + for i := 0; i < 1000; i++ { + path := fmt.Sprintf("/api/test%d", i%10) + shouldParseBody("GET", "example.com", path) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + path := fmt.Sprintf("/api/test%d", i%10) + shouldParseBody("GET", "example.com", path) + } +} diff --git a/trafficUtil/kafkaUtil/podinformer.go b/trafficUtil/kafkaUtil/podinformer.go index 14a54c23..38ff26ff 100644 --- a/trafficUtil/kafkaUtil/podinformer.go +++ b/trafficUtil/kafkaUtil/podinformer.go @@ -42,14 +42,18 @@ func init() { utils.InitVar("AKTO_K8_METADATA_CAPTURE", &KubeInjectEnabled) } +type PidInfo struct { + HostName string + ProcessName string +} + type PodInformer struct { clientset *kubernetes.Clientset nodeName string podNameLabelsMap sync.Map // Maps pod names to their labels directly - pidHostNameMap map[int32]string + pidHostNameMap map[int32]PidInfo } - func SetupPodInformer() (chan struct{}, error) { if !KubeInjectEnabled { slog.Warn("AKTO_K8_METADATA_CAPTURE is not true, skipping PodInformer setup") @@ -116,43 +120,63 @@ func NewPodInformer() (*PodInformer, error) { clientset: clientset, nodeName: nodeName, podNameLabelsMap: sync.Map{}, - pidHostNameMap: make(map[int32]string), + pidHostNameMap: make(map[int32]PidInfo), }, nil } func (w *PodInformer) GetPodNameByProcessId(pid int32) string { - if hostName, ok := w.pidHostNameMap[pid]; ok { - return hostName + if info, ok := w.pidHostNameMap[pid]; ok { + return info.HostName } - slog.Warn("Hostname not found for", "processId", pid) + return "" } +func (w *PodInformer) GetProcessNameByProcessId(pid int32) string { + if info, ok := w.pidHostNameMap[pid]; ok { + return info.ProcessName + } + // slog.Debug("Process name not found for", "processId", pid) + return "" +} + +func (w *PodInformer) GetAllKubePids() []uint32 { + // TODO: should we ignore envoy pids ? + pids := make([]uint32, 0, len(w.pidHostNameMap)) + for pid := range w.pidHostNameMap { + pids = append(pids, uint32(pid)) + } + slog.Debug("No of kube/docker pids", "found: ", len(pids)) + return pids +} + func (w *PodInformer) BuildPidHostNameMap() { - cmd := exec.Command("sh", "-c", "for dir in /host/proc/[0-9]*; do pid=$(echo \"$dir\" | cut -d'/' -f4); if [ -f \"$dir/environ\" ]; then hostname=$(strings \"$dir/environ\" | grep '^HOSTNAME=' | cut -d'=' -f2); if [ -n \"$hostname\" ]; then echo \"$pid $hostname\"; fi; fi; done") + cmd := exec.Command("sh", "-c", "for dir in /host/proc/[0-9]*; do pid=$(basename \"$dir\"); if [ -f $dir/environ ]; then hostname=$(strings $dir/environ | grep '^HOSTNAME=' | cut -d'=' -f2); if [ -n \"$hostname\" ]; then comm=$(cat $dir/comm 2>/dev/null); echo \"$pid $comm $hostname\"; fi; fi; done | sort -k3") output, err := cmd.Output() if err != nil { slog.Error("Failed to execute shell command", "error", err) return } - slog.Debug("Shell command output for PID to Hostname mapping", "output", string(output)) + // slog.Debug("Shell command output for PID to Hostname mapping", "output", string(output)) lines := strings.Split(string(output), "\n") for _, line := range lines { parts := strings.Fields(line) - if len(parts) == 2 { + if len(parts) == 3 { pid, err := strconv.Atoi(parts[0]) if err == nil { - w.pidHostNameMap[int32(pid)] = parts[1] + w.pidHostNameMap[int32(pid)] = PidInfo{ + ProcessName: parts[1], + HostName: parts[2], + } } } } - slog.Info("PID to Hostname map built successfully", "map", w.pidHostNameMap) + // slog.Debug("PID to Hostname map built successfully", "map", w.pidHostNameMap) w.logPidHostNameMap() } func (w *PodInformer) ResolvePodLabels(podName string, url, reqHost string) (string, error) { - slog.Debug("Resolving Pod Name to labels", "podName", podName) checkDebugUrlAndPrint(url, reqHost, "Resolving Pod Name to labels for "+podName) // Step 1: Use the pod name as the key to find labels in podNameLabelsMap @@ -184,16 +208,16 @@ func (w *PodInformer) ResolvePodLabels(podName string, url, reqHost string) (str } func (w *PodInformer) logPidHostNameMap() { - slog.Warn("Logging PID to Hostname Map to file", "file", utils.GoPidLogFile) + // slog.Warn("Logging PID to Hostname Map to file", "file", utils.GoPidLogFile) var builder strings.Builder - fmt.Fprintf(&builder, "PID\tHostname:\n") + fmt.Fprintf(&builder, "PID\tProcessName\tHostname:\n") - for pid, hostName := range w.pidHostNameMap { - fmt.Fprintf(&builder, "%d\t%s\n", pid, hostName) + for pid, info := range w.pidHostNameMap { + fmt.Fprintf(&builder, "%d\t%s\t%s\n", pid, info.ProcessName, info.HostName) } fmt.Fprintf(&builder, "-------Total PIDs tracked: %d----------\n", len(w.pidHostNameMap)) utils.LogToSpecificFile(utils.GoPidLogFile, builder.String()) - slog.Warn("PID to Hostname Map logged", "map", w.pidHostNameMap) + // slog.Debug("PID to Hostname Map logged", "map", w.pidHostNameMap) } func (w *PodInformer) logPodLabelsMapFile() { @@ -222,7 +246,7 @@ func (w *PodInformer) logPodNameLabelsMap() { result += fmt.Sprintf("Name: %s, Labels: %s; ", key, value) return true }) - slog.Warn("Pod Name Labels Map", "map", result) + // slog.Debug("Pod Name Labels Map", "map", result) w.logPodLabelsMapFile() } @@ -332,6 +356,7 @@ func (w *PodInformer) handlePodUpdate(oldObj, newObj interface{}) { slog.Debug("Pod update:", "namespace", newPod.Namespace, "podName", newPod.Name) w.podNameLabelsMap.Delete(oldPod.Name) w.podNameLabelsMap.Store(newPod.Name, newPod.Labels) + w.BuildPidHostNameMap() } func (w *PodInformer) handlePodDelete(obj interface{}) { @@ -345,4 +370,4 @@ func (w *PodInformer) handlePodDelete(obj interface{}) { // Build the PID to Hostname map again to ensure it is up-to-date // TODO: Optimize this ? What's the rate of pod add events? w.BuildPidHostNameMap() -} \ No newline at end of file +} diff --git a/trafficUtil/kafkaUtil/signature_helpers.go b/trafficUtil/kafkaUtil/signature_helpers.go new file mode 100644 index 00000000..411101df --- /dev/null +++ b/trafficUtil/kafkaUtil/signature_helpers.go @@ -0,0 +1,50 @@ +package kafkaUtil + +import ( + "strings" + "time" +) + +// getTimeBucket converts current time to a uint8 bucket (0-255) +// Each bucket represents a 10-minute interval +// Wraps around every ~42 hours (256 * 10 min) +func getTimeBucket() uint8 { + minutes := time.Now().Unix() / int64(timeBucketDuration.Seconds()) + return uint8(minutes % 256) +} + +// isTimeBucketExpired checks if a time bucket is older than the interval +// Accounts for wrap-around (255 -> 0) +func isTimeBucketExpired(storedBucket uint8) bool { + currentBucket := getTimeBucket() + + // Calculate difference accounting for wrap-around + var diff int + if currentBucket >= storedBucket { + diff = int(currentBucket) - int(storedBucket) + } else { + // Wrapped around: e.g., stored=250, current=5 + diff = (256 - int(storedBucket)) + int(currentBucket) + } + + // If diff >= 1, it's been at least 10 minutes + return diff >= 1 +} + +// buildSignatureKey creates a unique key from method, host, and path. +// Format: "METHOD|HOST|PATH" +// Uses strings.Builder for efficient concatenation. +// Examples: +// - buildSignatureKey("GET", "example.com", "/api/users") -> "GET|example.com|/api/users" +// - buildSignatureKey("POST", "api.example.com", "/v1/data") -> "POST|api.example.com|/v1/data" +func buildSignatureKey(method, host, path string) string { + var sb strings.Builder + // Pre-allocate capacity: method(4) + host(20) + path(20) + separators(2) ≈ 46 + sb.Grow(len(method) + len(host) + len(path) + 2) + sb.WriteString(method) + sb.WriteByte('|') + sb.WriteString(host) + sb.WriteByte('|') + sb.WriteString(path) + return sb.String() +} diff --git a/trafficUtil/kafkaUtil/signature_helpers_test.go b/trafficUtil/kafkaUtil/signature_helpers_test.go new file mode 100644 index 00000000..f11b21e5 --- /dev/null +++ b/trafficUtil/kafkaUtil/signature_helpers_test.go @@ -0,0 +1,72 @@ +package kafkaUtil + +import ( + "testing" +) + +// TestBuildSignatureKey tests signature key generation +func TestBuildSignatureKey(t *testing.T) { + tests := []struct { + method string + host string + path string + expected string + }{ + {"GET", "example.com", "/api/users", "GET|example.com|/api/users"}, + {"POST", "api.example.com", "/v1/data", "POST|api.example.com|/v1/data"}, + {"DELETE", "localhost:8080", "/test", "DELETE|localhost:8080|/test"}, + {"PUT", "192.168.1.1", "/", "PUT|192.168.1.1|/"}, + {"PATCH", "api.service.local", "/v2/resource/123", "PATCH|api.service.local|/v2/resource/123"}, + } + + for _, tt := range tests { + result := buildSignatureKey(tt.method, tt.host, tt.path) + if result != tt.expected { + t.Errorf("buildSignatureKey(%s, %s, %s) = %s, want %s", + tt.method, tt.host, tt.path, result, tt.expected) + } + } +} + +// TestGetTimeBucket tests time bucket generation +func TestGetTimeBucket(t *testing.T) { + bucket1 := getTimeBucket() + if bucket1 < 0 || bucket1 > 255 { + t.Errorf("getTimeBucket() returned %d, expected 0-255", bucket1) + } + + // Get bucket again immediately (should be same) + bucket2 := getTimeBucket() + if bucket1 != bucket2 { + t.Errorf("Expected same bucket for immediate calls: %d vs %d", bucket1, bucket2) + } +} + +// TestIsTimeBucketExpired tests time bucket expiration detection +func TestIsTimeBucketExpired(t *testing.T) { + // Test recent bucket (should not be expired) + currentBucket := getTimeBucket() + if isTimeBucketExpired(currentBucket) { + t.Errorf("Current bucket should not be expired") + } + + // Test old bucket (should be expired) + oldBucket := uint8((int(currentBucket) - 5 + 256) % 256) + if !isTimeBucketExpired(oldBucket) { + t.Errorf("Old bucket should be expired") + } +} + +// TestIsTimeBucketExpiredWrapAround tests time bucket expiration with wrap-around +func TestIsTimeBucketExpiredWrapAround(t *testing.T) { + // Simulate wrap-around scenario where stored bucket is close to 255 + // and current bucket wrapped to a low number + // This is probabilistic based on actual time, so we just verify the logic works + + // Test that the same bucket is not expired + bucket := uint8(100) + if isTimeBucketExpired(bucket) && bucket == getTimeBucket() { + // Only fail if it's actually the same bucket + t.Errorf("Same bucket should not be expired") + } +} diff --git a/trafficUtil/utils/common.go b/trafficUtil/utils/common.go index 7d49d8a0..e69197da 100644 --- a/trafficUtil/utils/common.go +++ b/trafficUtil/utils/common.go @@ -37,7 +37,7 @@ var IgnoreIpTraffic = false var IgnoreCloudMetadataCalls = false var IgnoreEnvoyProxycalls = false var EnableGraph = true -var ThreatEnabled = false +var ThreatEnabled = true const EnvoyProxyIp = "127.0.0.6" @@ -76,6 +76,5 @@ func InitVar(envVarName string, targetVar interface{}) { slog.Warn("Unsupported type for targetVar", "type", v) } } else { - slog.Warn("Missing env value, using default value", "name", envVarName) } }