diff --git a/cmd/flow/main.go b/cmd/flow/main.go index 446876c9b..cdf71871a 100644 --- a/cmd/flow/main.go +++ b/cmd/flow/main.go @@ -35,6 +35,7 @@ import ( "github.com/onflow/flow-cli/internal/events" evm "github.com/onflow/flow-cli/internal/evm" "github.com/onflow/flow-cli/internal/keys" + "github.com/onflow/flow-cli/internal/mcp" "github.com/onflow/flow-cli/internal/project" "github.com/onflow/flow-cli/internal/quick" "github.com/onflow/flow-cli/internal/schedule" @@ -92,6 +93,7 @@ func main() { cmd.AddCommand(dependencymanager.Cmd) cmd.AddCommand(evm.Cmd) cmd.AddCommand(schedule.Cmd) + cmd.AddCommand(mcp.Cmd) command.InitFlags(cmd) cmd.AddGroup(&cobra.Group{ diff --git a/go.mod b/go.mod index b0a09db48..5a0d69133 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/getsentry/sentry-go v0.43.0 github.com/gosuri/uilive v0.0.4 github.com/logrusorgru/aurora/v4 v4.0.0 + github.com/mark3labs/mcp-go v0.45.0 github.com/onflow/cadence v1.10.0 github.com/onflow/cadence-tools/languageserver v1.10.0 github.com/onflow/cadence-tools/lint v1.9.0 @@ -60,11 +61,13 @@ require ( github.com/VictoriaMetrics/fastcache v1.13.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.24.4 // indirect github.com/btcsuite/btcd/btcec/v2 v2.3.4 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.0.3 // indirect + github.com/buger/jsonparser v1.1.2 // indirect github.com/c-bata/go-prompt v0.2.6 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect @@ -149,10 +152,9 @@ require ( github.com/huandu/go-clone v1.6.0 // indirect github.com/huandu/go-clone/generic v1.7.2 // indirect github.com/huin/goupnp v1.3.0 // indirect - github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 // indirect github.com/improbable-eng/grpc-web v0.15.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/invopop/jsonschema v0.7.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/ipfs/bbloom v0.0.4 // indirect github.com/ipfs/boxo v0.17.1-0.20240131173518-89bceff34bf1 // indirect github.com/ipfs/go-block-format v0.2.0 // indirect @@ -179,6 +181,7 @@ require ( github.com/lmars/go-slip10 v0.0.0-20190606092855-400ba44fee12 // indirect github.com/logrusorgru/aurora v2.0.3+incompatible // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -263,9 +266,11 @@ require ( github.com/tyler-smith/go-bip39 v1.1.0 // indirect github.com/vmihailenco/msgpack/v4 v4.3.11 // indirect github.com/vmihailenco/tagparser v0.1.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/wlynxg/anet v0.0.5 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/zeebo/blake3 v0.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect diff --git a/go.sum b/go.sum index 4bb233f75..0ee00e1cd 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= @@ -106,6 +108,8 @@ github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurT github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.3 h1:SDlJ7bAm4ewvrmZtR0DaiYbQGdKPeaaIm7bM+qRhFeU= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.3/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= +github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= +github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k= github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw= github.com/bytedance/sonic/loader v0.1.0 h1:skjHJ2Bi9ibbq3Dwzh1w42MQ7wZJrXmEZr/uqUn3f0Q= @@ -529,8 +533,6 @@ github.com/huandu/go-clone/generic v1.7.2/go.mod h1:xgd9ZebcMsBWWcBx5mVMCoqMX24g github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= -github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 h1:i462o439ZjprVSFSZLZxcsoAe592sZB1rci2Z8j4wdk= -github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0/go.mod h1:N0Wam8K1arqPXNWjMo21EXnBPOPp36vB07FNRdD2geA= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/improbable-eng/grpc-web v0.15.0 h1:BN+7z6uNXZ1tQGcNAuaU1YjsLTApzkjt2tzCixLaUPQ= github.com/improbable-eng/grpc-web v0.15.0/go.mod h1:1sy9HKV4Jt9aEs9JSnkWlRJPuPtwNr0l57L4f878wP8= @@ -538,8 +540,8 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= -github.com/invopop/jsonschema v0.7.0 h1:2vgQcBz1n256N+FpX3Jq7Y17AjYt46Ig3zIWyy770So= -github.com/invopop/jsonschema v0.7.0/go.mod h1:O9uiLokuu0+MGFlyiaqtWxwqJm41/+8Nj0lD7A36YH0= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/ipfs/bbloom v0.0.4 h1:Gi+8EGJ2y5qiD5FbsbpX/TMNcJw8gSqr7eyjHa4Fhvs= github.com/ipfs/bbloom v0.0.4/go.mod h1:cS9YprKXpoZ9lT0n/Mw/a6/aFV6DTjTLYHeA+gyqMG0= github.com/ipfs/boxo v0.17.1-0.20240131173518-89bceff34bf1 h1:5H/HYvdmbxp09+sAvdqJzyrWoyCS6OroeW9Ym06Tb+0= @@ -574,6 +576,7 @@ github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jordanschalm/lockctx v0.1.0 h1:2ZziSl5zejl5VSRUjl+UtYV94QPFQgO9bekqWPOKUQw= github.com/jordanschalm/lockctx v0.1.0/go.mod h1:qsnXMryYP9X7JbzskIn0+N40sE6XNXLr9kYRRP6rwXU= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -659,6 +662,10 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc= +github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -1075,7 +1082,6 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v0.0.0-20170601210322-f6abca593680/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -1122,6 +1128,8 @@ github.com/vmihailenco/msgpack/v4 v4.3.11 h1:Q47CePddpNGNhk4GCnAx9DDtASi2rasatE0 github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= github.com/vmihailenco/tagparser v0.1.1 h1:quXMXlA39OCbd2wAdTsGDlK9RkOk6Wuw+x37wVyIuWY= github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= @@ -1131,6 +1139,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/internal/cadence/linter.go b/internal/cadence/linter.go index 8eda846a1..bb250d34c 100644 --- a/internal/cadence/linter.go +++ b/internal/cadence/linter.go @@ -78,6 +78,18 @@ func newLinter(state *flowkit.State) *linter { func (l *linter) lintFile( filePath string, +) (diagnostics []analysis.Diagnostic, err error) { + code, readErr := l.state.ReadFile(filePath) + if readErr != nil { + return nil, readErr + } + + return l.lintCode(code, common.StringLocation(filePath)) +} + +func (l *linter) lintCode( + code []byte, + location common.Location, ) (diagnostics []analysis.Diagnostic, err error) { // Recover from panics in the Cadence checker defer func() { @@ -88,12 +100,6 @@ func (l *linter) lintFile( }() diagnostics = make([]analysis.Diagnostic, 0) - location := common.StringLocation(filePath) - - code, readErr := l.state.ReadFile(filePath) - if readErr != nil { - return nil, readErr - } codeStr := string(code) // Parse program & convert any parsing errors to diagnostics @@ -181,6 +187,13 @@ func (l *linter) lintFile( return diagnostics, nil } +// LintCode runs all registered Cadence lint analyzers on inline code. +// This is the public entry point used by the MCP server. +func LintCode(code string, state *flowkit.State) ([]analysis.Diagnostic, error) { + l := newLinter(state) + return l.lintCode([]byte(code), common.StringLocation("code.cdc")) +} + // isContractName returns true if the location string is a contract name (not a file path) func isContractName(locationString string) bool { return !strings.HasSuffix(locationString, ".cdc") diff --git a/internal/mcp/integration_test.go b/internal/mcp/integration_test.go new file mode 100644 index 000000000..8b13dddc4 --- /dev/null +++ b/internal/mcp/integration_test.go @@ -0,0 +1,91 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "context" + "os" + "testing" + + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func skipIfNoNetwork(t *testing.T) { + t.Helper() + if os.Getenv("SKIP_NETWORK_TESTS") != "" { + t.Skip("Skipping network test (SKIP_NETWORK_TESTS is set)") + } +} + +func TestIntegration_GetContractSource(t *testing.T) { + t.Parallel() + skipIfNoNetwork(t) + + mctx := &mcpContext{state: nil} + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "address": "0x1654653399040a61", + "network": "mainnet", + } + + result, err := mctx.getContractSource(context.Background(), req) + require.NoError(t, err) + assert.False(t, result.IsError) + text := result.Content[0].(mcplib.TextContent).Text + assert.Contains(t, text, "FlowToken") +} + +func TestIntegration_GetContractCode(t *testing.T) { + t.Parallel() + skipIfNoNetwork(t) + + mctx := &mcpContext{state: nil} + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "address": "0x1654653399040a61", + "contract_name": "FlowToken", + "network": "mainnet", + } + + result, err := mctx.getContractCode(context.Background(), req) + require.NoError(t, err) + assert.False(t, result.IsError) + text := result.Content[0].(mcplib.TextContent).Text + assert.Contains(t, text, "FlowToken") +} + +func TestIntegration_ExecuteScript(t *testing.T) { + t.Parallel() + skipIfNoNetwork(t) + + mctx := &mcpContext{state: nil} + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "code": `access(all) fun main(): Int { return 42 }`, + "network": "mainnet", + } + + result, err := mctx.cadenceExecuteScript(context.Background(), req) + require.NoError(t, err) + assert.False(t, result.IsError) + text := result.Content[0].(mcplib.TextContent).Text + assert.Contains(t, text, "42") +} diff --git a/internal/mcp/lsp.go b/internal/mcp/lsp.go new file mode 100644 index 000000000..38c51478b --- /dev/null +++ b/internal/mcp/lsp.go @@ -0,0 +1,315 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/onflow/cadence-tools/languageserver/integration" + "github.com/onflow/cadence-tools/languageserver/protocol" + "github.com/onflow/cadence-tools/languageserver/server" +) + +const scratchURI = protocol.DocumentURI("file:///mcp/scratch.cdc") + +// diagConn implements protocol.Conn and captures diagnostics published by the LSP server. +type diagConn struct { + mu sync.Mutex + diagnostics []protocol.Diagnostic +} + +func (c *diagConn) Notify(method string, params any) error { + if method == "textDocument/publishDiagnostics" { + switch p := params.(type) { + case *protocol.PublishDiagnosticsParams: + c.captureDiagnostics(p.URI, p.Diagnostics) + default: + // Try JSON round-trip for map types + data, err := json.Marshal(p) + if err == nil { + var pdp protocol.PublishDiagnosticsParams + if json.Unmarshal(data, &pdp) == nil { + c.captureDiagnostics(pdp.URI, pdp.Diagnostics) + } + } + } + } + return nil +} + +func (c *diagConn) ShowMessage(_ *protocol.ShowMessageParams) {} + +func (c *diagConn) ShowMessageRequest(_ *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) { + return nil, nil +} + +func (c *diagConn) LogMessage(_ *protocol.LogMessageParams) {} + +func (c *diagConn) PublishDiagnostics(params *protocol.PublishDiagnosticsParams) error { + if params != nil { + c.captureDiagnostics(params.URI, params.Diagnostics) + } + return nil +} + +func (c *diagConn) RegisterCapability(_ *protocol.RegistrationParams) error { + return nil +} + +func (c *diagConn) captureDiagnostics(uri protocol.DocumentURI, diags []protocol.Diagnostic) { + c.mu.Lock() + defer c.mu.Unlock() + if uri != "" && uri != scratchURI { + return // ignore diagnostics for unrelated documents + } + c.diagnostics = diags // replace, not append — one publish per check cycle +} + +func (c *diagConn) reset() { + c.mu.Lock() + defer c.mu.Unlock() + c.diagnostics = nil +} + +func (c *diagConn) getDiagnostics() []protocol.Diagnostic { + c.mu.Lock() + defer c.mu.Unlock() + result := make([]protocol.Diagnostic, len(c.diagnostics)) + copy(result, c.diagnostics) + return result +} + +// LSPWrapper manages an in-process cadence-tools LSP server, +// handling document lifecycle and diagnostic capture. +type LSPWrapper struct { + server *server.Server + conn *diagConn + mu sync.Mutex + docVersion int32 + docOpen bool +} + +// NewLSPWrapper creates a new LSP wrapper with an in-process Cadence language server. +func NewLSPWrapper(enableFlowClient bool) (*LSPWrapper, error) { + s, err := server.NewServer() + if err != nil { + return nil, fmt.Errorf("creating LSP server: %w", err) + } + + _, err = integration.NewFlowIntegration(s, enableFlowClient) + if err != nil { + return nil, fmt.Errorf("creating flow integration: %w", err) + } + + conn := &diagConn{} + + _, err = s.Initialize(conn, &protocol.InitializeParams{ + XInitializeParams: protocol.XInitializeParams{ + InitializationOptions: map[string]any{ + "accessCheckMode": "strict", + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("initializing LSP server: %w", err) + } + + return &LSPWrapper{ + server: s, + conn: conn, + }, nil +} + +// updateDocument sends the code to the LSP server as a virtual document. +// Must be called with w.mu held. +func (w *LSPWrapper) updateDocument(code string) error { + w.docVersion++ + version := w.docVersion + + if !w.docOpen { + w.docOpen = true + return w.server.DidOpenTextDocument(w.conn, &protocol.DidOpenTextDocumentParams{ + TextDocument: protocol.TextDocumentItem{ + URI: scratchURI, + LanguageID: "cadence", + Version: version, + Text: code, + }, + }) + } + + return w.server.DidChangeTextDocument(w.conn, &protocol.DidChangeTextDocumentParams{ + TextDocument: protocol.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: protocol.TextDocumentIdentifier{ + URI: scratchURI, + }, + Version: version, + }, + ContentChanges: []protocol.TextDocumentContentChangeEvent{ + {Text: code}, + }, + }) +} + +// Check sends code to the LSP and returns any diagnostics. +func (w *LSPWrapper) Check(code string) ([]protocol.Diagnostic, error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.conn.reset() + + if err := w.updateDocument(code); err != nil { + return nil, fmt.Errorf("updating document: %w", err) + } + + return w.conn.getDiagnostics(), nil +} + +// Hover returns hover information at the given position. +func (w *LSPWrapper) Hover(code string, line, character int) (*protocol.Hover, error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.conn.reset() + + if err := w.updateDocument(code); err != nil { + return nil, fmt.Errorf("updating document: %w", err) + } + + return w.server.Hover(w.conn, &protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: scratchURI}, + Position: protocol.Position{Line: uint32(line), Character: uint32(character)}, + }) +} + +// Definition returns the definition location for the symbol at the given position. +func (w *LSPWrapper) Definition(code string, line, character int) (*protocol.Location, error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.conn.reset() + + if err := w.updateDocument(code); err != nil { + return nil, fmt.Errorf("updating document: %w", err) + } + + return w.server.Definition(w.conn, &protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: scratchURI}, + Position: protocol.Position{Line: uint32(line), Character: uint32(character)}, + }) +} + +// Symbols returns the document symbols for the given code. +func (w *LSPWrapper) Symbols(code string) ([]*protocol.DocumentSymbol, error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.conn.reset() + + if err := w.updateDocument(code); err != nil { + return nil, fmt.Errorf("updating document: %w", err) + } + + return w.server.DocumentSymbol(w.conn, &protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: scratchURI}, + }) +} + +// Completion returns completion items at the given position. +func (w *LSPWrapper) Completion(code string, line, character int) ([]*protocol.CompletionItem, error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.conn.reset() + + if err := w.updateDocument(code); err != nil { + return nil, fmt.Errorf("updating document: %w", err) + } + + return w.server.Completion(w.conn, &protocol.CompletionParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: scratchURI}, + Position: protocol.Position{Line: uint32(line), Character: uint32(character)}, + }, + }) +} + +// formatDiagnostics formats diagnostics as human-readable text. +func formatDiagnostics(diagnostics []protocol.Diagnostic) string { + if len(diagnostics) == 0 { + return "No errors found." + } + + var b strings.Builder + for _, d := range diagnostics { + fmt.Fprintf(&b, "[%s] line %d:%d: %s\n", + d.Severity.String(), + d.Range.Start.Line+1, + d.Range.Start.Character+1, + d.Message, + ) + } + return b.String() +} + +// formatHover formats a hover result as human-readable text. +func formatHover(result *protocol.Hover) string { + if result == nil { + return "No hover information available." + } + return result.Contents.Value +} + +// formatSymbols formats document symbols as an indented tree. +// Accepts []*protocol.DocumentSymbol (from the server API). +func formatSymbols(symbols []*protocol.DocumentSymbol, indent int) string { + var b strings.Builder + prefix := strings.Repeat(" ", indent) + for _, s := range symbols { + fmt.Fprintf(&b, "%s%s %s", prefix, s.Kind.String(), s.Name) + if s.Detail != "" { + fmt.Fprintf(&b, " — %s", s.Detail) + } + b.WriteString("\n") + if len(s.Children) > 0 { + b.WriteString(formatSymbolValues(s.Children, indent+1)) + } + } + return b.String() +} + +// formatSymbolValues formats []protocol.DocumentSymbol (value type, used for Children). +func formatSymbolValues(symbols []protocol.DocumentSymbol, indent int) string { + var b strings.Builder + prefix := strings.Repeat(" ", indent) + for _, s := range symbols { + fmt.Fprintf(&b, "%s%s %s", prefix, s.Kind.String(), s.Name) + if s.Detail != "" { + fmt.Fprintf(&b, " — %s", s.Detail) + } + b.WriteString("\n") + if len(s.Children) > 0 { + b.WriteString(formatSymbolValues(s.Children, indent+1)) + } + } + return b.String() +} diff --git a/internal/mcp/lsp_test.go b/internal/mcp/lsp_test.go new file mode 100644 index 000000000..e82c61302 --- /dev/null +++ b/internal/mcp/lsp_test.go @@ -0,0 +1,131 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestWrapper(t *testing.T) *LSPWrapper { + t.Helper() + w, err := NewLSPWrapper(false) + require.NoError(t, err) + require.NotNil(t, w) + return w +} + +func TestLSPWrapper_Check_ValidCode(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + code := ` + access(all) fun hello(): String { + return "hello" + } + ` + diags, err := w.Check(code) + require.NoError(t, err) + assert.Empty(t, diags, "valid code should produce no diagnostics") +} + +func TestLSPWrapper_Check_InvalidCode(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + // Type mismatch: returning Int from a String function + code := ` + access(all) fun hello(): String { + return 42 + } + ` + diags, err := w.Check(code) + require.NoError(t, err) + assert.NotEmpty(t, diags, "type mismatch should produce diagnostics") +} + +func TestLSPWrapper_Check_SyntaxError(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + code := ` + access(all) fun hello( { + ` + diags, err := w.Check(code) + require.NoError(t, err) + assert.NotEmpty(t, diags, "syntax error should produce diagnostics") +} + +func TestLSPWrapper_Hover(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + code := ` +access(all) fun hello(): String { + return "hello" +} +` + // Hover over "String" return type — line 1 (0-based), find the position of "String" + result, err := w.Hover(code, 1, 25) + require.NoError(t, err) + // Hover may or may not return a result depending on the position; + // we just verify it doesn't error. If non-nil, it should have contents. + if result != nil { + assert.NotEmpty(t, result.Contents.Value) + } +} + +func TestLSPWrapper_Symbols(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + code := ` +access(all) contract MyContract { + access(all) fun greet(): String { + return "hi" + } +} +` + symbols, err := w.Symbols(code) + require.NoError(t, err) + require.NotEmpty(t, symbols, "contract with members should have symbols") + + // The top-level symbol should be the contract + assert.Equal(t, "MyContract", symbols[0].Name) +} + +func TestLSPWrapper_Completion(t *testing.T) { + t.Parallel() + w := newTestWrapper(t) + + // Inside a function body, the LSP should offer completions + code := ` +access(all) fun main() { + let x: String = "hello" + x. +} +` + // Position right after "x." — line 3, character 3 + items, err := w.Completion(code, 3, 3) + require.NoError(t, err) + // String methods should appear as completions + assert.NotEmpty(t, items, "should get completion items for String methods") +} diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go new file mode 100644 index 000000000..dc9e5915c --- /dev/null +++ b/internal/mcp/mcp.go @@ -0,0 +1,132 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "errors" + "fmt" + "os" + + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/spf13/afero" + "github.com/spf13/cobra" + + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/config" + "github.com/onflow/flowkit/v2/gateway" +) + +var Cmd = &cobra.Command{ + Use: "mcp", + Short: "Start the Cadence MCP server", + Long: `Start a Model Context Protocol (MCP) server for Cadence smart contract development. + +The server provides tools for checking Cadence code, inspecting types, +querying on-chain contracts, executing scripts, and reviewing code for +common issues. + +Claude Code: + claude mcp add cadence-mcp -- flow mcp + +Cursor / Claude Desktop (add to settings JSON): + { + "mcpServers": { + "cadence-mcp": { + "command": "flow", + "args": ["mcp"] + } + } + } + +Available tools: + cadence_check Check Cadence code for syntax and type errors + cadence_hover Get type info for a symbol at a position + cadence_definition Find where a symbol is defined + cadence_symbols List all symbols in Cadence code + cadence_completion Get completions at a position + get_contract_source Fetch on-chain contract manifest + get_contract_code Fetch contract source code from an address + cadence_code_review Review Cadence code for common issues + cadence_execute_script Execute a read-only Cadence script on-chain`, + Run: runMCP, +} + +func runMCP(cmd *cobra.Command, args []string) { + // Try to load flow.json for custom network configs + loader := &afero.Afero{Fs: afero.NewOsFs()} + state, err := flowkit.Load(config.DefaultPaths(), loader) + if err != nil && !errors.Is(err, config.ErrDoesNotExist) { + fmt.Fprintf(os.Stderr, "Warning: failed to load flow.json: %v\n", err) + } + + // Initialize the LSP wrapper (without flow client for MCP use). + var lsp *LSPWrapper + if w, err := NewLSPWrapper(false); err == nil { + lsp = w + } else { + fmt.Fprintf(os.Stderr, "Warning: LSP initialization failed, LSP tools will be unavailable: %v\n", err) + } + + mctx := &mcpContext{ + lsp: lsp, + state: state, + } + + s := mcpserver.NewMCPServer("cadence-mcp", "1.0.0") + registerTools(s, mctx) + + if err := mcpserver.ServeStdio(s); err != nil { + fmt.Fprintf(os.Stderr, "MCP server error: %v\n", err) + os.Exit(1) + } +} + +// resolveNetwork returns a config.Network for the given network name. +// Uses flow.json config if available, otherwise falls back to defaults. +func resolveNetwork(state *flowkit.State, network string) (*config.Network, error) { + if network == "" { + network = "mainnet" + } + + if state != nil { + net, err := state.Networks().ByName(network) + if err == nil { + return net, nil + } + } + + net, err := config.DefaultNetworks.ByName(network) + if err != nil { + return nil, fmt.Errorf("unknown network %q", network) + } + return net, nil +} + +// createGateway creates a gRPC gateway for the given network. +// Uses a secure gateway when the network has a configured key. +func createGateway(state *flowkit.State, network string) (gateway.Gateway, error) { + net, err := resolveNetwork(state, network) + if err != nil { + return nil, err + } + if net.Key != "" { + return gateway.NewSecureGrpcGateway(*net) + } + return gateway.NewGrpcGateway(*net) +} diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go new file mode 100644 index 000000000..028e07a75 --- /dev/null +++ b/internal/mcp/tools.go @@ -0,0 +1,387 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + + mcplib "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/onflow/cadence" + "github.com/onflow/flow-go-sdk" + + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/arguments" +) + +// mcpContext holds shared dependencies for all MCP tool handlers. +type mcpContext struct { + lsp *LSPWrapper + state *flowkit.State // may be nil +} + +// resolveCode extracts the required "code" parameter from the request. +func resolveCode(req mcplib.CallToolRequest) (string, error) { + return req.RequireString("code") +} + +// parseAddress parses a Flow address string and validates it is not empty. +func parseAddress(address string) (flow.Address, error) { + addr := flow.HexToAddress(address) + if addr == flow.EmptyAddress { + return flow.EmptyAddress, fmt.Errorf("invalid Flow address: %q", address) + } + return addr, nil +} + +// registerTools registers all MCP tools on the given server. +func registerTools(s *mcpserver.MCPServer, mctx *mcpContext) { + // LSP tools — only register if the LSP wrapper is available. + if mctx.lsp != nil { + s.AddTool( + mcplib.NewTool("cadence_check", + mcplib.WithDescription("Check Cadence code for syntax and type errors."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence source code to check")), + ), + mctx.cadenceCheck, + ) + + s.AddTool( + mcplib.NewTool("cadence_hover", + mcplib.WithDescription("Get type information for a symbol at a position in Cadence code."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence source code")), + + mcplib.WithNumber("line", mcplib.Required(), mcplib.Description("0-based line number")), + mcplib.WithNumber("character", mcplib.Required(), mcplib.Description("0-based column number")), + ), + mctx.cadenceHover, + ) + + s.AddTool( + mcplib.NewTool("cadence_definition", + mcplib.WithDescription("Find where a symbol is defined in Cadence code."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence source code")), + + mcplib.WithNumber("line", mcplib.Required(), mcplib.Description("0-based line number")), + mcplib.WithNumber("character", mcplib.Required(), mcplib.Description("0-based column number")), + ), + mctx.cadenceDefinition, + ) + + s.AddTool( + mcplib.NewTool("cadence_symbols", + mcplib.WithDescription("List all symbols in Cadence code."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence source code")), + ), + mctx.cadenceSymbols, + ) + + s.AddTool( + mcplib.NewTool("cadence_completion", + mcplib.WithDescription("Get completion suggestions at a position in Cadence code."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence source code")), + + mcplib.WithNumber("line", mcplib.Required(), mcplib.Description("0-based line number")), + mcplib.WithNumber("character", mcplib.Required(), mcplib.Description("0-based column number")), + ), + mctx.cadenceCompletion, + ) + } + + // Audit / network tools — always registered. + s.AddTool( + mcplib.NewTool("get_contract_source", + mcplib.WithDescription("Fetch on-chain contract manifest (names and sizes) for a Flow account"), + mcplib.WithString("address", mcplib.Required(), mcplib.Description("Flow account address (hex, with or without 0x prefix)")), + mcplib.WithString("network", mcplib.Description("Flow network to query"), mcplib.Enum("mainnet", "testnet", "emulator")), + ), + mctx.getContractSource, + ) + + s.AddTool( + mcplib.NewTool("get_contract_code", + mcplib.WithDescription("Fetch contract source code from a Flow account"), + mcplib.WithString("address", mcplib.Required(), mcplib.Description("Flow account address (hex, with or without 0x prefix)")), + mcplib.WithString("contract_name", mcplib.Description("Specific contract name to retrieve; omit for all contracts")), + mcplib.WithString("network", mcplib.Description("Flow network to query"), mcplib.Enum("mainnet", "testnet", "emulator")), + ), + mctx.getContractCode, + ) + + s.AddTool( + mcplib.NewTool("cadence_execute_script", + mcplib.WithDescription("Execute a read-only Cadence script on-chain."), + mcplib.WithString("code", mcplib.Required(), mcplib.Description("Cadence script source code")), + + mcplib.WithString("network", mcplib.Description("Flow network to execute against"), mcplib.Enum("mainnet", "testnet", "emulator")), + mcplib.WithString("arguments", mcplib.Description("JSON array of arguments as strings, e.g. [\"String:hello\", \"UFix64:1.0\"]")), + ), + mctx.cadenceExecuteScript, + ) +} + +// --------------------------------------------------------------------------- +// LSP tool handlers +// --------------------------------------------------------------------------- + +func (m *mcpContext) cadenceCheck(_ context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + + diags, err := m.lsp.Check(code) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("LSP check failed: %v", err)), nil + } + return mcplib.NewToolResultText(formatDiagnostics(diags)), nil +} + +func (m *mcpContext) cadenceHover(_ context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + line, err := req.RequireInt("line") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + character, err := req.RequireInt("character") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + + result, err := m.lsp.Hover(code, line, character) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("LSP hover failed: %v", err)), nil + } + return mcplib.NewToolResultText(formatHover(result)), nil +} + +func (m *mcpContext) cadenceDefinition(_ context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + line, err := req.RequireInt("line") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + character, err := req.RequireInt("character") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + + loc, err := m.lsp.Definition(code, line, character) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("LSP definition failed: %v", err)), nil + } + if loc == nil { + return mcplib.NewToolResultText("No definition found."), nil + } + return mcplib.NewToolResultText(fmt.Sprintf("%s line %d:%d", + loc.URI, loc.Range.Start.Line+1, loc.Range.Start.Character+1)), nil +} + +func (m *mcpContext) cadenceSymbols(_ context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + + symbols, err := m.lsp.Symbols(code) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("LSP symbols failed: %v", err)), nil + } + return mcplib.NewToolResultText(formatSymbols(symbols, 0)), nil +} + +func (m *mcpContext) cadenceCompletion(_ context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + line, err := req.RequireInt("line") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + character, err := req.RequireInt("character") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + + items, err := m.lsp.Completion(code, line, character) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("LSP completion failed: %v", err)), nil + } + + var b strings.Builder + for _, item := range items { + b.WriteString(item.Label) + if item.Detail != "" { + fmt.Fprintf(&b, " — %s", item.Detail) + } + b.WriteString("\n") + } + if b.Len() == 0 { + return mcplib.NewToolResultText("No completions available."), nil + } + return mcplib.NewToolResultText(b.String()), nil +} + +// --------------------------------------------------------------------------- +// Audit / network tool handlers +// --------------------------------------------------------------------------- + +func (m *mcpContext) getContractSource(ctx context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + address, err := req.RequireString("address") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + network := req.GetString("network", "mainnet") + + gw, err := createGateway(m.state, network) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to create gateway: %v", err)), nil + } + + addr, err := parseAddress(address) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + account, err := gw.GetAccount(ctx, addr) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to get account: %v", err)), nil + } + + type contractInfo struct { + Name string `json:"name"` + Size int `json:"size"` + } + + contracts := make([]contractInfo, 0, len(account.Contracts)) + for name, code := range account.Contracts { + contracts = append(contracts, contractInfo{Name: name, Size: len(code)}) + } + sort.Slice(contracts, func(i, j int) bool { + return contracts[i].Name < contracts[j].Name + }) + + result := struct { + Address string `json:"address"` + Contracts []contractInfo `json:"contracts"` + }{ + Address: addr.String(), + Contracts: contracts, + } + + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + return mcplib.NewToolResultText(string(data)), nil +} + +func (m *mcpContext) getContractCode(ctx context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + address, err := req.RequireString("address") + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + contractName := req.GetString("contract_name", "") + network := req.GetString("network", "mainnet") + + gw, err := createGateway(m.state, network) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to create gateway: %v", err)), nil + } + + addr, err := parseAddress(address) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + account, err := gw.GetAccount(ctx, addr) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to get account: %v", err)), nil + } + + if contractName != "" { + code, ok := account.Contracts[contractName] + if !ok { + return mcplib.NewToolResultError(fmt.Sprintf("contract %q not found on account %s", contractName, addr.String())), nil + } + return mcplib.NewToolResultText(string(code)), nil + } + + // Return all contracts. + var b strings.Builder + names := make([]string, 0, len(account.Contracts)) + for name := range account.Contracts { + names = append(names, name) + } + sort.Strings(names) + + for i, name := range names { + if i > 0 { + b.WriteString("\n\n") + } + fmt.Fprintf(&b, "// === %s ===\n%s", name, string(account.Contracts[name])) + } + if b.Len() == 0 { + return mcplib.NewToolResultText("No contracts found on this account."), nil + } + return mcplib.NewToolResultText(b.String()), nil +} + +func (m *mcpContext) cadenceExecuteScript(ctx context.Context, req mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + code, err := resolveCode(req) + if err != nil { + return mcplib.NewToolResultError(err.Error()), nil + } + network := req.GetString("network", "mainnet") + argsJSON := req.GetString("arguments", "") + + gw, err := createGateway(m.state, network) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to create gateway: %v", err)), nil + } + + var cadenceArgs []cadence.Value + if argsJSON != "" { + var argStrings []string + if jsonErr := json.Unmarshal([]byte(argsJSON), &argStrings); jsonErr != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to parse arguments JSON: %v", jsonErr)), nil + } + parsed, parseErr := arguments.ParseWithoutType(argStrings, []byte(code), "") + if parseErr != nil { + return mcplib.NewToolResultError(fmt.Sprintf("failed to parse arguments: %v", parseErr)), nil + } + cadenceArgs = parsed + } + + val, err := gw.ExecuteScript(ctx, []byte(code), cadenceArgs) + if err != nil { + return mcplib.NewToolResultError(fmt.Sprintf("script execution failed: %v", err)), nil + } + + return mcplib.NewToolResultText(val.String()), nil +} diff --git a/internal/mcp/tools_test.go b/internal/mcp/tools_test.go new file mode 100644 index 000000000..f4c5cc05e --- /dev/null +++ b/internal/mcp/tools_test.go @@ -0,0 +1,107 @@ +/* + * Flow CLI + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + "context" + "testing" + + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestContext(t *testing.T) *mcpContext { + t.Helper() + lsp, err := NewLSPWrapper(false) + require.NoError(t, err) + return &mcpContext{lsp: lsp} +} + +func TestTool_CadenceCheck_Valid(t *testing.T) { + t.Parallel() + mctx := newTestContext(t) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "code": `access(all) fun hello(): String { return "hello" }`, + } + + result, err := mctx.cadenceCheck(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + textContent := result.Content[0].(mcplib.TextContent) + assert.Contains(t, textContent.Text, "No errors found") +} + +func TestTool_CadenceCheck_Invalid(t *testing.T) { + t.Parallel() + mctx := newTestContext(t) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "code": `access(all) fun hello(): String { return 42 }`, + } + + result, err := mctx.cadenceCheck(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + textContent := result.Content[0].(mcplib.TextContent) + assert.Contains(t, textContent.Text, "Error") +} + +func TestTool_CadenceCheck_MissingCode(t *testing.T) { + t.Parallel() + mctx := newTestContext(t) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{} + + result, err := mctx.cadenceCheck(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) +} + +func TestTool_CadenceSymbols(t *testing.T) { + t.Parallel() + mctx := newTestContext(t) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "code": ` +access(all) contract MyContract { + access(all) fun greet(): String { + return "hi" + } +} +`, + } + + result, err := mctx.cadenceSymbols(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + textContent := result.Content[0].(mcplib.TextContent) + assert.Contains(t, textContent.Text, "MyContract") +}