Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions pkg/ai/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
messages := toAnthropicMessages(req.Messages)
tools := AnthropicToolDefs()

maxIterations := 10
maxIterations := 100
for i := 0; i < maxIterations; i++ {
stream := a.anthropicClient.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(a.model),
Expand Down Expand Up @@ -91,8 +91,9 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
sendEvent(SSEEvent{
Event: "tool_result",
Data: map[string]interface{}{
"tool": toolName,
"result": result,
"tool": toolName,
"result": result,
"is_error": true,
},
})
toolResults = append(toolResults, anthropic.NewToolResultBlock(tc.ID, "Tool error: "+result, true))
Expand All @@ -113,8 +114,9 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
sendEvent(SSEEvent{
Event: "tool_result",
Data: map[string]interface{}{
"tool": toolName,
"result": result,
"tool": toolName,
"result": result,
"is_error": isError,
},
})

Expand Down
12 changes: 7 additions & 5 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu

tools := OpenAIToolDefs()

maxIterations := 10
maxIterations := 100
for i := 0; i < maxIterations; i++ {
stream := a.openaiClient.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
Model: a.model,
Expand Down Expand Up @@ -99,8 +99,9 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
sendEvent(SSEEvent{
Event: "tool_result",
Data: map[string]interface{}{
"tool": toolName,
"result": result,
"tool": toolName,
"result": result,
"is_error": true,
},
})
messages = append(messages, openai.ToolMessage("Tool error: "+result, tc.ID))
Expand All @@ -121,8 +122,9 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
sendEvent(SSEEvent{
Event: "tool_result",
Data: map[string]interface{}{
"tool": toolName,
"result": result,
"tool": toolName,
"result": result,
"is_error": isError,
},
})

Expand Down
164 changes: 146 additions & 18 deletions pkg/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import (
"github.com/zxh326/kite/pkg/rbac"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime/schema"
k8stypes "k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/discovery"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/yaml"
Expand Down Expand Up @@ -226,7 +228,7 @@ type resourceInfo struct {
ClusterScoped bool
}

func resolveResourceInfo(kind string) resourceInfo {
func resolveStaticResourceInfo(kind string) resourceInfo {
switch strings.ToLower(strings.TrimSpace(kind)) {
case "pod", "pods":
return resourceInfo{Kind: "Pod", Resource: "pods", Version: "v1"}
Expand Down Expand Up @@ -264,6 +266,8 @@ func resolveResourceInfo(kind string) resourceInfo {
return resourceInfo{Kind: "NetworkPolicy", Resource: "networkpolicies", Group: "networking.k8s.io", Version: "v1"}
case "storageclass", "storageclasses", "sc":
return resourceInfo{Kind: "StorageClass", Resource: "storageclasses", Group: "storage.k8s.io", Version: "v1", ClusterScoped: true}
case "customresourcedefinition", "customresourcedefinitions", "crd", "crds":
return resourceInfo{Kind: "CustomResourceDefinition", Resource: "customresourcedefinitions", Group: "apiextensions.k8s.io", Version: "v1", ClusterScoped: true}
case "event", "events":
return resourceInfo{Kind: "Event", Resource: "events", Version: "v1"}
default:
Expand All @@ -284,6 +288,131 @@ func resolveResourceInfo(kind string) resourceInfo {
}
}

func resolveResourceInfo(ctx context.Context, cs *cluster.ClientSet, kind string) resourceInfo {
if info, ok := resolveResourceInfoFromDiscovery(ctx, cs, kind, ""); ok {
return info
}
return resolveStaticResourceInfo(kind)
}

func resolveResourceInfoForObject(ctx context.Context, cs *cluster.ClientSet, obj *unstructured.Unstructured) resourceInfo {
if info, ok := resolveResourceInfoFromDiscovery(ctx, cs, obj.GetKind(), obj.GetAPIVersion()); ok {
return info
}
return resolveStaticResourceInfo(obj.GetKind())
}

func resolveResourceInfoFromDiscovery(ctx context.Context, cs *cluster.ClientSet, kind, apiVersion string) (resourceInfo, bool) {
input := strings.ToLower(strings.TrimSpace(kind))
if input == "" || cs == nil || cs.K8sClient == nil || cs.K8sClient.ClientSet == nil {
return resourceInfo{}, false
}
if ctx != nil {
select {
case <-ctx.Done():
return resourceInfo{}, false
default:
}
}
discoveryClient := cs.K8sClient.ClientSet.Discovery()

if gv, ok := parseGroupVersion(apiVersion); ok {
resourceList, err := discoveryClient.ServerResourcesForGroupVersion(gv.String())
if err != nil {
klog.V(2).Infof("AI tool discovery failed for %s: %v", gv.String(), err)
} else if info, found := findResourceInfoInList(input, gv, resourceList.APIResources); found {
return info, true
}
}

resourceLists, err := discoveryClient.ServerPreferredResources()
if err != nil && !discovery.IsGroupDiscoveryFailedError(err) {
klog.V(2).Infof("AI tool preferred discovery failed: %v", err)
return resourceInfo{}, false
}

for _, resourceList := range resourceLists {
if resourceList == nil {
continue
}
gv, err := schema.ParseGroupVersion(resourceList.GroupVersion)
if err != nil {
continue
}
if info, found := findResourceInfoInList(input, gv, resourceList.APIResources); found {
return info, true
}
}

return resourceInfo{}, false
}

func parseGroupVersion(apiVersion string) (schema.GroupVersion, bool) {
apiVersion = strings.TrimSpace(apiVersion)
if apiVersion == "" {
return schema.GroupVersion{}, false
}
gv, err := schema.ParseGroupVersion(apiVersion)
if err != nil {
return schema.GroupVersion{}, false
}
return gv, true
}

func findResourceInfoInList(input string, gv schema.GroupVersion, apiResources []metav1.APIResource) (resourceInfo, bool) {
group := strings.ToLower(gv.Group)
for _, apiResource := range apiResources {
if strings.Contains(apiResource.Name, "/") {
continue
}
if !resourceMatchesInput(input, group, apiResource) {
continue
}
return resourceInfo{
Kind: apiResource.Kind,
Resource: apiResource.Name,
Group: gv.Group,
Version: gv.Version,
ClusterScoped: !apiResource.Namespaced,
}, true
}
return resourceInfo{}, false
}

func resourceMatchesInput(input, group string, apiResource metav1.APIResource) bool {
candidates := make([]string, 0, 3+len(apiResource.ShortNames))
if kind := strings.ToLower(strings.TrimSpace(apiResource.Kind)); kind != "" {
candidates = append(candidates, kind)
}
if name := strings.ToLower(strings.TrimSpace(apiResource.Name)); name != "" {
candidates = append(candidates, name)
}
if singular := strings.ToLower(strings.TrimSpace(apiResource.SingularName)); singular != "" {
candidates = append(candidates, singular)
}
for _, shortName := range apiResource.ShortNames {
if shortName = strings.ToLower(strings.TrimSpace(shortName)); shortName != "" {
candidates = append(candidates, shortName)
}
}

for _, candidate := range candidates {
if input == candidate {
return true
}
if !strings.HasSuffix(candidate, "s") && input == candidate+"s" {
return true
}
if group != "" && input == candidate+"."+group {
return true
}
if group != "" && !strings.HasSuffix(candidate, "s") && input == candidate+"s."+group {
return true
}
}
return false
}

func (r resourceInfo) GVK() schema.GroupVersionKind {
return schema.GroupVersionKind{Group: r.Group, Version: r.Version, Kind: r.Kind}
}
Expand All @@ -299,8 +428,7 @@ func normalizeNamespace(r resourceInfo, namespace string) string {
return namespace
}

func buildObjectForKind(kind string) *unstructured.Unstructured {
resource := resolveResourceInfo(kind)
func buildObjectForResource(resource resourceInfo) *unstructured.Unstructured {
obj := &unstructured.Unstructured{}
obj.SetGroupVersionKind(resource.GVK())
return obj
Expand Down Expand Up @@ -356,15 +484,15 @@ func permissionNamespace(resource resourceInfo, namespace string) string {
return namespace
}

func requiredToolPermissions(toolName string, args map[string]interface{}) ([]toolPermission, error) {
func requiredToolPermissions(ctx context.Context, cs *cluster.ClientSet, toolName string, args map[string]interface{}) ([]toolPermission, error) {
switch toolName {
case "get_resource":
kind, err := getRequiredString(args, "kind")
if err != nil {
return nil, err
}
namespace, _ := args["namespace"].(string)
resource := resolveResourceInfo(kind)
resource := resolveResourceInfo(ctx, cs, kind)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbGet),
Expand All @@ -376,7 +504,7 @@ func requiredToolPermissions(toolName string, args map[string]interface{}) ([]to
return nil, err
}
namespace, _ := args["namespace"].(string)
resource := resolveResourceInfo(kind)
resource := resolveResourceInfo(ctx, cs, kind)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbGet),
Expand Down Expand Up @@ -407,7 +535,7 @@ func requiredToolPermissions(toolName string, args map[string]interface{}) ([]to
if err != nil {
return nil, err
}
resource := resolveResourceInfo(obj.GetKind())
resource := resolveResourceInfoForObject(ctx, cs, obj)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbCreate),
Expand All @@ -418,7 +546,7 @@ func requiredToolPermissions(toolName string, args map[string]interface{}) ([]to
if err != nil {
return nil, err
}
resource := resolveResourceInfo(obj.GetKind())
resource := resolveResourceInfoForObject(ctx, cs, obj)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbUpdate),
Expand All @@ -433,7 +561,7 @@ func requiredToolPermissions(toolName string, args map[string]interface{}) ([]to
return nil, err
}
namespace, _ := args["namespace"].(string)
resource := resolveResourceInfo(kind)
resource := resolveResourceInfo(ctx, cs, kind)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbUpdate),
Expand All @@ -448,7 +576,7 @@ func requiredToolPermissions(toolName string, args map[string]interface{}) ([]to
return nil, err
}
namespace, _ := args["namespace"].(string)
resource := resolveResourceInfo(kind)
resource := resolveResourceInfo(ctx, cs, kind)
return []toolPermission{{
Resource: resource.Resource,
Verb: string(common.VerbDelete),
Expand Down Expand Up @@ -480,7 +608,7 @@ func AuthorizeTool(c *gin.Context, cs *cluster.ClientSet, toolName string, args
return "Error: authenticated user not found in context", true
}

permissions, err := requiredToolPermissions(toolName, args)
permissions, err := requiredToolPermissions(c.Request.Context(), cs, toolName, args)
if err != nil {
return "Error: " + err.Error(), true
}
Expand Down Expand Up @@ -533,8 +661,8 @@ func executeGetResource(ctx context.Context, cs *cluster.ClientSet, args map[str
}
namespace, _ := args["namespace"].(string)

resource := resolveResourceInfo(kind)
obj := buildObjectForKind(kind)
resource := resolveResourceInfo(ctx, cs, kind)
obj := buildObjectForResource(resource)
key := k8stypes.NamespacedName{
Name: name,
Namespace: normalizeNamespace(resource, namespace),
Expand Down Expand Up @@ -596,7 +724,7 @@ func executeListResources(ctx context.Context, cs *cluster.ClientSet, args map[s
namespace, _ := args["namespace"].(string)
labelSelector, _ := args["label_selector"].(string)

resource := resolveResourceInfo(kind)
resource := resolveResourceInfo(ctx, cs, kind)
namespace = normalizeNamespace(resource, namespace)
list := &unstructured.UnstructuredList{}
list.SetGroupVersionKind(resource.ListGVK())
Expand Down Expand Up @@ -1072,8 +1200,8 @@ func executePatchResource(ctx context.Context, cs *cluster.ClientSet, args map[s
return "Error: patch must be valid JSON", true
}

resource := resolveResourceInfo(kind)
obj := buildObjectForKind(kind)
resource := resolveResourceInfo(ctx, cs, kind)
obj := buildObjectForResource(resource)

key := k8stypes.NamespacedName{
Name: name,
Expand Down Expand Up @@ -1104,8 +1232,8 @@ func executeDeleteResource(ctx context.Context, cs *cluster.ClientSet, args map[
}
namespace, _ := args["namespace"].(string)

resource := resolveResourceInfo(kind)
obj := buildObjectForKind(kind)
resource := resolveResourceInfo(ctx, cs, kind)
obj := buildObjectForResource(resource)

key := k8stypes.NamespacedName{
Name: name,
Expand Down
Loading