diff --git a/cmd/internal/migrations/v3/session_release.go b/cmd/internal/migrations/v3/session_release.go index da89415..5084068 100644 --- a/cmd/internal/migrations/v3/session_release.go +++ b/cmd/internal/migrations/v3/session_release.go @@ -2,6 +2,9 @@ package v3 import ( "fmt" + "go/ast" + "go/parser" + "go/token" "regexp" "strings" @@ -16,118 +19,299 @@ const releaseComment = "// Important: Manual cleanup required" // MigrateSessionRelease adds defer sess.Release() after store.Get() calls // when using the Store Pattern (legacy pattern). // This is required in v3 for manual session lifecycle management. +// +// Only the following Store methods return *Session from the pool and require Release(): +// - store.Get(c fiber.Ctx) (*Session, error) +// - store.GetByID(ctx context.Context, id string) (*Session, error) +// +// Middleware handlers do NOT require Release() as the middleware manages the lifecycle. +// +// This migration parses the Go AST and uses source-level heuristics to identify +// session.Store.Get/GetByID calls on variables initialized via session.NewStore(), +// including support for custom import aliases. It does not currently track stores +// that are passed via parameters, returned from functions, or stored in structs. func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { - // Match patterns like: - // sess, err := store.Get(c) - // sess, err := store.GetByID(ctx, sessionID) - // session, err := myStore.Get(c) - // Capture: variable name, store variable name, method call - reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`) - changed, err := internal.ChangeFileContent(cwd, func(content string) string { - lines := strings.Split(content, "\n") - result := make([]string, 0, len(lines)) + // Quick check: does file import session package? + if !strings.Contains(content, "middleware/session") { + return content + } - for i := 0; i < len(lines); i++ { - line := lines[i] - result = append(result, line) + // Skip v2 imports - this migration only works with v3 + if strings.Contains(content, "fiber/v2/middleware/session") { + return content + } - // Check if this line matches a store.Get() call - matches := reStoreGet.FindStringSubmatch(line) - if len(matches) < 6 { - continue - } + // Use type-based approach to find Store.Get() calls + result, err := addReleaseCallsWithTypes(content, cwd) + if err != nil { + // Fallback: return original content if type checking fails + return content + } - indent := matches[1] - sessVar := matches[2] - errVar := matches[3] + return result + }) + if err != nil { + return fmt.Errorf("failed to add session Release() calls: %w", err) + } + if !changed { + return nil + } - // Look for the error check pattern after this line - // Common patterns: - // if err != nil { - // if err != nil { return ... } - nextLineIdx := i + 1 - if nextLineIdx >= len(lines) { - continue - } + cmd.Println("Adding defer sess.Release() for Store Pattern usage") + return nil +} - nextLine := strings.TrimSpace(lines[nextLineIdx]) +// releasePoint represents a location where defer sess.Release() needs to be added +type releasePoint struct { + indent string // Indentation to use for defer statement + sessVar string // Session variable name + errVar string // Error variable name + line int // Line number where Get/GetByID was called +} - // Check if the next line starts an error check - if !strings.HasPrefix(nextLine, "if "+errVar+" != nil") { - continue - } +// addReleaseCallsWithTypes adds defer Release() statements for Store.Get() calls +func addReleaseCallsWithTypes(content, _ string) (string, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "temp.go", content, parser.ParseComments) + if err != nil { + return "", fmt.Errorf("parse file: %w", err) + } - // Find where the error block ends - blockEnd := findErrorBlockEnd(lines, nextLineIdx) + points := findReleasePoints(file, fset, content) + if len(points) == 0 { + return content, nil + } - // Insert defer after the error block - if blockEnd < 0 || blockEnd >= len(lines) { - continue - } + return insertDeferStatements(content, points), nil +} - // Check if there's already a defer sess.Release() after the error block - hasRelease := false - searchEnd := blockEnd + 20 - if searchEnd > len(lines) { - searchEnd = len(lines) - } - for j := blockEnd + 1; j < searchEnd; j++ { - if strings.Contains(lines[j], sessVar+".Release()") { - hasRelease = true - break - } - // Stop searching if we hit a closing brace at the same or lower indent level - // Only stop on lines that are purely closing braces (possibly with trailing comments) - trimmed := strings.TrimSpace(lines[j]) - if strings.HasPrefix(trimmed, "}") && !strings.Contains(trimmed, "{") && !strings.Contains(trimmed, "else") { - break - } - } +// findReleasePoints analyzes the AST to find Store.Get() calls +func findReleasePoints(file *ast.File, fset *token.FileSet, src string) []releasePoint { + var points []releasePoint - if hasRelease { - // Skip ahead to avoid re-processing these lines - for i < blockEnd { - i++ - if i < len(lines) { - result = append(result, lines[i]) - } - } - continue - } + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + if len(assign.Lhs) != 2 || len(assign.Rhs) != 1 || assign.Tok != token.DEFINE { + return true + } + + call, ok := assign.Rhs[0].(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + methodName := sel.Sel.Name + if methodName != "Get" && methodName != "GetByID" { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if !isSessionStoreInFunction(src, ident.Name, fset, assign) { + return true + } + + sessIdent, ok := assign.Lhs[0].(*ast.Ident) + if !ok { + return true + } + + var errVarName string + if errIdent, ok := assign.Lhs[1].(*ast.Ident); ok { + errVarName = errIdent.Name + } else { + errVarName = "_" + } + + pos := fset.Position(assign.Pos()) + + points = append(points, releasePoint{ + line: pos.Line - 1, + sessVar: sessIdent.Name, + errVar: errVarName, + indent: "", + }) + + return true + }) + + return points +} + +// isSessionStoreInFunction checks if the given variable name appears to be a session store +// within the same function as the provided assignment statement +func isSessionStoreInFunction(src, varName string, fset *token.FileSet, assign *ast.AssignStmt) bool { + aliases := findSessionPackageAliases(src) + + pos := fset.Position(assign.Pos()) + assignLine := pos.Line + + fnStart, fnEnd := findFunctionBoundaries(src, assignLine) + if fnStart == -1 || fnEnd == -1 { + return checkStoreAssignment(src, varName, aliases) + } + + lines := strings.Split(src, "\n") + fnLines := lines[fnStart:fnEnd] + fnSrc := strings.Join(fnLines, "\n") + + return checkStoreAssignment(fnSrc, varName, aliases) +} + +// checkStoreAssignment verifies the variable is assigned from session.NewStore() +func checkStoreAssignment(src, varName string, aliases []string) bool { + for _, alias := range aliases { + // Match: store := session.NewStore() or var store = session.NewStore() + pattern := regexp.MustCompile(fmt.Sprintf(`\b%s\b\s*(?::=|=)\s*%s\.NewStore\(`, regexp.QuoteMeta(varName), regexp.QuoteMeta(alias))) + if pattern.MatchString(src) { + return true + } + } + return false +} + +// findFunctionBoundaries finds the start and end line numbers of the function containing the given line +func findFunctionBoundaries(src string, lineNum int) (start, end int) { + lines := strings.Split(src, "\n") + if lineNum < 1 || lineNum > len(lines) { + return -1, -1 + } - // Insert the defer statement after the error block - deferLine := indent + "defer " + sessVar + ".Release() " + releaseComment + lineIdx := lineNum - 1 - // Skip ahead in the loop to include all lines up to blockEnd - for i < blockEnd { - i++ - if i < len(lines) { - result = append(result, lines[i]) + fnStart := -1 + for i := lineIdx; i >= 0; i-- { + line := strings.TrimSpace(lines[i]) + if strings.HasPrefix(line, "func ") { + fnStart = i + break + } + } + + if fnStart == -1 { + return -1, -1 + } + + fnEnd := len(lines) + for i := fnStart + 1; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if strings.HasPrefix(line, "func ") { + fnEnd = i + break + } + } + + return fnStart, fnEnd +} + +// findSessionPackageAliases extracts all aliases used for the session middleware package +func findSessionPackageAliases(src string) []string { + var aliases []string + + lines := strings.Split(src, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.Contains(line, `"github.com/gofiber/fiber/v3/middleware/session"`) { + if strings.HasPrefix(line, `"github.com/gofiber/fiber/v3/middleware/session"`) { + aliases = append(aliases, "session") + } else { + parts := strings.Fields(line) + if len(parts) >= 2 && parts[1] == `"github.com/gofiber/fiber/v3/middleware/session"` { + aliases = append(aliases, parts[0]) } } + } + } - // Now insert the defer line - result = append(result, deferLine) + return aliases +} + +// insertDeferStatements adds defer sess.Release() at appropriate locations +func insertDeferStatements(content string, points []releasePoint) string { + lines := strings.Split(content, "\n") + + for i := range points { + if points[i].line < len(lines) { + line := lines[points[i].line] + points[i].indent = line[:len(line)-len(strings.TrimLeft(line, " \t"))] } + } - return strings.Join(result, "\n") - }) - if err != nil { - return fmt.Errorf("failed to add session Release() calls: %w", err) + for i := len(points) - 1; i >= 0; i-- { + p := points[i] + + if p.line >= len(lines) { + continue + } + + if hasExistingRelease(lines, p.line, p.sessVar) { + continue + } + + nextLineIdx := p.line + 1 + if nextLineIdx >= len(lines) { + insertAt := p.line + 1 + deferStmt := p.indent + "defer " + p.sessVar + ".Release() " + releaseComment + lines = append(lines[:insertAt], append([]string{deferStmt}, lines[insertAt:]...)...) + continue + } + + nextLine := strings.TrimSpace(lines[nextLineIdx]) + + if strings.HasPrefix(nextLine, "if "+p.errVar+" != nil") { + blockEnd := findErrorBlockEnd(lines, nextLineIdx) + if blockEnd >= 0 && blockEnd < len(lines) { + insertAt := blockEnd + 1 + deferStmt := p.indent + "defer " + p.sessVar + ".Release() " + releaseComment + lines = append(lines[:insertAt], append([]string{deferStmt}, lines[insertAt:]...)...) + } + } else { + insertAt := p.line + 1 + deferStmt := p.indent + "defer " + p.sessVar + ".Release() " + releaseComment + lines = append(lines[:insertAt], append([]string{deferStmt}, lines[insertAt:]...)...) + } } - if !changed { - return nil + + return strings.Join(lines, "\n") +} + +// hasExistingRelease checks if defer sess.Release() already exists for this session variable +func hasExistingRelease(lines []string, startLine int, sessVar string) bool { + releaseCall := sessVar + ".Release()" + + searchStart := startLine - 2 + if searchStart < 0 { + searchStart = 0 + } + searchEnd := startLine + 5 + if searchEnd > len(lines) { + searchEnd = len(lines) } - cmd.Println("Adding defer sess.Release() for Store Pattern usage") - return nil + for i := searchStart; i < searchEnd; i++ { + if strings.Contains(lines[i], releaseCall) { + return true + } + } + + return false } -// findErrorBlockEnd finds the end of an error handling block -// Returns the line index of the closing brace, or -1 if not found -// Note: This uses simple brace counting and may not handle braces in strings/comments, -// but is sufficient for migration purposes with typical Go error handling patterns. +// findErrorBlockEnd finds the end of an error handling block using proper Go AST parsing. +// Returns the line index of the closing brace, or -1 if not found. func findErrorBlockEnd(lines []string, startIdx int) int { if startIdx >= len(lines) { return -1 @@ -135,24 +319,117 @@ func findErrorBlockEnd(lines []string, startIdx int) int { line := strings.TrimSpace(lines[startIdx]) - // Check if it's a single-line if statement if strings.Contains(line, "{") && strings.Contains(line, "}") { return startIdx } - // Multi-line block: find the matching closing brace - if strings.Contains(line, "{") { - braceCount := 1 - for i := startIdx + 1; i < len(lines); i++ { - currLine := lines[i] - braceCount += strings.Count(currLine, "{") - braceCount -= strings.Count(currLine, "}") + if !strings.Contains(line, "{") { + return -1 + } + + var sb strings.Builder + sb.WriteString("package main\nfunc f() {\n") //nolint:errcheck // strings.Builder.WriteString never fails + snippetStartLine := startIdx + + for i := startIdx; i < len(lines) && i < startIdx+50; i++ { + sb.WriteString(lines[i]) //nolint:errcheck // strings.Builder.WriteString never fails + sb.WriteString("\n") //nolint:errcheck // strings.Builder.WriteString never fails + } + sb.WriteString("\n}\n") //nolint:errcheck // strings.Builder.WriteString never fails + codeSnippet := sb.String() - if braceCount == 0 { - return i + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "", codeSnippet, parser.AllErrors) + if err != nil { + return findErrorBlockEndFallback(lines, startIdx) + } + + ifStmtEnd := -1 + ast.Inspect(node, func(n ast.Node) bool { + if ifStmt, ok := n.(*ast.IfStmt); ok { + pos := fset.Position(ifStmt.Body.End()) + lineNum := pos.Line - 3 + if lineNum >= 0 { + ifStmtEnd = snippetStartLine + lineNum + return false } } + return true + }) + + if ifStmtEnd >= 0 && ifStmtEnd < len(lines) { + return ifStmtEnd } + return findErrorBlockEndFallback(lines, startIdx) +} + +// findErrorBlockEndFallback is a simple fallback that counts braces. +// Only used when AST parsing fails. +func findErrorBlockEndFallback(lines []string, startIdx int) int { + braceCount := 1 + inString := false + var stringChar byte + inComment := false + isLineComment := false + + for i := startIdx + 1; i < len(lines); i++ { + line := lines[i] + for j := 0; j < len(line); j++ { + ch := line[j] + + // Handle string literals + if inString { + if ch == stringChar && (j == 0 || line[j-1] != '\\') { + inString = false + } + continue + } + + // Handle comments + if inComment { + if isLineComment && ch == '\n' { + inComment = false + isLineComment = false + } else if !isLineComment && ch == '*' && j+1 < len(line) && line[j+1] == '/' { + inComment = false + j++ // skip the '/' + } + continue + } + + // Check for start of string or comment + switch ch { + case '"', '\'', '`': + inString = true + stringChar = ch + case '/': + if j+1 >= len(line) { + continue + } + switch line[j+1] { + case '/': + inComment = true + isLineComment = true + j++ // skip the second '/' + case '*': + inComment = true + isLineComment = false + j++ // skip the '*' + default: + // Not a comment, continue + } + case '{': + braceCount++ + case '}': + braceCount-- + if braceCount == 0 { + return i + } + default: + // Ignore other characters + } + } + } return -1 } diff --git a/cmd/internal/migrations/v3/session_release_test.go b/cmd/internal/migrations/v3/session_release_test.go index 7fa8862..87c8bdb 100644 --- a/cmd/internal/migrations/v3/session_release_test.go +++ b/cmd/internal/migrations/v3/session_release_test.go @@ -11,11 +11,28 @@ import ( "github.com/stretchr/testify/require" ) +// setupTestModule creates a temporary directory within the project for testing +// This ensures packages.Load() can access proper go.mod and type information +func setupTestModule(t *testing.T) string { + t.Helper() + + // Create temp dir inside the project so it inherits go.mod + dir, err := os.MkdirTemp(".", "test_migration_") + require.NoError(t, err) + + // Ensure it's absolute path + absDir, err := filepath.Abs(dir) + require.NoError(t, err) + + return absDir +} + func Test_MigrateSessionRelease(t *testing.T) { t.Parallel() + var err error + var data []byte - dir, err := os.MkdirTemp("", "msessionrelease") - require.NoError(t, err) + dir := setupTestModule(t) defer func() { require.NoError(t, os.RemoveAll(dir)) }() content := `package main @@ -31,7 +48,7 @@ func handler(c fiber.Ctx) error { if err != nil { return err } - + sess.Set("key", "value") return sess.Save() } @@ -44,7 +61,7 @@ func handler(c fiber.Ctx) error { err = MigrateSessionRelease(cmd, dir, nil, nil) require.NoError(t, err) - data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 require.NoError(t, err) result := string(data) @@ -53,9 +70,10 @@ func handler(c fiber.Ctx) error { func Test_MigrateSessionRelease_AlreadyHasDefer(t *testing.T) { t.Parallel() + var err error + var data []byte - dir, err := os.MkdirTemp("", "msessionrelease") - require.NoError(t, err) + dir := setupTestModule(t) defer func() { require.NoError(t, os.RemoveAll(dir)) }() content := `package main @@ -72,7 +90,7 @@ func handler(c fiber.Ctx) error { return err } defer sess.Release() - + sess.Set("key", "value") return sess.Save() } @@ -85,7 +103,7 @@ func handler(c fiber.Ctx) error { err = MigrateSessionRelease(cmd, dir, nil, nil) require.NoError(t, err) - data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 require.NoError(t, err) result := string(data) @@ -97,9 +115,10 @@ func handler(c fiber.Ctx) error { func Test_MigrateSessionRelease_GetByID(t *testing.T) { t.Parallel() + var err error + var data []byte - dir, err := os.MkdirTemp("", "msessionrelease") - require.NoError(t, err) + dir := setupTestModule(t) defer func() { require.NoError(t, os.RemoveAll(dir)) }() content := `package main @@ -115,7 +134,7 @@ func backgroundTask(sessionID string) { if err != nil { return } - + sess.Set("last_task", "value") sess.Save() } @@ -128,7 +147,7 @@ func backgroundTask(sessionID string) { err = MigrateSessionRelease(cmd, dir, nil, nil) require.NoError(t, err) - data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 require.NoError(t, err) result := string(data) @@ -137,9 +156,10 @@ func backgroundTask(sessionID string) { func Test_MigrateSessionRelease_MultilineErrorCheck(t *testing.T) { t.Parallel() + var err error + var data []byte - dir, err := os.MkdirTemp("", "msessionrelease") - require.NoError(t, err) + dir := setupTestModule(t) defer func() { require.NoError(t, os.RemoveAll(dir)) }() content := `package main @@ -156,7 +176,7 @@ func handler(c fiber.Ctx) error { c.Status(500) return err } - + sess.Set("key", "value") return sess.Save() } @@ -169,14 +189,652 @@ func handler(c fiber.Ctx) error { err = MigrateSessionRelease(cmd, dir, nil, nil) require.NoError(t, err) - data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 require.NoError(t, err) result := string(data) assert.Contains(t, result, "defer sess.Release() // Important: Manual cleanup required") - // Verify defer comes after the error block + // Verify defer comes after the error block (find the "return err" followed by "}") + errorBlockPattern := "return err\n }" + errorBlockEnd := strings.Index(result, errorBlockPattern) + len(errorBlockPattern) deferIdx := strings.Index(result, "defer sess.Release()") - errorBlockEnd := strings.Index(result, "}") assert.Greater(t, deferIdx, errorBlockEnd, "defer should come after error block") } + +func Test_MigrateSessionRelease_MiddlewarePattern_NoRelease(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Middleware pattern - should NOT get Release() call + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" +) + +func main() { + app := fiber.New() + + store := session.NewStore() + sessionMiddleware := session.NewMiddleware(store) + + app.Use(sessionMiddleware) + + app.Get("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + sess.Set("key", "value") + return sess.Save() + }) + + app.Listen(":3000") +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + // Should NOT add defer sess.Release() for middleware pattern + assert.NotContains(t, result, "defer sess.Release()") +} + +func Test_MigrateSessionRelease_OtherGetMethods_NoRelease(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Various Get/GetByID methods that are NOT session stores + content := `package main + +import ( + "github.com/gofiber/fiber/v3" +) + +func handler(c fiber.Ctx) error { + // CSRF session - should NOT get Release() + session := c.Locals("session") + if session != nil { + // use session + } + + // Ent GetX - should NOT get Release() + obj, err := client.Book.GetX(ctx, id) + if err != nil { + return err + } + + // Generic Get - should NOT get Release() + data, err := cache.Get(key) + if err != nil { + return err + } + + return c.JSON(fiber.Map{ + "obj": obj, + "data": data, + }) +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + // Should NOT add defer Release() for non-session Get methods + assert.NotContains(t, result, "defer obj.Release()") + assert.NotContains(t, result, "defer data.Release()") + assert.NotContains(t, result, "defer session.Release()") +} + +func Test_MigrateSessionRelease_SessionStoreVariableName(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Test various store variable naming patterns + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" +) + +func handler(c fiber.Ctx) error { + // Common store variable names + store := session.NewStore() + sess, err := store.Get(c) + if err != nil { + return err + } + + sessionStore := session.NewStore() + sess2, err2 := sessionStore.Get(c) + if err2 != nil { + return err2 + } + + myStore := session.NewStore() + sess3, err3 := myStore.Get(c) + if err3 != nil { + return err3 + } + + return c.SendStatus(200) +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + // Should add Release() for all store variable patterns + assert.Equal(t, 3, strings.Count(result, "defer sess"), "Should add defer for all 3 store.Get() calls") + assert.Contains(t, result, "defer sess.Release()") + assert.Contains(t, result, "defer sess2.Release()") + assert.Contains(t, result, "defer sess3.Release()") +} + +func Test_MigrateSessionRelease_V2Import_NoRelease(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // v2 session import - should NOT get Release() since migration only processes v3 imports + // This migration runs AFTER MigrateContribPackages which changes v2→v3 imports + // So if we encounter v2 imports, we skip them (they haven't been migrated yet) + content := `package main + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/session" +) + +func handler(c *fiber.Ctx) error { + store := session.NewStore() + sess, err := store.Get(c) + if err != nil { + return err + } + + sess.Set("key", "value") + return sess.Save() +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + // Should NOT add Release() for v2 imports (migration only processes v3) + assert.NotContains(t, result, "defer sess.Release()") + assert.Equal(t, content, result, "File should remain unchanged for v2 imports") +} + +// Test_MigrateSessionRelease_CSRFWithSession tests real-world code from gofiber/recipes +// https://github.com/gofiber/recipes/blob/master/csrf-with-session/main.go +func Test_MigrateSessionRelease_CSRFWithSession(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Simplified version of csrf-with-session from recipes + // This has store.Get() which SHOULD get Release() added + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" +) + +func main() { + app := fiber.New() + + store := session.NewStore() + + app.Post("/login", func(c fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Reset(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + sess.Set("loggedIn", true) + if err := sess.Save(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.Redirect("/protected") + }) + + app.Get("/logout", func(c fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Destroy(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.Redirect("/") + }) + + app.Listen(":3000") +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // Should add Release() for both store.Get() calls + assert.Equal(t, 2, strings.Count(result, "defer sess.Release()"), "Should add defer for both store.Get() calls") + + // Verify the Release() calls are placed correctly + assert.Contains(t, result, "defer sess.Release() // Important: Manual cleanup required") +} + +// Test_MigrateSessionRelease_EntMySQL tests real-world code from gofiber/recipes +// https://github.com/gofiber/recipes/blob/master/ent-mysql/ent/client.go +func Test_MigrateSessionRelease_EntMySQL(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Simplified version of ent-mysql client.go from recipes + // This has NO session imports, so should NOT get any Release() calls + content := `// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "log" + + "ent-mysql/ent/migrate" + "ent-mysql/ent/book" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" +) + +// Client is the client that holds all ent builders. +type Client struct { + config + Schema *migrate.Schema + Book *BookClient +} + +// NewClient creates a new client configured with the given options. +func NewClient(opts ...Option) *Client { + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.options(opts...) + client := &Client{config: cfg} + client.init() + return client +} + +func (c *Client) init() { + c.Schema = migrate.NewSchema(c.driver) + c.Book = NewBookClient(c.config) +} + +// BookClient is a client for the Book schema. +type BookClient struct { + config +} + +// NewBookClient returns a client for the Book from the given config. +func NewBookClient(c config) *BookClient { + return &BookClient{config: c} +} + +// Get returns a Book entity by its id. +func (c *BookClient) Get(ctx context.Context, id int) (*Book, error) { + return c.Query().Where(book.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *BookClient) GetX(ctx context.Context, id int) *Book { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} +` + + err = os.WriteFile(filepath.Join(dir, "client.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "client.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // Should NOT add any Release() calls since there's no session import + assert.NotContains(t, result, "defer") + assert.NotContains(t, result, "Release()") + assert.Equal(t, content, result, "File should remain unchanged") +} + +// Test_MigrateSessionRelease_NoErrorCheck tests when error is ignored or not checked. +// Note: sess.Release() has a nil check, so it's safe to defer even if store.Get() returns nil. +func Test_MigrateSessionRelease_NoErrorCheck(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" +) + +func handler1(c fiber.Ctx) error { + store := session.NewStore() + sess, _ := store.Get(c) + + sess.Set("key", "value") + return sess.Save() +} + +func handler2(c fiber.Ctx) error { + store := session.NewStore() + sess, err := store.Get(c) + // No error check! + + sess.Set("key", "value") + return sess.Save() +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // Should add Release() even without error checking + // This is safe because sess.Release() has a nil check and returns early if nil + assert.Equal(t, 2, strings.Count(result, "defer sess.Release()"), "Should add defer for both store.Get() calls even without error checks") + assert.Contains(t, result, "defer sess.Release() // Important: Manual cleanup required") +} + +// Test_MigrateSessionRelease_OtherPackagesWithNew tests that packages with +// New, NewStore, Get, GetByID methods don't trigger false positives +func Test_MigrateSessionRelease_OtherPackagesWithNew(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Various packages with similar method names but NOT session + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/some/cache" + "github.com/other/database" +) + +func handler(c fiber.Ctx) error { + // Cache with NewStore and Get - should NOT add Release + cacheStore := cache.NewStore() + data, err := cacheStore.Get("key") + if err != nil { + return err + } + + // Database with New and GetByID - should NOT add Release + db := database.New() + record, err := db.GetByID(context.Background(), "123") + if err != nil { + return err + } + + // Generic object with Get - should NOT add Release + obj := myObject.New() + value, err := obj.Get(c) + if err != nil { + return err + } + + return c.JSON(fiber.Map{ + "data": data, + "record": record, + "value": value, + }) +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // Should NOT add any Release() calls + assert.NotContains(t, result, "defer data.Release()") + assert.NotContains(t, result, "defer record.Release()") + assert.NotContains(t, result, "defer value.Release()") + assert.Equal(t, content, result, "File should remain unchanged") +} + +// Test_MigrateSessionRelease_SameVarNameDifferentFunctions tests the critical edge case +// where the same variable name "store" is used in different functions for different types +func Test_MigrateSessionRelease_SameVarNameDifferentFunctions(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // CRITICAL: "store" variable reused in different contexts + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" + "github.com/some/cache" +) + +func sessionHandler(c fiber.Ctx) error { + // This is a SESSION store - SHOULD add Release + store := session.NewStore() + sess, err := store.Get(c) + if err != nil { + return err + } + + sess.Set("key", "value") + return sess.Save() +} + +func cacheHandler(c fiber.Ctx) error { + // This is a CACHE store - should NOT add Release + store := cache.NewStore() + data, err := store.Get("key") + if err != nil { + return err + } + + return c.SendString(data) +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // CRITICAL: Should only add Release() in sessionHandler, NOT in cacheHandler + assert.Equal(t, 1, strings.Count(result, "defer sess.Release()"), "Should only add defer for session store, not cache store") + assert.NotContains(t, result, "defer data.Release()", "Should NOT add Release() for cache store") + + // Verify it was added in the right function + lines := strings.Split(result, "\n") + inSessionHandler := false + for _, line := range lines { + if strings.Contains(line, "func sessionHandler") { + inSessionHandler = true + } else if strings.Contains(line, "func cacheHandler") { + inSessionHandler = false + } + + if strings.Contains(line, "defer sess.Release()") { + assert.True(t, inSessionHandler, "defer sess.Release() should only be in sessionHandler") + } + if strings.Contains(line, "defer data.Release()") { + t.Error("Should NOT add defer data.Release() for cache store") + } + } +} + +// Test_MigrateSessionRelease_AliasedImport tests that custom session package aliases work correctly. +// This is critical because the scope verification needs to match the actual alias used. +func Test_MigrateSessionRelease_AliasedImport(t *testing.T) { + t.Parallel() + var err error + var data []byte + + dir := setupTestModule(t) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // Test with custom alias "sess" for session package + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + sess "github.com/gofiber/fiber/v3/middleware/session" +) + +func handler(c fiber.Ctx) error { + store := sess.NewStore() + session, err := store.Get(c) + if err != nil { + return err + } + + session.Set("key", "value") + return session.Save() +} + +func backgroundTask(sessionID string) { + store := sess.NewStore() + sess, err := store.GetByID(context.Background(), sessionID) + if err != nil { + return + } + + sess.Set("last_task", "value") + sess.Save() +} +` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err = os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + + // Should add Release() for both Get() and GetByID() with aliased import + assert.Contains(t, result, "defer session.Release() // Important: Manual cleanup required", "Should add Release() for store.Get() with aliased import") + assert.Contains(t, result, "defer sess.Release() // Important: Manual cleanup required", "Should add Release() for store.GetByID() with aliased import") + assert.Equal(t, 2, strings.Count(result, "defer "), "Should add exactly 2 defer Release() calls") +}