diff --git a/tools/gocc/main.go b/tools/gocc/main.go index e443d15..0c54f5b 100644 --- a/tools/gocc/main.go +++ b/tools/gocc/main.go @@ -35,8 +35,7 @@ import ( ) // TODO: manual SSA to detect a unique name for the majic lock -var majicLockName = "optiLock" -var majicLockID = 0 +const optiLockName = "optiLock" var usesMap map[*ast.Ident]types.Object var typesMap map[ast.Expr]types.TypeAndValue @@ -49,13 +48,6 @@ var isSingleFile bool var profileProvided bool = false var hotFuncMap map[string]bool -// this map will mark which lock position and paths are in lambda function -// since it will have different namescope -var lockInLambdaFunc map[token.Pos]bool - -// this map will keep the all blkstmt position whose parent is funclit -var blkstmtMap map[token.Pos]bool - // for any lock, lockInfo stores the positions where it locks and (defer) unlocks // isValue indicates whether this lock is lock object or pointer in source code type lockInfo struct { @@ -73,9 +65,6 @@ var lockUnlockSameBB = 0 var lockDeferUnlockSameBB = 0 var lockUnlockPairDifferentBB = 0 var lockDeferUnlockPairDifferentBB = 0 -var unsafeLock = 0 -var unpaired = 0 -var paired = 0 var mPkg map[string]int var mBBSafety map[int]bool @@ -85,9 +74,6 @@ var mapFuncSafety map[*ssa.Function]bool var allLockVals map[ssa.Value]bool var allUnlockVals map[ssa.Value]bool var lockAliasMap map[ssa.Value][]ssa.Value -var pkgName map[string]bool -var tokenToName map[token.Pos]string -var pathToEndNodePos map[ast.Node]token.Pos var writeOutput bool = true var mCallGraph *callgraph.Graph @@ -96,9 +82,7 @@ var outputPath string // generate the name of the packages that violates HTM func initBlockList() []string { - // return []string{"sync", "os", "io", "fmt", "runtime"} - // remove sync to allow nested lock - return []string{"os", "io", "fmt", "runtime"} + return []string{"sync", "os", "io", "fmt", "runtime"} } func isMutexValue(s string) bool { @@ -126,22 +110,6 @@ func isLockPointer(rcv ssa.Value) bool { } else { // Pointer type } - case *ssa.Global: - allcTmp := rcv.(*ssa.Global) - if isMutexValue(allcTmp.Type().String()) { - // is a value - isValue = true - } else { - // is a pointer - } - case *ssa.FreeVar: - allcTmp := rcv.(*ssa.FreeVar) - if isMutexValue(allcTmp.Type().String()) { - // is a value - isValue = true - } else { - // is a pointer - } default: // Pointer type } @@ -150,26 +118,33 @@ func isLockPointer(rcv ssa.Value) bool { // count lock, unlock, and defer unlock number func countLockNumber(ssaF *ssa.Function, lockType string, lock string, unlock string) bool { - if ssaF != nil && ssaF.Blocks != nil { - for _, blk := range ssaF.Blocks { - for _, ins := range blk.Instrs { - if call, ok := ins.(*ssa.Call); ok { - if !call.Call.IsInvoke() && call.Call.StaticCallee() != nil { - calleeName := call.Call.StaticCallee().Name() - callRcv := call.Call.Value - if callRcv != nil && strings.Contains(callRcv.String(), lockType) && calleeName == lock { - allLockVals[call.Call.Args[0]] = true - } else if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { - allUnlockVals[call.Call.Args[0]] = true - } - } - } else if call, ok := ins.(*ssa.Defer); ok { - if call.Call.StaticCallee() != nil { - callRcv := call.Call.Value - if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { - allUnlockVals[call.Call.Args[0]] = true - } - } + if ssaF == nil || ssaF.Blocks == nil { + return localCheck(ssaF.Blocks) + } + + for _, blk := range ssaF.Blocks { + for _, ins := range blk.Instrs { + if call, ok := ins.(*ssa.Call); ok { + if call.Call.IsInvoke() || call.Call.StaticCallee() == nil { + continue + } + calleeName := call.Call.StaticCallee().Name() + callRcv := call.Call.Value + if callRcv != nil && strings.Contains(callRcv.String(), lockType) && calleeName == lock { + numLock++ + allLockVals[call.Call.Args[0]] = true + } else if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { + numUnlock++ + allUnlockVals[call.Call.Args[0]] = true + } + } else if call, ok := ins.(*ssa.Defer); ok { + if call.Call.StaticCallee() == nil { + continue + } + callRcv := call.Call.Value + if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { + numDeferUnlock++ + allUnlockVals[call.Call.Args[0]] = true } } } @@ -214,9 +189,6 @@ func checkBasicBlockInCriticalSection(blks []*ssa.BasicBlock) bool { func checkBlockList(rcv *ssa.Value) bool { for _, name := range blockList { if strings.Contains((*rcv).String(), name) { - if *rcv != nil { - fmt.Println((*rcv).String()) - } return true } } @@ -227,17 +199,21 @@ func checkBlockList(rcv *ssa.Value) bool { func localCheck(blks []*ssa.BasicBlock) bool { for _, blk := range blks { for _, ins := range blk.Instrs { - if call, ok := ins.(*ssa.Call); ok { - callRcv := call.Call.Value - if checkBlockList(&callRcv) { - return false - } - } else if call, ok := ins.(*ssa.Defer); ok { - callRcv := call.Call.Value - if checkBlockList(&callRcv) { - return false - } + var callRcv ssa.Value + switch ins.(type) { + case *ssa.Call: + call, _ := ins.(*ssa.Call) + callRcv = call.Call.Value + case *ssa.Defer: + call, _ := ins.(*ssa.Defer) + callRcv = call.Call.Value + default: + continue } + if checkBlockList(&callRcv) { + return false + } + } } return true @@ -245,42 +221,34 @@ func localCheck(blks []*ssa.BasicBlock) bool { // check if single insturction is violating HTM or not func checkInst(ins ssa.Instruction) bool { - // no go func() in the critical section - if _, ok := ins.(*ssa.Go); ok { + var callRcv ssa.Value + + switch ins.(type) { + case *ssa.Call: + call, _ := ins.(*ssa.Call) + callRcv = call.Call.Value + case *ssa.Defer: + call, _ := ins.(*ssa.Defer) + callRcv = call.Call.Value + + default: + return true + } + + if checkBlockList(&callRcv) { return false } - if call, ok := ins.(*ssa.Call); ok { - callRcv := call.Call.Value - if checkBlockList(&callRcv) { - return false - } - if mCallGraph != nil { - nodeInGraph, ok := mCallGraph.Nodes[curNode] - if ok { - for _, edge := range nodeInGraph.Out { - if edge.Site.Common().Value.String() == callRcv.String() { - if mapFuncSafety[edge.Callee.Func] == false { - return false - } - } - } - } - } - } else if call, ok := ins.(*ssa.Defer); ok { - callRcv := call.Call.Value - if checkBlockList(&callRcv) { - return false - } - if mCallGraph != nil { - nodeInGraph, ok := mCallGraph.Nodes[curNode] - if ok { - for _, edge := range nodeInGraph.Out { - if edge.Site.Common().Value.String() == callRcv.String() { - if mapFuncSafety[edge.Callee.Func] == false { - return false - } - } - } + if mCallGraph == nil { + return true + } + nodeInGraph, ok := mCallGraph.Nodes[curNode] + if !ok { + return true + } + for _, edge := range nodeInGraph.Out { + if edge.Site.Common().Value.String() == callRcv.String() { + if mapFuncSafety[edge.Callee.Func] == false { + return false } } } @@ -411,14 +379,6 @@ func normalizeFunctionName(name string) string { return fName } -// return if the ssa function is in the given input package so that we can transform -func inFile(ssaF *ssa.Function) bool { - if _, ok := pkgName[ssaF.Pkg.Pkg.Name()]; ok { - return true - } - return false -} - // Find all paired locks in given function. // return a set of lockInfo for rewrite // currently it checks 4 patterns: @@ -430,10 +390,6 @@ func lockAnalysis(ssaF *ssa.Function, lockType string, lock string, unlock strin // lockMap to return from this function var lockMap = map[string]lockInfo{} - if ssaF == nil || inFile(ssaF) == false { - return lockMap - } - // store all lock/unlocks instructions in sets setLock := make(map[*ssa.CallCommon]bool) setUnlock := make(map[*ssa.CallCommon]bool) @@ -449,242 +405,191 @@ func lockAnalysis(ssaF *ssa.Function, lockType string, lock string, unlock strin setUnlockIndex := make(map[*ssa.CallCommon]int) setDeferUnlockIndex := make(map[*ssa.CallCommon]int) - if ssaF != nil && ssaF.Blocks != nil { - // detect lock/unlock and its declaration - for _, blk := range ssaF.Blocks { - for idx, ins := range blk.Instrs { - if call, ok := ins.(*ssa.Call); ok { - if !call.Call.IsInvoke() && call.Call.StaticCallee() != nil { - calleeName := call.Call.StaticCallee().Name() - callRcv := call.Call.Value - // lock/unlock detected - if callRcv != nil && strings.Contains(callRcv.String(), lockType) && calleeName == lock { - setLock[&call.Call] = true - mapLockToBB[&call.Call] = blk - setLockIndex[&call.Call] = idx - } else if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { - mapUnlockToBB[&call.Call] = blk - setUnlock[&call.Call] = true - setUnlockIndex[&call.Call] = idx - } - } - } else if call, ok := ins.(*ssa.Defer); ok { - if call.Call.StaticCallee() != nil { - callRcv := call.Call.Value - if callRcv != nil && strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { - mapDeferUnlockToBB[&call.Call] = blk - setDeferUnlock[&call.Call] = true - setDeferUnlockIndex[&call.Call] = idx - } - } - } - } - } + if ssaF == nil || ssaF.Blocks == nil { + return lockMap + } - // start the lock analysis - // lock and unlock in the same block, add instruction in between as critical section - for ul := range setUnlock { - // ulRcv := ul.Args[0].String() - for l := range setLock { - if setLockIndex[l] > setUnlockIndex[ul] { + // detect lock/unlock and its declaration + for _, blk := range ssaF.Blocks { + for idx, ins := range blk.Instrs { + if call, ok := ins.(*ssa.Call); ok { + if call.Call.IsInvoke() || call.Call.StaticCallee() == nil { continue } - lRcv := l.Args[0].String() - if isSameLock(l.Args[0], ul.Args[0]) && mapLockToBB[l] == mapUnlockToBB[ul] { - // check critical section - paired++ - if checkInstructionBetween(mapLockToBB[l], setLockIndex[l], setUnlockIndex[ul]) { - lockUnlockSameBB++ - setUnlock[ul] = false - setLock[l] = false - addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) - } else { - unsafeLock++ - } + calleeName := call.Call.StaticCallee().Name() + callRcv := call.Call.Value + // lock/unlock detected + if callRcv == nil { + continue } - } - } - - // lock and defer unlock in the same block - for ul := range setDeferUnlock { - // ulRcv := ul.Args[0].String() - for l := range setLock { - lRcv := l.Args[0].String() - if isSameLock(l.Args[0], ul.Args[0]) && mapLockToBB[l] == mapDeferUnlockToBB[ul] { - paired++ - // critical section is all reachable blocks from this block plus the remaining instruction after unlock - lstBlk := reachableBlks(mapLockToBB[l]) - blkInstNum := len(mapLockToBB[l].Instrs) - safetyInBB := false - // add this since defer unlock can be used before lock - if setLockIndex[l] < setDeferUnlockIndex[ul] { - safetyInBB = checkInstructionBetween(mapLockToBB[l], setLockIndex[l], setDeferUnlockIndex[ul]) && checkInstructionBetween(mapLockToBB[l], setDeferUnlockIndex[ul], blkInstNum) - } else { - safetyInBB = checkInstructionBetween(mapLockToBB[l], setLockIndex[l], blkInstNum) - } - if safetyInBB && checkBasicBlockInCriticalSection(lstBlk) { - lockDeferUnlockSameBB++ - setDeferUnlock[ul] = false - setLock[l] = false - addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) - } else { - unsafeLock++ - } + if strings.Contains(callRcv.String(), lockType) && calleeName == lock { + setLock[&call.Call] = true + mapLockToBB[&call.Call] = blk + setLockIndex[&call.Call] = idx + } else if strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { + mapUnlockToBB[&call.Call] = blk + setUnlock[&call.Call] = true + setUnlockIndex[&call.Call] = idx } - } - } - - postDomMap := postDom.PostDominators(ssaF) - // lock and and unlock in different bb, but dominates and post-dominates each other - for ul := range setUnlock { - // ulRcv := ul.Args[0].String() - for l := range setLock { - lRcv := l.Args[0].String() - lockBB, ok := mapLockToBB[l] - if ok == false { - break + } else if call, ok := ins.(*ssa.Defer); ok { + if call.Call.StaticCallee() == nil { + continue } - unlockBB, ok := mapUnlockToBB[ul] - if ok == false { - break + callRcv := call.Call.Value + if callRcv == nil { + continue } - if isSameLock(l.Args[0], ul.Args[0]) && lockBB.Dominates(unlockBB) { - paired++ - lkPD := postDomMap[mapLockToBB[l].Index] - if domContains(lkPD, mapUnlockToBB[ul].Index) { - // critical section is the bb between lock/unlock plus the remaining instruction in lock and unlock block - lkBlkInstNum := len(mapLockToBB[l].Instrs) - lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) - ulkBlkRemainInst := checkInstructionBetween(mapUnlockToBB[ul], -1, setUnlockIndex[ul]) - bbInBetween := basicBlockInBetween(mapLockToBB[l], mapUnlockToBB[ul]) - if lkBlkRemainInst && ulkBlkRemainInst && checkBasicBlockInCriticalSection(bbInBetween) { - lockUnlockPairDifferentBB++ - setUnlock[ul] = false - setLock[l] = false - addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) - } else { - unsafeLock++ - } - } + if strings.Contains(callRcv.String(), lockType) && call.Call.StaticCallee().Name() == unlock { + mapDeferUnlockToBB[&call.Call] = blk + setDeferUnlock[&call.Call] = true + setDeferUnlockIndex[&call.Call] = idx } } } + } - // lock and defer unlock in different bb, but dominates and post-dominate each other - for ul := range setDeferUnlock { - // ulRcv := ul.Args[0].String() - for l := range setLock { - lRcv := l.Args[0].String() - lockBB, ok := mapLockToBB[l] - if ok == false { - break - } - unlockBB, ok := mapUnlockToBB[ul] - if ok == false { - break - } - if isSameLock(l.Args[0], ul.Args[0]) && lockBB.Dominates(unlockBB) { - paired++ - lkPD := postDomMap[mapLockToBB[l].Index] - if domContains(lkPD, mapDeferUnlockToBB[ul].Index) { - // critical section is all block that can be reached by defer unlock and everything in between the lock and unlock - lkBlkInstNum := len(mapLockToBB[l].Instrs) - ulkBlkInstNumber := len(mapDeferUnlockToBB[ul].Instrs) - lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) - ulkBlkRemainInst := checkInstructionBetween(mapDeferUnlockToBB[ul], -1, setDeferUnlockIndex[ul]) && checkInstructionBetween(mapDeferUnlockToBB[ul], setDeferUnlockIndex[ul], ulkBlkInstNumber) - bbInBetween := basicBlockInBetween(mapLockToBB[l], mapUnlockToBB[ul]) - bbReachable := reachableBlks(mapDeferUnlockToBB[ul]) - if lkBlkRemainInst && ulkBlkRemainInst && checkBasicBlockInCriticalSection(bbReachable) && checkBasicBlockInCriticalSection(bbInBetween) { - lockDeferUnlockPairDifferentBB++ - setDeferUnlock[ul] = false - setLock[l] = false - addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) - } else { - unsafeLock++ - } - } - } else if isSameLock(l.Args[0], ul.Args[0]) && unlockBB.Dominates(lockBB) { - paired++ - // add this since defer unlock can happen before lock - ulkPD := postDomMap[mapLockToBB[ul].Index] - if domContains(ulkPD, mapDeferUnlockToBB[l].Index) { - // critical section is all block that can be reached by defer unlock and everything in between the lock and unlock - lkBlkInstNum := len(mapLockToBB[l].Instrs) - lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) - bbReachable := reachableBlks(mapLockToBB[l]) - if lkBlkRemainInst && checkBasicBlockInCriticalSection(bbReachable) { - lockDeferUnlockPairDifferentBB++ - setDeferUnlock[ul] = false - setLock[l] = false - addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) - } else { - unsafeLock++ - } - } - } + // start the lock analysis + // lock and unlock in the same block, add instruction in between as critical section + for ul := range setUnlock { + // ulRcv := ul.Args[0].String() + for l := range setLock { + lRcv := l.Args[0].String() + if !isSameLock(l.Args[0], ul.Args[0]) || (mapLockToBB[l] != mapUnlockToBB[ul]) { + continue + } + if !checkInstructionBetween(mapLockToBB[l], setLockIndex[l], setUnlockIndex[ul]) { + continue } + // check critical section + lockUnlockSameBB++ + setUnlock[ul] = false + setLock[l] = false + addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) } } - for _, v := range setLock { - if v == true { - unpaired++ - } - } - for _, v := range setUnlock { - if v == true { - unpaired++ - } - } - for _, v := range setDeferUnlock { - if v == true { - unpaired++ + // lock and defer unlock in the same block + for ul := range setDeferUnlock { + // ulRcv := ul.Args[0].String() + for l := range setLock { + lRcv := l.Args[0].String() + if !isSameLock(l.Args[0], ul.Args[0]) || (mapLockToBB[l] != mapDeferUnlockToBB[ul]) { + continue + } + + // critical section is all reachable blocks from this block plus the remaining instruction after unlock + lstBlk := reachableBlks(mapLockToBB[l]) + blkInstNum := len(mapLockToBB[l].Instrs) + safetyInBB := false + // add this since defer unlock can be used before lock + if setLockIndex[l] < setDeferUnlockIndex[ul] { + safetyInBB = checkInstructionBetween(mapLockToBB[l], setLockIndex[l], setDeferUnlockIndex[ul]) && checkInstructionBetween(mapLockToBB[l], setDeferUnlockIndex[ul], blkInstNum) + } else { + safetyInBB = checkInstructionBetween(mapLockToBB[l], setLockIndex[l], blkInstNum) + } + if !safetyInBB || !checkBasicBlockInCriticalSection(lstBlk) { + continue + } + lockDeferUnlockSameBB++ + setDeferUnlock[ul] = false + setLock[l] = false + addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) } } - numLock += len(setLock) - numUnlock += len(setUnlock) - numDeferUnlock += len(setDeferUnlock) - return lockMap -} - -func pathContains(replacePath [][]ast.Node, curPos token.Pos) bool { - for _, path := range replacePath { - for _, node := range path { - if node.Pos() == curPos { - return true + postDomMap := postDom.PostDominators(ssaF) + // lock and and unlock in different bb, but dominates and post-dominates each other + for ul := range setUnlock { + // ulRcv := ul.Args[0].String() + for l := range setLock { + lRcv := l.Args[0].String() + lockBB, ok := mapLockToBB[l] + if ok == false { + break + } + unlockBB, ok := mapUnlockToBB[ul] + if ok == false { + break + } + if !isSameLock(l.Args[0], ul.Args[0]) || (!lockBB.Dominates(unlockBB)) { + continue + } + lkPD := postDomMap[mapLockToBB[l].Index] + if !domContains(lkPD, mapUnlockToBB[ul].Index) { + continue + } + // critical section is the bb between lock/unlock plus the remaining instruction in lock and unlock block + lkBlkInstNum := len(mapLockToBB[l].Instrs) + lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) + ulkBlkRemainInst := checkInstructionBetween(mapUnlockToBB[ul], -1, setUnlockIndex[ul]) + bbInBetween := basicBlockInBetween(mapLockToBB[l], mapUnlockToBB[ul]) + if lkBlkRemainInst && ulkBlkRemainInst && checkBasicBlockInCriticalSection(bbInBetween) { + lockUnlockPairDifferentBB++ + setUnlock[ul] = false + setLock[l] = false + addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) } } } - return false -} -func getPosName(path [][]ast.Node, curPos token.Pos, posToID map[token.Pos]string) string { - for _, p := range path { - for _, node := range p { - if node.Pos() == curPos { - for _, n := range p { - if str, ok := posToID[pathToEndNodePos[n]]; ok { - return str - } + // lock and defer unlock in different bb, but dominates and post-dominate each other + for ul := range setDeferUnlock { + // ulRcv := ul.Args[0].String() + for l := range setLock { + lRcv := l.Args[0].String() + lockBB, ok := mapLockToBB[l] + if ok == false { + break + } + unlockBB, ok := mapUnlockToBB[ul] + if ok == false { + break + } + if isSameLock(l.Args[0], ul.Args[0]) && lockBB.Dominates(unlockBB) { + lkPD := postDomMap[mapLockToBB[l].Index] + if !domContains(lkPD, mapDeferUnlockToBB[ul].Index) { + continue + } + // critical section is all block that can be reached by defer unlock and everything in between the lock and unlock + lkBlkInstNum := len(mapLockToBB[l].Instrs) + ulkBlkInstNumber := len(mapDeferUnlockToBB[ul].Instrs) + lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) + ulkBlkRemainInst := checkInstructionBetween(mapDeferUnlockToBB[ul], -1, setDeferUnlockIndex[ul]) && checkInstructionBetween(mapDeferUnlockToBB[ul], setDeferUnlockIndex[ul], ulkBlkInstNumber) + bbInBetween := basicBlockInBetween(mapLockToBB[l], mapUnlockToBB[ul]) + bbReachable := reachableBlks(mapDeferUnlockToBB[ul]) + if lkBlkRemainInst && ulkBlkRemainInst && checkBasicBlockInCriticalSection(bbReachable) && checkBasicBlockInCriticalSection(bbInBetween) { + lockDeferUnlockPairDifferentBB++ + setDeferUnlock[ul] = false + setLock[l] = false + addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) + } + } else if isSameLock(l.Args[0], ul.Args[0]) && unlockBB.Dominates(lockBB) { + // add this since defer unlock can happen before lock + ulkPD := postDomMap[mapLockToBB[ul].Index] + if !domContains(ulkPD, mapDeferUnlockToBB[l].Index) { + continue + } + // critical section is all block that can be reached by defer unlock and everything in between the lock and unlock + lkBlkInstNum := len(mapLockToBB[l].Instrs) + lkBlkRemainInst := checkInstructionBetween(mapLockToBB[l], setLockIndex[l], lkBlkInstNum) + bbReachable := reachableBlks(mapLockToBB[l]) + if lkBlkRemainInst && checkBasicBlockInCriticalSection(bbReachable) { + lockDeferUnlockPairDifferentBB++ + setDeferUnlock[ul] = false + setLock[l] = false + addLockPairToLockInfo(lockMap, lRcv, l, ul, ssaF) } } } } - return "" + return lockMap } -func reversePathContains(replacePath [][]ast.Node, curPos token.Pos) bool { +func pathContains(replacePath [][]ast.Node, curPos token.Pos) bool { for _, path := range replacePath { - _PATHLOOP: - for i := 0; i < len(path); i++ { - switch n := path[i].(type) { - case *ast.BlockStmt: - if n.Pos() == curPos { - return true - } else if _, ok := blkstmtMap[n.Pos()]; ok { - break _PATHLOOP - } + for _, node := range path { + if node.Pos() == curPos { + return true } } } @@ -701,34 +606,16 @@ func singlePathContains(singlePath []ast.Node, curPos token.Pos) bool { } // adds context variable definition at the beginning of the function's statement list -func addContextInitStmt(stmtsList *[]ast.Stmt, sigPos token.Pos, count int) { - for i := 0; i < count; i++ { - newStmt := ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(majicLockName + strconv.Itoa(i))}, - TokPos: sigPos, // use concrete position to avoid being split by a comment leading to syntax error - Tok: token.DEFINE, - Rhs: []ast.Expr{ast.NewIdent("rtm.OptiLock{}")}} - var newStmtsList []ast.Stmt - newStmtsList = append(newStmtsList, &newStmt) - newStmtsList = append(newStmtsList, (*stmtsList)...) - *stmtsList = newStmtsList - } -} - -func collectBlkstmt(f ast.Node, pkg *packages.Package) { - postFunc := func(c *astutil.Cursor) bool { - node := c.Node() - switch node.(type) { - case *ast.BlockStmt: - { - if _, ok := c.Parent().(*ast.FuncLit); ok && c.Name() == "Body" { - blkstmtMap[c.Node().Pos()] = true - } - } - } - return true - } - astutil.Apply(f, nil, postFunc) +func addContextInitStmt(stmtsList *[]ast.Stmt, sigPos token.Pos) { + newStmt := ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(optiLockName)}, + TokPos: sigPos, // use concrete position to avoid being split by a comment leading to syntax error + Tok: token.DEFINE, + Rhs: []ast.Expr{ast.NewIdent("rtm.OptiLock{}")}} + var newStmtsList []ast.Stmt + newStmtsList = append(newStmtsList, &newStmt) + newStmtsList = append(newStmtsList, (*stmtsList)...) + *stmtsList = newStmtsList } // given the function f from the pkg, replace all the valid lock/unlock with htm library @@ -743,11 +630,9 @@ func collectBlkstmt(f ast.Node, pkg *packages.Package) { // foo := Foo{} // foo.Lock() // 4. Promoted field pointer -func rewriteAST(f ast.Node, pkg *packages.Package, replacePathRWMutex, insertPathRWMutex, replacePathMutex, insertPathMutex, lambdaPath, normalPath *[][]ast.Node, posToID map[token.Pos]string) ast.Node { +func rewriteAST(f ast.Node, pkg *packages.Package, replacePathRWMutex, insertPathRWMutex, replacePathMutex, insertPathMutex *[][]ast.Node) ast.Node { fmt.Println(" Rewriting field accesses in the file...") blkmap := make(map[*ast.BlockStmt]bool) - addImport := false - optilockNumber := len(posToID) / 2 postFunc := func(c *astutil.Cursor) bool { node := c.Node() switch n := node.(type) { @@ -757,379 +642,402 @@ func rewriteAST(f ast.Node, pkg *packages.Package, replacePathRWMutex, insertPat switch { case pathContains(*replacePathMutex, n.Pos()): { - if se, ok := n.Fun.(*ast.SelectorExpr); ok { - if se.Sel.Name == "Lock" || se.Sel.Name == "Unlock" { - lockType := typesMap[se.X].Type.String() - // branch 1: receiver is lock pointer - if strings.Contains(lockType, "*sync.Mutex") { - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{se.X}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 4: receiver is promoted field pointer - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newSel := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("Mutex"), - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newSel}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } + se, ok := n.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + if !isLockUnlock(se.Sel.Name) { + return true + } + + lockType := typesMap[se.X].Type.String() + // branch 1: receiver is lock pointer + if lockType == "*sync.Mutex" { + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{se.X}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, } + + c.Replace(call) + return true + } + // else branch 4: receiver is promoted field pointer + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newSel := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("Mutex"), } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newSel}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + c.Replace(call) + return true } case pathContains(*insertPathMutex, n.Pos()): { // when the lock is a sync.Mutex - if se, ok := n.Fun.(*ast.SelectorExpr); ok { - if se.Sel.Name == "Lock" || se.Sel.Name == "Unlock" { - lockType := typesMap[se.X].Type.String() - if lockType == "sync.Mutex" { - // branch 2: receiver is lock object, need to take its address - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: se.X, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 3: receiver is some promoted field object - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newSel := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("Mutex"), - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: newSel, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } + se, ok := n.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + if !isLockUnlock(se.Sel.Name) { + return true + } + + lockType := typesMap[se.X].Type.String() + if lockType == "sync.Mutex" { + // branch 2: receiver is lock object, need to take its address + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newX := &ast.UnaryExpr{ + Op: token.AND, + X: se.X, } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + + c.Replace(call) + return true + } + // else branch 3: receiver is some promoted field object + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newSel := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("Mutex"), + } + newX := &ast.UnaryExpr{ + Op: token.AND, + X: newSel, } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + c.Replace(call) + return true } case pathContains(*replacePathRWMutex, n.Pos()): { + + // when the lock is a sync.Mutex + se, ok := n.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + // when the lock is *sync.RWMutex - if se, ok := n.Fun.(*ast.SelectorExpr); ok { - // change RWMutex.Lock()/Unlock() to majicLock.WLock()/WUnlock() - if se.Sel.Name == "Lock" || se.Sel.Name == "Unlock" { - lockType := typesMap[se.X].Type.String() - // branch 1: receiver is a lock pointer - if lockType == "*sync.RWMutex" { - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: &ast.Ident{ - Name: "W" + se.Sel.Name, - NamePos: se.Sel.NamePos, - }, - } - - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{se.X}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 3: promoted field pointer - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: &ast.Ident{ - Name: "W" + se.Sel.Name, - NamePos: se.Sel.NamePos, - }, - } - newX := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("RWMutex"), - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) + // change RWMutex.Lock()/Unlock() to majicLock.WLock()/WUnlock() + + if !isLockUnlockRLockRUnlock(se.Sel.Name) { + return true + } + + if isLockUnlock(se.Sel.Name) { + lockType := typesMap[se.X].Type.String() + // branch 1: receiver is a lock pointer + if lockType == "*sync.RWMutex" { + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: &ast.Ident{ + Name: "W" + se.Sel.Name, + NamePos: se.Sel.NamePos, + }, } - } else if se.Sel.Name == "RLock" || se.Sel.Name == "RUnlock" { - // this changes RLock/RUnlock of RWMutex - lockType := typesMap[se.X].Type.String() - - if lockType == "*sync.RWMutex" { - // branch 1: receiver is lock pointer - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: &ast.Ident{ - Name: se.Sel.Name, - NamePos: se.Sel.NamePos, - }, - } - - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{se.X}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 4: lock is called on promoted field pointer - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*replacePathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newX := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("RWMutex"), - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) + + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{se.X}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, } + + c.Replace(call) + return true + } + // else branch 3: promoted field pointer + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: &ast.Ident{ + Name: "W" + se.Sel.Name, + NamePos: se.Sel.NamePos, + }, + } + newX := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("RWMutex"), } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + + c.Replace(call) + return true } + // else se.Sel.Name == "RLock" || se.Sel.Name == "RUnlock" { + // this changes RLock/RUnlock of RWMutex + lockType := typesMap[se.X].Type.String() + + if lockType == "*sync.RWMutex" { + // branch 1: receiver is lock pointer + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: &ast.Ident{ + Name: se.Sel.Name, + NamePos: se.Sel.NamePos, + }, + } + + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{se.X}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + c.Replace(call) + return true + } + // else // branch 4: lock is called on promoted field pointer + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newX := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("RWMutex"), + } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + c.Replace(call) + return true } case pathContains(*insertPathRWMutex, n.Pos()): { + // when the lock is a sync.Mutex + se, ok := n.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if !isLockUnlockRLockRUnlock(se.Sel.Name) { + return true + } + // when the lock is sync.RWMutex - if se, ok := n.Fun.(*ast.SelectorExpr); ok { - // change RWMutex.Lock()/Unlock() to majicLock.WLock()/WUnlock() - if se.Sel.Name == "Lock" || se.Sel.Name == "Unlock" { - lockType := typesMap[se.X].Type.String() - if lockType == "sync.RWMutex" { - // branch 2: receiver is a lock value - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: &ast.Ident{ - Name: "W" + se.Sel.Name, - NamePos: se.Sel.NamePos, - }, - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: se.X, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 3: promoted field object - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: &ast.Ident{ - Name: "W" + se.Sel.Name, - NamePos: se.Sel.NamePos, - }, - } - newSel := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("RWMutex"), - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: newSel, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) + // change RWMutex.Lock()/Unlock() to majicLock.WLock()/WUnlock() + if isLockUnlock(se.Sel.Name) { + lockType := typesMap[se.X].Type.String() + if lockType == "sync.RWMutex" { + // branch 2: receiver is a lock value + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: &ast.Ident{ + Name: "W" + se.Sel.Name, + NamePos: se.Sel.NamePos, + }, } - } else if se.Sel.Name == "RLock" || se.Sel.Name == "RUnlock" { - // this changes RLock/RUnlock of RWMutex - lockType := typesMap[se.X].Type.String() - if lockType == "sync.RWMutex" { - // branch 2, lock is called on lock value - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: se.X, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) - } else { - // branch 3: lock is called on promoted field value - fun := &ast.SelectorExpr{ - X: &ast.Ident{ - Name: majicLockName + getPosName(*insertPathRWMutex, n.Pos(), posToID), - NamePos: se.X.Pos(), - }, - Sel: se.Sel, - } - newSel := &ast.SelectorExpr{ - X: se.X, - Sel: ast.NewIdent("RWMutex"), - } - newX := &ast.UnaryExpr{ - Op: token.AND, - X: newSel, - } - call := &ast.CallExpr{ - Fun: fun, - Lparen: token.NoPos, - Args: []ast.Expr{newX}, - Ellipsis: token.NoPos, - Rparen: token.NoPos, - } - - c.Replace(call) + newX := &ast.UnaryExpr{ + Op: token.AND, + X: se.X, } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + + c.Replace(call) + return true } + // else branch 3: promoted field object + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: &ast.Ident{ + Name: "W" + se.Sel.Name, + NamePos: se.Sel.NamePos, + }, + } + newSel := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("RWMutex"), + } + newX := &ast.UnaryExpr{ + Op: token.AND, + X: newSel, + } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + + c.Replace(call) + return true + } + // else se.Sel.Name == "RLock" || se.Sel.Name == "RUnlock" { + // this changes RLock/RUnlock of RWMutex + lockType := typesMap[se.X].Type.String() + if lockType == "sync.RWMutex" { + // branch 2, lock is called on lock value + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newX := &ast.UnaryExpr{ + Op: token.AND, + X: se.X, + } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + + c.Replace(call) + return true } + // else branch 3: lock is called on promoted field value + fun := &ast.SelectorExpr{ + X: &ast.Ident{ + Name: optiLockName, + NamePos: se.X.Pos(), + }, + Sel: se.Sel, + } + newSel := &ast.SelectorExpr{ + X: se.X, + Sel: ast.NewIdent("RWMutex"), + } + newX := &ast.UnaryExpr{ + Op: token.AND, + X: newSel, + } + call := &ast.CallExpr{ + Fun: fun, + Lparen: token.NoPos, + Args: []ast.Expr{newX}, + Ellipsis: token.NoPos, + Rparen: token.NoPos, + } + c.Replace(call) + return true } } } case *ast.ImportSpec: { - if addImport == false { - newImport := &ast.ImportSpec{ - Doc: n.Doc, - Name: ast.NewIdent("rtm"), - Path: &ast.BasicLit{ - ValuePos: n.Path.ValuePos, - Kind: n.Path.Kind, - Value: strconv.Quote("github.com/lollllcat/GOCC/tools/gocc/rtmlib"), - }, - Comment: n.Comment, - EndPos: n.EndPos, - } - c.InsertAfter(newImport) - addImport = true + if n.Path.Value != "\"sync\"" { + return true + } + newImport := &ast.ImportSpec{ + Doc: n.Doc, + Name: ast.NewIdent("rtm"), + Path: &ast.BasicLit{ + ValuePos: n.Path.ValuePos, + Kind: n.Path.Kind, + Value: strconv.Quote("github.com/uber-research/GOCC/tools/gocc/rtmlib"), + }, + Comment: n.Comment, + EndPos: n.EndPos, } + c.InsertAfter(newImport) + return true } case *ast.BlockStmt: { // purpose of this part is to add the majicLock declarations to the funtions that we transform blkStmt := c.Node().(*ast.BlockStmt) // not added before - if _, ok := blkmap[blkStmt]; !ok { - if fd, ok := c.Parent().(*ast.FuncDecl); ok && c.Name() == "Body" { - // containLocks := pathContains(*replacePathMutex, n.Pos()) || pathContains(*replacePathRWMutex, n.Pos()) || pathContains(*insertPathRWMutex, n.Pos()) || pathContains(*insertPathMutex, n.Pos()) - containLocks := pathContains(*normalPath, n.Pos()) - if containLocks { - addContextInitStmt(&(fd.Body.List), fd.Name.NamePos, optilockNumber) - blkmap[blkStmt] = true - } - } else if fl, ok := c.Parent().(*ast.FuncLit); ok && c.Name() == "Body" { - // containLocks := pathContains(*replacePathMutex, n.Pos()) || pathContains(*replacePathRWMutex, n.Pos()) || pathContains(*insertPathRWMutex, n.Pos()) || pathContains(*insertPathMutex, n.Pos()) - containLocks := reversePathContains(*lambdaPath, c.Node().Pos()) - if containLocks { - addContextInitStmt(&(fl.Body.List), fl.Body.Lbrace, optilockNumber) - blkmap[blkStmt] = true - } + if _, ok := blkmap[blkStmt]; ok { + return true + + } + if fd, ok := c.Parent().(*ast.FuncDecl); ok && c.Name() == "Body" { + containLocks := pathContains(*replacePathMutex, n.Pos()) || pathContains(*replacePathRWMutex, n.Pos()) || pathContains(*insertPathRWMutex, n.Pos()) || pathContains(*insertPathMutex, n.Pos()) + if containLocks { + addContextInitStmt(&(fd.Body.List), fd.Name.NamePos) + blkmap[blkStmt] = true } } } + return true } return true } @@ -1137,26 +1045,27 @@ func rewriteAST(f ast.Node, pkg *packages.Package, replacePathRWMutex, insertPat } func writeAST(f ast.Node, sourceFilePath string, pkg *packages.Package, filename string) { - if writeOutput { - fmt.Println(" Writing output to ", filename) - info, err := os.Stat(filename) - if err != nil { - panic(err) - } - fSize := info.Size() - os.Remove(filename) - output, err := os.Create(filename) - if err != nil { - panic(err) - } - defer output.Close() + if !writeOutput { + return + } + fmt.Println(" Writing output to ", filename) + info, err := os.Stat(filename) + if err != nil { + panic(err) + } + fSize := info.Size() + os.Remove(filename) + output, err := os.Create(filename) + if err != nil { + panic(err) + } + defer output.Close() - w := bufio.NewWriterSize(output, int(2*fSize)) - if err := format.Node(w, pkg.Fset, f); err != nil { - panic(err) - } - w.Flush() + w := bufio.NewWriterSize(output, int(2*fSize)) + if err := format.Node(w, pkg.Fset, f); err != nil { + panic(err) } + w.Flush() } func argContains(args []string, target string) bool { @@ -1168,25 +1077,41 @@ func argContains(args []string, target string) bool { return false } +func isLockUnlockRLockRUnlock(name string) bool { + return isLockUnlock(name) || isRLockRUnlock(name) +} + +func isLockUnlock(name string) bool { + return name == "Lock" || name == "Unlock" +} + +func isRLockRUnlock(name string) bool { + return name == "RLock" || name == "RUnlock" +} + func getLockName(f ast.Node, pkg *packages.Package, path []ast.Node) string { var str string postFunc := func(c *astutil.Cursor) bool { node := c.Node() switch n := node.(type) { case *ast.CallExpr: - if singlePathContains(path, n.Pos()) { - if se, ok := n.Fun.(*ast.SelectorExpr); ok { - if se.Sel.Name == "Lock" || se.Sel.Name == "Unlock" || se.Sel.Name == "RLock" || se.Sel.Name == "RUnlock" { - var selectorExpr bytes.Buffer - err := printer.Fprint(&selectorExpr, pkg.Fset, se.X) - if err != nil { - log.Fatalf("failed printing %s", err) - } - str = selectorExpr.String() - return true - } - } + if !singlePathContains(path, n.Pos()) { + return true + } + se, ok := n.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + if !isLockUnlockRLockRUnlock(se.Sel.Name) { + return true + } + var selectorExpr bytes.Buffer + err := printer.Fprint(&selectorExpr, pkg.Fset, se.X) + if err != nil { + log.Fatalf("failed printing %s", err) } + str = selectorExpr.String() + return true } return true } @@ -1209,21 +1134,23 @@ func checkSameLock(l lockInfo, pkgs []*packages.Package) bool { // all lock() name should be the same for _, l := range l.lockPosition { lockPath, ok := astutil.PathEnclosingInterval(file, l, l) - if ok { - lName := getLockName(file, pkg, lockPath) - if lockName != lName { - return false - } + if !ok { + continue + } + lName := getLockName(file, pkg, lockPath) + if lockName != lName { + return false } } // all unlock() name should be the same for _, ul := range l.unlockPosition { unlockPath, ok := astutil.PathEnclosingInterval(file, ul, ul) - if ok { - ulName := getLockName(file, pkg, unlockPath) - if ulName != lockName { - return false - } + if !ok { + continue + } + ulName := getLockName(file, pkg, unlockPath) + if ulName != lockName { + return false } } } @@ -1248,13 +1175,7 @@ func main() { // this map stores the hot function from the profiling hotFuncMap = make(map[string]bool) - lockInLambdaFunc = make(map[token.Pos]bool) - blkstmtMap = make(map[token.Pos]bool) - mPkg = make(map[string]int) - pkgName = make(map[string]bool) - tokenToName = make(map[token.Pos]string) - pathToEndNodePos = make(map[ast.Node]token.Pos) // lock positions to rewrite replacedRWMutexPtr := make(map[token.Pos]bool) @@ -1274,8 +1195,6 @@ func main() { syntheticPtr := flag.Bool("synthetic", false, "set true if the synthetic main from transformer is used") - rewriteTestFile := flag.Bool("rewriteTest", false, "set true if you want to change testing file") - flag.Parse() if inputFile == "" { @@ -1330,14 +1249,10 @@ func main() { panic("something wrong during loading!") } - for _, pkg := range pkgs { - pkgName[pkg.Name] = true - } - prog, ssapkgs := ssautil.AllPackages(pkgs, ssa.NaiveForm|ssa.GlobalDebug) // prog, ssapkgs := ssautil.AllPackages(pkgs, ssa.GlobalDebug) libbuilder.BuildPackages(prog, ssapkgs, true, true) - mCallGraph := libcg.BuildRtaCG(prog, true) + mCallGraph := libcg.BuildRtaCG(prog, false) // TODO: first pass on optimized form and second pass on naive form to check if it is a value or object @@ -1365,8 +1280,7 @@ func main() { curNode = node postDom.GetExit(node) lockInfoMap := lockAnalysis(node, lockType, lockFuncName, unlockFuncName) - isLambda := strings.Contains(node.RelString(node.Pkg.Pkg), "$") - for lkName, value := range lockInfoMap { + for _, value := range lockInfoMap { // if this is not a same lock, don't replace. // TODO: fix this if checkSameLock(value, pkgs) == false { @@ -1379,33 +1293,17 @@ func main() { // a RWMutex pointer not a value for _, item := range value.lockPosition { replacedRWMutexPtr[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } for _, item := range value.unlockPosition { replacedRWMutexPtr[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } } else { // a RWMutex value for _, item := range value.lockPosition { replacedRWMutexVal[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } for _, item := range value.unlockPosition { replacedRWMutexVal[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } } } else { @@ -1413,33 +1311,17 @@ func main() { // a Mutex pointer not a value for _, item := range value.lockPosition { replacedMutexPtr[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } for _, item := range value.unlockPosition { replacedMutexPtr[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } } else { // a Mutex value for _, item := range value.lockPosition { replacedMutexVal[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } for _, item := range value.unlockPosition { replacedMutexVal[item] = true - tokenToName[item] = lkName - if isLambda { - lockInLambdaFunc[item] = true - } } } } @@ -1451,101 +1333,36 @@ func main() { usesMap = pkg.TypesInfo.Uses typesMap = pkg.TypesInfo.Types for _, file := range pkg.Syntax { - nameToID := make(map[string]string) - posToID := make(map[token.Pos]string) - id := 0 - // don't rewrite testing file by default - if *rewriteTestFile == false { - fName := pkg.Fset.Position(file.Pos()).Filename - if strings.HasSuffix(fName, "_test.go") { - // fmt.Printf("%v is skipped since it is a testing file\n", fName) - continue - } - } - // replacePath means the lock is a pointer so we can replace it directly // insertPath indicates the lock is a value and we need to take the address of it. replacePathRWMutex := make([][]ast.Node, 0) insertPathRWMutex := make([][]ast.Node, 0) replacePathMutex := make([][]ast.Node, 0) insertPathMutex := make([][]ast.Node, 0) - lambdaPath := make([][]ast.Node, 0) - normalPath := make([][]ast.Node, 0) + for l := range replacedRWMutexPtr { path, ok := astutil.PathEnclosingInterval(file, l, l) if ok { - pathToEndNodePos[path[0]] = l replacePathRWMutex = append(replacePathRWMutex, path) - if _, ok := lockInLambdaFunc[l]; ok { - lambdaPath = append(lambdaPath, path) - } else { - normalPath = append(normalPath, path) - } - if val, ok := nameToID[tokenToName[l]]; ok { - posToID[l] = val - } else { - nameToID[tokenToName[l]] = strconv.Itoa(id) - posToID[l] = strconv.Itoa(id) - id++ - } } } for l := range replacedRWMutexVal { path, ok := astutil.PathEnclosingInterval(file, l, l) if ok { - pathToEndNodePos[path[0]] = l insertPathRWMutex = append(insertPathRWMutex, path) - if _, ok := lockInLambdaFunc[l]; ok { - lambdaPath = append(lambdaPath, path) - } else { - normalPath = append(normalPath, path) - } - if val, ok := nameToID[tokenToName[l]]; ok { - posToID[l] = val - } else { - nameToID[tokenToName[l]] = strconv.Itoa(id) - posToID[l] = strconv.Itoa(id) - id++ - } } } for l := range replacedMutexPtr { path, ok := astutil.PathEnclosingInterval(file, l, l) if ok { - pathToEndNodePos[path[0]] = l replacePathMutex = append(replacePathMutex, path) - if _, ok := lockInLambdaFunc[l]; ok { - lambdaPath = append(lambdaPath, path) - } else { - normalPath = append(normalPath, path) - } - if val, ok := nameToID[tokenToName[l]]; ok { - posToID[l] = val - } else { - nameToID[tokenToName[l]] = strconv.Itoa(id) - posToID[l] = strconv.Itoa(id) - id++ - } } } for l := range replacedMutexVal { path, ok := astutil.PathEnclosingInterval(file, l, l) if ok { - pathToEndNodePos[path[0]] = l insertPathMutex = append(insertPathMutex, path) - if _, ok := lockInLambdaFunc[l]; ok { - lambdaPath = append(lambdaPath, path) - } else { - normalPath = append(normalPath, path) - } - if val, ok := nameToID[tokenToName[l]]; ok { - posToID[l] = val - } else { - nameToID[tokenToName[l]] = strconv.Itoa(id) - posToID[l] = strconv.Itoa(id) - id++ - } } } @@ -1555,8 +1372,7 @@ func main() { } if numberOfLocksToChange > 0 && writeOutput { fmt.Printf("Number of locks to rewrite %v\n", numberOfLocksToChange) - collectBlkstmt(file, pkg) - ast := rewriteAST(file, pkg, &replacePathRWMutex, &insertPathRWMutex, &replacePathMutex, &insertPathMutex, &lambdaPath, &normalPath, posToID) + ast := rewriteAST(file, pkg, &replacePathRWMutex, &insertPathRWMutex, &replacePathMutex, &insertPathMutex) filename := prog.Fset.Position(ast.Pos()).Filename writeAST(ast, inputFile, pkg, filename) } @@ -1581,9 +1397,6 @@ func main() { fmt.Fprintln(w, "Defer lock pairs in same BB: ", lockDeferUnlockSameBB) fmt.Fprintln(w, "Lock pairs dominate each other: ", lockUnlockPairDifferentBB) fmt.Fprintln(w, "Defer Lock pairs dominate each other: ", lockDeferUnlockPairDifferentBB) - fmt.Fprintln(w, "Unsafe lock instructions that are dropped: ", unsafeLock) - fmt.Fprintln(w, "Unpaired locks: ", unpaired) - fmt.Fprintln(w, "Paired locks: ", paired) w.Flush()