diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 6461f0e..0f2b5f5 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -9,6 +9,7 @@ import ( "os" "strings" + "github.com/supermodeltools/cli/internal/Memorygraph" "github.com/supermodeltools/cli/internal/analyze" "github.com/supermodeltools/cli/internal/api" "github.com/supermodeltools/cli/internal/build" @@ -113,6 +114,80 @@ var tools = []tool{ }, }, }, + { + Name: "upsert_memory_node", + Description: "Upsert a typed knowledge node into the persistent memory graph.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "type": {Type: "string", Description: "Node type: fact, concept, entity, event, procedure, context."}, + "label": {Type: "string", Description: "Short unique label for the node."}, + "content": {Type: "string", Description: "Full content body of the node."}, + }, + Required: []string{"type", "label", "content"}, + }, + }, + { + Name: "create_relation", + Description: "Create a directed weighted edge between two memory graph nodes.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "source_id": {Type: "string", Description: "ID of the source node."}, + "target_id": {Type: "string", Description: "ID of the target node."}, + "relation": {Type: "string", Description: "Relation type, e.g. related_to, depends_on, part_of."}, + "weight": {Type: "number", Description: "Edge weight between 0 and 1 (default 1.0)."}, + }, + Required: []string{"source_id", "target_id", "relation"}, + }, + }, + { + Name: "search_memory_graph", + Description: "Score and retrieve nodes from the memory graph matching a query, with optional one-hop neighbor expansion.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "query": {Type: "string", Description: "Search query string."}, + "max_depth": {Type: "integer", Description: "Max BFS depth for neighbor expansion (default 1)."}, + "top_k": {Type: "integer", Description: "Maximum number of direct results to return (default 5)."}, + }, + Required: []string{"query"}, + }, + }, + { + Name: "retrieve_with_traversal", + Description: "BFS traversal from a start node up to maxDepth, returning visited nodes with decayed relevance scores.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "start_node_id": {Type: "string", Description: "ID of the node to start traversal from."}, + "max_depth": {Type: "integer", Description: "Maximum BFS depth (default 3)."}, + }, + Required: []string{"start_node_id"}, + }, + }, + { + Name: "prune_stale_links", + Description: "Remove edges below a weight threshold and orphaned nodes from the memory graph.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "threshold": {Type: "number", Description: "Minimum edge weight to retain (default 0.1)."}, + }, + }, + }, + { + Name: "add_interlinked_context", + Description: "Bulk-insert nodes and optionally auto-create similarity edges (Jaccard ≥ 0.72) between them.", + InputSchema: toolSchema{ + Type: "object", + Properties: map[string]schemaProp{ + "items": {Type: "array", Description: "Array of {type, label, content, metadata} node objects to insert."}, + "auto_link": {Type: "boolean", Description: "If true, auto-create similarity edges between inserted nodes."}, + }, + Required: []string{"items"}, + }, + }, } // --- Server ------------------------------------------------------------------ @@ -226,7 +301,71 @@ func (s *server) callTool(ctx context.Context, name string, args map[string]any) return s.toolGetGraph(ctx, args) default: return "", fmt.Errorf("unknown tool: %s", name) - } + func (s *server) callTool(ctx context.Context, name string, args map[string]any) (string, error) { + switch name { + case "analyze": + return s.toolAnalyze(ctx, args) + case "dead_code": + return s.toolDeadCode(ctx, args) + case "blast_radius": + return s.toolBlastRadius(ctx, args) + case "get_graph": + return s.toolGetGraph(ctx, args) + case "upsert_memory_node": + return memorygraph.ToolUpsertMemoryNode(memorygraph.UpsertMemoryNodeOptions{ + RootDir: s.dir, + Type: memorygraph.NodeType(strArg(args, "type")), + Label: strArg(args, "label"), + Content: strArg(args, "content"), + }) + case "create_relation": + w := floatArg(args, "weight") + if w == 0 { + w = 1.0 + } + return memorygraph.ToolCreateRelation(memorygraph.CreateRelationOptions{ + RootDir: s.dir, + SourceID: strArg(args, "source_id"), + TargetID: strArg(args, "target_id"), + Relation: memorygraph.RelationType(strArg(args, "relation")), + Weight: w, + }) + case "search_memory_graph": + topK := intArg(args, "top_k") + if topK == 0 { + topK = 5 + } + return memorygraph.ToolSearchMemoryGraph(memorygraph.SearchMemoryGraphOptions{ + RootDir: s.dir, + Query: strArg(args, "query"), + MaxDepth: intArg(args, "max_depth"), + TopK: topK, + }) + case "retrieve_with_traversal": + return memorygraph.ToolRetrieveWithTraversal(memorygraph.RetrieveWithTraversalOptions{ + RootDir: s.dir, + StartNodeID: strArg(args, "start_node_id"), + MaxDepth: intArg(args, "max_depth"), + }) + case "prune_stale_links": + return memorygraph.ToolPruneStaleLinks(memorygraph.PruneStaleLinksOptions{ + RootDir: s.dir, + Threshold: floatArg(args, "threshold"), + }) + case "add_interlinked_context": + items, err := parseInterlinkedItems(args) + if err != nil { + return "", fmt.Errorf("add_interlinked_context: invalid items: %w", err) + } + return memorygraph.ToolAddInterlinkedContext(memorygraph.AddInterlinkedContextOptions{ + RootDir: s.dir, + Items: items, + AutoLink: boolArg(args, "auto_link"), + }) + default: + return "", fmt.Errorf("unknown tool: %s", name) + } + } } // toolAnalyze uploads the repo and runs the full analysis pipeline. @@ -497,3 +636,33 @@ func intArg(args map[string]any, key string) int { v, _ := args[key].(float64) return int(v) } + func strArg(args map[string]any, key string) string { + v, _ := args[key].(string) + return v + } + + func floatArg(args map[string]any, key string) float64 { + v, _ := args[key].(float64) + return v + } + + // parseInterlinkedItems re-encodes the raw args["items"] array and decodes it + // into the strongly-typed slice expected by ToolAddInterlinkedContext. + func parseInterlinkedItems(args map[string]any) ([]memorygraph.InterlinkedItem, error) { + raw, ok := args["items"] + if !ok || raw == nil { + return nil, fmt.Errorf("missing required field \"items\"") + } + b, err := json.Marshal(raw) + if err != nil { + return nil, err + } + var items []memorygraph.InterlinkedItem + if err := json.Unmarshal(b, &items); err != nil { + return nil, err + } + if len(items) == 0 { + return nil, fmt.Errorf("\"items\" must be a non-empty array") + } + return items, nil + } diff --git a/internal/memorygraph/memory_graph.go b/internal/memorygraph/memory_graph.go new file mode 100644 index 0000000..01e82ac --- /dev/null +++ b/internal/memorygraph/memory_graph.go @@ -0,0 +1,700 @@ +// Package memorygraph implements a persistent memory graph for interlinked RAG. +// Nodes represent typed knowledge units; edges are weighted, typed relations. +// The graph is stored as a single JSON file under rootDir/.supermodel/memory-graph.json +// and is safe for concurrent reads within a process (writes hold a mutex). +package memorygraph + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" +) + +// --- Types ------------------------------------------------------------------- + +// NodeType classifies what kind of knowledge a node represents. +type NodeType string + +const ( + NodeTypeFact NodeType = "fact" + NodeTypeConcept NodeType = "concept" + NodeTypeEntity NodeType = "entity" + NodeTypeEvent NodeType = "event" + NodeTypeProcedure NodeType = "procedure" + NodeTypeContext NodeType = "context" +) + +// RelationType classifies the semantic relationship between two nodes. +type RelationType string + +const ( + RelationRelatedTo RelationType = "related_to" + RelationDependsOn RelationType = "depends_on" + RelationPartOf RelationType = "part_of" + RelationLeadsTo RelationType = "leads_to" + RelationContrasts RelationType = "contrasts" + RelationSimilarTo RelationType = "similar_to" + RelationInstantiates RelationType = "instantiates" +) + +// Node is a single knowledge unit in the memory graph. +type Node struct { + ID string `json:"id"` + Type NodeType `json:"type"` + Label string `json:"label"` + Content string `json:"content"` + Metadata map[string]string `json:"metadata,omitempty"` + AccessCount int `json:"accessCount"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// Edge is a directed, weighted relation between two nodes. +type Edge struct { + ID string `json:"id"` + Source string `json:"source"` + Target string `json:"target"` + Relation RelationType `json:"relation"` + Weight float64 `json:"weight"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// TraversalResult is a node reached during graph traversal, enriched with +// path context and a relevance score. +type TraversalResult struct { + Node Node + Depth int + RelevanceScore float64 + PathRelations []string // relation labels along the path from the start node +} + +// GraphStats summarises the current state of the graph. +type GraphStats struct { + Nodes int + Edges int +} + +// graphData is the on-disk format. +type graphData struct { + Nodes []Node `json:"nodes"` + Edges []Edge `json:"edges"` +} + +// --- Storage ----------------------------------------------------------------- + +const graphFile = ".supermodel/memory-graph.json" + +var ( + mu sync.RWMutex + cache = map[string]*graphData{} // rootDir → loaded graph +) + +func graphPath(rootDir string) string { + return filepath.Join(rootDir, graphFile) +} + +// load reads the graph for rootDir from disk (or returns the in-memory cache). +func load(rootDir string) (*graphData, error) { + mu.RLock() + if g, ok := cache[rootDir]; ok { + mu.RUnlock() + return g, nil + } + mu.RUnlock() + + mu.Lock() + defer mu.Unlock() + + // Double-checked locking. + if g, ok := cache[rootDir]; ok { + return g, nil + } + + path := graphPath(rootDir) + g := &graphData{} + + data, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("memorygraph: read %s: %w", path, err) + } + if err == nil { + if err := json.Unmarshal(data, g); err != nil { + return nil, fmt.Errorf("memorygraph: parse %s: %w", path, err) + } + } + + cache[rootDir] = g + return g, nil +} + +// save persists g to disk. Caller must hold mu (write lock). +func save(rootDir string, g *graphData) error { + path := graphPath(rootDir) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("memorygraph: mkdir: %w", err) + } + data, err := json.MarshalIndent(g, "", " ") + if err != nil { + return fmt.Errorf("memorygraph: marshal: %w", err) + } + if err := os.WriteFile(path, data, 0o644); err != nil { + return fmt.Errorf("memorygraph: write: %w", err) + } + return nil +} + +// nodeID derives a stable deterministic ID from type+label. +func nodeID(t NodeType, label string) string { + return fmt.Sprintf("%s:%s", t, strings.ToLower(strings.ReplaceAll(label, " ", "_"))) +} + +// edgeID derives a stable deterministic ID from endpoints+relation. +func edgeID(source, target string, relation RelationType) string { + return fmt.Sprintf("%s--%s-->%s", source, relation, target) +} + +// --- Core operations --------------------------------------------------------- + +// UpsertNode creates or updates a node identified by (type, label). +// If the node already exists its content and metadata are updated in-place. +func UpsertNode(rootDir string, t NodeType, label, content string, metadata map[string]string) (*Node, error) { + mu.Lock() + defer mu.Unlock() + + g, err := loadLocked(rootDir) + if err != nil { + return nil, err + } + + id := nodeID(t, label) + now := time.Now().UTC() + + for i := range g.Nodes { + if g.Nodes[i].ID == id { + g.Nodes[i].Content = content + g.Nodes[i].Metadata = metadata + g.Nodes[i].UpdatedAt = now + g.Nodes[i].AccessCount++ + node := g.Nodes[i] + if err := save(rootDir, g); err != nil { + return nil, err + } + return &node, nil + } + } + + node := Node{ + ID: id, + Type: t, + Label: label, + Content: content, + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + } + g.Nodes = append(g.Nodes, node) + if err := save(rootDir, g); err != nil { + return nil, err + } + return &node, nil +} + +// CreateRelation adds a directed edge between two existing nodes. +// Returns nil if either node ID is not found. +func CreateRelation(rootDir, sourceID, targetID string, relation RelationType, weight float64, metadata map[string]string) (*Edge, error) { + mu.Lock() + defer mu.Unlock() + + g, err := loadLocked(rootDir) + if err != nil { + return nil, err + } + + if !nodeExists(g, sourceID) || !nodeExists(g, targetID) { + return nil, nil //nolint:nilnil // caller checks nil to detect missing nodes + } + + if weight <= 0 { + weight = 1.0 + } + + id := edgeID(sourceID, targetID, relation) + + // Upsert: update weight if edge already exists. + for i := range g.Edges { + if g.Edges[i].ID == id { + g.Edges[i].Weight = weight + g.Edges[i].Metadata = metadata + edge := g.Edges[i] + if err := save(rootDir, g); err != nil { + return nil, err + } + return &edge, nil + } + } + + edge := Edge{ + ID: id, + Source: sourceID, + Target: targetID, + Relation: relation, + Weight: weight, + Metadata: metadata, + } + g.Edges = append(g.Edges, edge) + if err := save(rootDir, g); err != nil { + return nil, err + } + return &edge, nil +} + +// GetGraphStats returns a snapshot of node and edge counts. +func GetGraphStats(rootDir string) (GraphStats, error) { + g, err := load(rootDir) + if err != nil { + return GraphStats{}, err + } + return GraphStats{Nodes: len(g.Nodes), Edges: len(g.Edges)}, nil +} + +// PruneResult reports what was removed during a prune pass. +type PruneResult struct { + Removed int + Remaining int +} + +// PruneStaleLinks removes edges whose weight falls below threshold and then +// removes any nodes that have become fully orphaned (no edges in or out). +func PruneStaleLinks(rootDir string, threshold float64) (PruneResult, error) { + if threshold <= 0 { + threshold = 0.1 + } + + mu.Lock() + defer mu.Unlock() + + g, err := loadLocked(rootDir) + if err != nil { + return PruneResult{}, err + } + + removed := 0 + + // Remove weak edges. + live := g.Edges[:0] + for _, e := range g.Edges { + if e.Weight >= threshold { + live = append(live, e) + } else { + removed++ + } + } + g.Edges = live + + // Remove orphaned nodes (no remaining edges reference them). + connected := make(map[string]bool, len(g.Nodes)) + for _, e := range g.Edges { + connected[e.Source] = true + connected[e.Target] = true + } + liveNodes := g.Nodes[:0] + for _, n := range g.Nodes { + if connected[n.ID] { + liveNodes = append(liveNodes, n) + } else { + removed++ + } + } + g.Nodes = liveNodes + + if err := save(rootDir, g); err != nil { + return PruneResult{}, err + } + return PruneResult{Removed: removed, Remaining: len(g.Edges)}, nil +} + +// InterlinkResult is returned by AddInterlinkedContext. +type InterlinkResult struct { + Nodes []Node + Edges []Edge +} + +// AddInterlinkedContext bulk-upserts a set of nodes and, when autoLink is true, +// automatically creates similarity edges between pairs whose content overlaps +// above a fixed threshold (0.72 cosine-approximated via token Jaccard). +func AddInterlinkedContext(rootDir string, items []struct { + Type NodeType + Label string + Content string + Metadata map[string]string +}, autoLink bool) (*InterlinkResult, error) { + result := &InterlinkResult{} + + for _, item := range items { + n, err := UpsertNode(rootDir, item.Type, item.Label, item.Content, item.Metadata) + if err != nil { + return nil, err + } + result.Nodes = append(result.Nodes, *n) + } + + if !autoLink || len(result.Nodes) < 2 { + return result, nil + } + + const similarityThreshold = 0.72 + for i := 0; i < len(result.Nodes); i++ { + for j := i + 1; j < len(result.Nodes); j++ { + sim := jaccardSimilarity(result.Nodes[i].Content, result.Nodes[j].Content) + if sim >= similarityThreshold { + edge, err := CreateRelation(rootDir, + result.Nodes[i].ID, result.Nodes[j].ID, + RelationSimilarTo, sim, nil) + if err != nil { + return nil, err + } + if edge != nil { + result.Edges = append(result.Edges, *edge) + } + } + } + } + return result, nil +} + +// --- Traversal --------------------------------------------------------------- + +// SearchResult is returned by SearchGraph. +type SearchResult struct { + Direct []TraversalResult + Neighbors []TraversalResult + TotalNodes int + TotalEdges int +} + +// SearchGraph finds nodes whose label or content matches query, then expands +// one hop to collect linked neighbors. Results are scored by relevance. +func SearchGraph(rootDir, query string, maxDepth, topK int, edgeFilter []RelationType) (*SearchResult, error) { + if strings.TrimSpace(query) == "" { + g, err := load(rootDir) + if err != nil { + return nil, err + } + return &SearchResult{TotalNodes: len(g.Nodes), TotalEdges: len(g.Edges)}, nil + } + if maxDepth <= 0 { + maxDepth = 2 + } + if topK <= 0 { + topK = 10 + } + + g, err := load(rootDir) + if err != nil { + return nil, err + } + + queryLower := strings.ToLower(query) + nodeByID := indexNodes(g) + adjOut := buildAdjacency(g, edgeFilter) // nodeID → []Edge + + // Score all nodes against the query. + type scored struct { + node Node + score float64 + } + var candidates []scored + for _, n := range g.Nodes { + score := scoreNode(n, queryLower) + if score > 0 { + candidates = append(candidates, scored{node: n, score: score}) + } + } + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].score > candidates[j].score + }) + if len(candidates) > topK { + candidates = candidates[:topK] + } + + result := &SearchResult{ + TotalNodes: len(g.Nodes), + TotalEdges: len(g.Edges), + } + directIDs := make(map[string]bool) + for _, c := range candidates { + result.Direct = append(result.Direct, TraversalResult{ + Node: c.node, + Depth: 0, + RelevanceScore: c.score, + PathRelations: []string{}, + }) + directIDs[c.node.ID] = true + } + + // Expand one hop of neighbors. + neighborIDs := make(map[string]bool) + for _, c := range candidates { + for _, edge := range adjOut[c.node.ID] { + if directIDs[edge.Target] || neighborIDs[edge.Target] { + continue + } + if n, ok := nodeByID[edge.Target]; ok { + neighborIDs[edge.Target] = true + score := scoreNode(n, queryLower) * edge.Weight * 0.5 + result.Neighbors = append(result.Neighbors, TraversalResult{ + Node: n, + Depth: 1, + RelevanceScore: score, + PathRelations: []string{string(edge.Relation)}, + }) + } + } + } + sort.Slice(result.Neighbors, func(i, j int) bool { + return result.Neighbors[i].RelevanceScore > result.Neighbors[j].RelevanceScore + }) + + // Bump access counts for returned nodes. + go func() { _ = bumpAccess(rootDir, directIDs) }() + + return result, nil +} + +// RetrieveWithTraversal performs a BFS/depth-limited walk starting from +// startNodeID, returning all reachable nodes up to maxDepth hops away. +// edgeFilter restricts which relation types are followed; nil follows all. +func RetrieveWithTraversal(rootDir, startNodeID string, maxDepth int, edgeFilter []RelationType) ([]TraversalResult, error) { + if maxDepth <= 0 { + maxDepth = 2 + } + + g, err := load(rootDir) + if err != nil { + return nil, err + } + + nodeByID := indexNodes(g) + startNode, ok := nodeByID[startNodeID] + if !ok { + return nil, nil + } + + adjOut := buildAdjacency(g, edgeFilter) + + type queueItem struct { + nodeID string + depth int + pathRelations []string + score float64 + } + + visited := map[string]bool{startNodeID: true} + queue := []queueItem{{nodeID: startNodeID, depth: 0, pathRelations: []string{}, score: 1.0}} + var results []TraversalResult + + results = append(results, TraversalResult{ + Node: startNode, + Depth: 0, + RelevanceScore: 1.0, + PathRelations: []string{}, + }) + + for len(queue) > 0 { + item := queue[0] + queue = queue[1:] + + if item.depth >= maxDepth { + continue + } + + for _, edge := range adjOut[item.nodeID] { + if visited[edge.Target] { + continue + } + visited[edge.Target] = true + + n, ok := nodeByID[edge.Target] + if !ok { + continue + } + + // Decay relevance with depth and edge weight. + score := item.score * edge.Weight * math.Pow(0.8, float64(item.depth+1)) + pathRels := append(append([]string(nil), item.pathRelations...), string(edge.Relation)) + + results = append(results, TraversalResult{ + Node: n, + Depth: item.depth + 1, + RelevanceScore: score, + PathRelations: pathRels, + }) + queue = append(queue, queueItem{ + nodeID: edge.Target, + depth: item.depth + 1, + pathRelations: pathRels, + score: score, + }) + } + } + + // Sort by depth first, then descending relevance. + sort.Slice(results, func(i, j int) bool { + if results[i].Depth != results[j].Depth { + return results[i].Depth < results[j].Depth + } + return results[i].RelevanceScore > results[j].RelevanceScore + }) + + go func() { + ids := make(map[string]bool, len(results)) + for _, r := range results { + ids[r.Node.ID] = true + } + _ = bumpAccess(rootDir, ids) + }() + + return results, nil +} + +// --- Internal helpers -------------------------------------------------------- + +// loadLocked loads the graph assuming the caller already holds mu (write lock). +// It reads directly from the in-memory cache or from disk without acquiring +// any additional locks (caller already holds the write lock). +func loadLocked(rootDir string) (*graphData, error) { + if g, ok := cache[rootDir]; ok { + return g, nil + } + path := graphPath(rootDir) + g := &graphData{} + data, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("memorygraph: read %s: %w", path, err) + } + if err == nil { + if err := json.Unmarshal(data, g); err != nil { + return nil, fmt.Errorf("memorygraph: parse %s: %w", path, err) + } + } + cache[rootDir] = g + return g, nil +} + +func nodeExists(g *graphData, id string) bool { + for _, n := range g.Nodes { + if n.ID == id { + return true + } + } + return false +} + +func indexNodes(g *graphData) map[string]Node { + m := make(map[string]Node, len(g.Nodes)) + for _, n := range g.Nodes { + m[n.ID] = n + } + return m +} + +// buildAdjacency returns a map of nodeID → outgoing edges, optionally filtered +// by relation type. +func buildAdjacency(g *graphData, filter []RelationType) map[string][]Edge { + allowed := make(map[RelationType]bool, len(filter)) + for _, r := range filter { + allowed[r] = true + } + + adj := make(map[string][]Edge) + for _, e := range g.Edges { + if len(filter) > 0 && !allowed[e.Relation] { + continue + } + adj[e.Source] = append(adj[e.Source], e) + } + return adj +} + +// scoreNode scores a node against a lower-cased query string. +// Label matches are weighted more heavily than content matches. +func scoreNode(n Node, queryLower string) float64 { + labelLower := strings.ToLower(n.Label) + contentLower := strings.ToLower(n.Content) + + var score float64 + if strings.Contains(labelLower, queryLower) { + score += 1.0 + } + if strings.Contains(contentLower, queryLower) { + // Partial overlap: proportion of query tokens found in content. + score += tokenOverlap(queryLower, contentLower) * 0.6 + } + // Popularity bias: nodes accessed frequently are slightly preferred. + if n.AccessCount > 0 { + score += math.Log1p(float64(n.AccessCount)) * 0.05 + } + return score +} + +// tokenOverlap returns the fraction of query tokens present in text. +func tokenOverlap(query, text string) float64 { + queryTokens := strings.Fields(query) + if len(queryTokens) == 0 { + return 0 + } + found := 0 + for _, t := range queryTokens { + if strings.Contains(text, t) { + found++ + } + } + return float64(found) / float64(len(queryTokens)) +} + +// jaccardSimilarity approximates cosine similarity via token-set Jaccard. +func jaccardSimilarity(a, b string) float64 { + ta := tokenSet(a) + tb := tokenSet(b) + if len(ta) == 0 || len(tb) == 0 { + return 0 + } + intersection := 0 + for t := range ta { + if tb[t] { + intersection++ + } + } + return float64(intersection) / float64(len(ta)+len(tb)-intersection) +} + +func tokenSet(s string) map[string]bool { + m := make(map[string]bool) + for _, t := range strings.Fields(strings.ToLower(s)) { + m[t] = true + } + return m +} + +// bumpAccess increments the AccessCount for each node in ids and persists. +func bumpAccess(rootDir string, ids map[string]bool) error { + mu.Lock() + defer mu.Unlock() + + g, ok := cache[rootDir] + if !ok { + return nil + } + for i := range g.Nodes { + if ids[g.Nodes[i].ID] { + g.Nodes[i].AccessCount++ + } + } + return save(rootDir, g) +} diff --git a/internal/memorygraph/peek.go b/internal/memorygraph/peek.go new file mode 100644 index 0000000..c068557 --- /dev/null +++ b/internal/memorygraph/peek.go @@ -0,0 +1,223 @@ +package memorygraph + +import ( + "fmt" + "strings" + "time" +) + +// NodePeek is a full snapshot of a node and all its edges, returned by Peek. +type NodePeek struct { + Node Node + EdgesOut []EdgePeek // edges where this node is the source + EdgesIn []EdgePeek // edges where this node is the target +} + +// EdgePeek is a human-readable summary of a single edge and its peer node. +type EdgePeek struct { + Edge Edge + PeerID string + PeerLabel string + PeerType NodeType +} + +// PeekOptions controls what Peek returns. +type PeekOptions struct { + RootDir string + // NodeID takes priority if set. + NodeID string + // Label is used for lookup when NodeID is empty (first match wins). + Label string +} + +// Peek returns a full NodePeek for the requested node: its content, metadata, +// access stats, and every inbound/outbound edge with peer labels resolved. +// Returns nil if the node cannot be found. +func Peek(opts PeekOptions) (*NodePeek, error) { + g, err := load(opts.RootDir) + if err != nil { + return nil, err + } + + nodeByID := indexNodes(g) + + // Resolve target node. + var target *Node + if opts.NodeID != "" { + if n, ok := nodeByID[opts.NodeID]; ok { + target = &n + } + } else if opts.Label != "" { + labelLower := strings.ToLower(opts.Label) + for i := range g.Nodes { + if strings.ToLower(g.Nodes[i].Label) == labelLower { + n := g.Nodes[i] + target = &n + break + } + } + } + + if target == nil { + return nil, nil //nolint:nilnil // caller checks nil to detect not-found + } + + peek := &NodePeek{Node: *target} + + for _, e := range g.Edges { + switch { + case e.Source == target.ID: + peer := nodeByID[e.Target] + peek.EdgesOut = append(peek.EdgesOut, EdgePeek{ + Edge: e, + PeerID: e.Target, + PeerLabel: peer.Label, + PeerType: peer.Type, + }) + case e.Target == target.ID: + peer := nodeByID[e.Source] + peek.EdgesIn = append(peek.EdgesIn, EdgePeek{ + Edge: e, + PeerID: e.Source, + PeerLabel: peer.Label, + PeerType: peer.Type, + }) + } + } + + return peek, nil +} + +// PeekList returns a lightweight summary of every node in the graph — +// ID, type, label, access count, age, and edge degree — sorted by access +// count descending. Useful for scanning the graph before pruning. +func PeekList(rootDir string) ([]NodePeek, error) { + g, err := load(rootDir) + if err != nil { + return nil, err + } + + nodeByID := indexNodes(g) + + edgesOut := make(map[string][]EdgePeek, len(g.Nodes)) + edgesIn := make(map[string][]EdgePeek, len(g.Nodes)) + + for _, e := range g.Edges { + peer := nodeByID[e.Target] + edgesOut[e.Source] = append(edgesOut[e.Source], EdgePeek{ + Edge: e, PeerID: e.Target, PeerLabel: peer.Label, PeerType: peer.Type, + }) + + peer = nodeByID[e.Source] + edgesIn[e.Target] = append(edgesIn[e.Target], EdgePeek{ + Edge: e, PeerID: e.Source, PeerLabel: peer.Label, PeerType: peer.Type, + }) + } + + peeks := make([]NodePeek, 0, len(g.Nodes)) + for _, n := range g.Nodes { + peeks = append(peeks, NodePeek{ + Node: n, + EdgesOut: edgesOut[n.ID], + EdgesIn: edgesIn[n.ID], + }) + } + + // Sort by access count desc, then label asc for stable output. + sortNodePeeks(peeks) + + return peeks, nil +} + +// FormatPeek renders a NodePeek as a human-readable block suitable for +// display in a terminal or MCP tool response. +func FormatPeek(p *NodePeek) string { + if p == nil { + return "❌ Node not found." + } + n := p.Node + age := time.Since(n.CreatedAt).Round(time.Hour) + + var b strings.Builder + fmt.Fprintf(&b, "┌─ [%s] %s\n", n.Type, n.Label) + fmt.Fprintf(&b, "│ ID: %s\n", n.ID) + fmt.Fprintf(&b, "│ Accessed: %dx │ Age: %s │ Updated: %s\n", + n.AccessCount, + age, + n.UpdatedAt.Format("2006-01-02 15:04"), + ) + if len(n.Metadata) > 0 { + fmt.Fprintf(&b, "│ Metadata: %s\n", formatMetadata(n.Metadata)) + } + fmt.Fprintf(&b, "│\n│ Content:\n│ %s\n", + strings.ReplaceAll(n.Content, "\n", "\n│ ")) + + if len(p.EdgesOut) > 0 { + fmt.Fprintf(&b, "│\n│ Out (%d):\n", len(p.EdgesOut)) + for _, ep := range p.EdgesOut { + fmt.Fprintf(&b, "│ ──[%s w:%.2f]──▶ [%s] %s\n", + ep.Edge.Relation, ep.Edge.Weight, ep.PeerType, ep.PeerLabel) + } + } + if len(p.EdgesIn) > 0 { + fmt.Fprintf(&b, "│\n│ In (%d):\n", len(p.EdgesIn)) + for _, ep := range p.EdgesIn { + fmt.Fprintf(&b, "│ [%s] %s ──[%s w:%.2f]──▶\n", + ep.PeerType, ep.PeerLabel, ep.Edge.Relation, ep.Edge.Weight) + } + } + b.WriteString("└─") + return b.String() +} + +// FormatPeekList renders a PeekList result as a compact table with one node +// per line, suitable for scanning before a prune pass. +func FormatPeekList(peeks []NodePeek) string { + if len(peeks) == 0 { + return "Graph is empty." + } + var b strings.Builder + fmt.Fprintf(&b, "%-12s %-10s %-32s %7s %4s %4s\n", + "TYPE", "ID (short)", "LABEL", "ACCESSED", "OUT", "IN") + b.WriteString(strings.Repeat("─", 76) + "\n") + for _, p := range peeks { + shortID := p.Node.ID + if len(shortID) > 10 { + shortID = shortID[:10] + "…" + } + label := p.Node.Label + if len(label) > 32 { + label = label[:31] + "…" + } + fmt.Fprintf(&b, "%-12s %-10s %-32s %7dx %4d %4d\n", + p.Node.Type, shortID, label, + p.Node.AccessCount, len(p.EdgesOut), len(p.EdgesIn)) + } + fmt.Fprintf(&b, "\n%d node(s) total\n", len(peeks)) + return b.String() +} + +// --- helpers ----------------------------------------------------------------- + +func formatMetadata(m map[string]string) string { + parts := make([]string, 0, len(m)) + for k, v := range m { + parts = append(parts, k+"="+v) + } + return strings.Join(parts, " ") +} + +func sortNodePeeks(peeks []NodePeek) { + // Insertion sort is fine for the typical small N here. + for i := 1; i < len(peeks); i++ { + for j := i; j > 0; j-- { + a, b := peeks[j-1], peeks[j] + if a.Node.AccessCount < b.Node.AccessCount || + (a.Node.AccessCount == b.Node.AccessCount && a.Node.Label > b.Node.Label) { + peeks[j-1], peeks[j] = peeks[j], peeks[j-1] + } else { + break + } + } + } +} diff --git a/internal/memorygraph/tools.go b/internal/memorygraph/tools.go new file mode 100644 index 0000000..9185945 --- /dev/null +++ b/internal/memorygraph/tools.go @@ -0,0 +1,245 @@ +// Package memorygraph — MCP tool wrappers. +// Each exported Tool* function is the Go equivalent of the TypeScript tool +// functions in tools/memory-tools.ts and follows the same output format so +// that callers get identical text responses. +package memorygraph + +import ( + "fmt" + "strings" +) + +// --- Tool option structs ------------------------------------------------------ + +// UpsertMemoryNodeOptions mirrors UpsertMemoryNodeOptions in the TS source. +type UpsertMemoryNodeOptions struct { + RootDir string + Type NodeType + Label string + Content string + Metadata map[string]string +} + +// CreateRelationOptions mirrors CreateRelationOptions in the TS source. +type CreateRelationOptions struct { + RootDir string + SourceID string + TargetID string + Relation RelationType + Weight float64 + Metadata map[string]string +} + +// SearchMemoryGraphOptions mirrors SearchMemoryGraphOptions in the TS source. +type SearchMemoryGraphOptions struct { + RootDir string + Query string + MaxDepth int + TopK int + EdgeFilter []RelationType +} + +// PruneStaleLinksOptions mirrors PruneStaleLinksOptions in the TS source. +type PruneStaleLinksOptions struct { + RootDir string + Threshold float64 +} + +// InterlinkedItem is a single entry for AddInterlinkedContext. +type InterlinkedItem struct { + Type NodeType + Label string + Content string + Metadata map[string]string +} + +// AddInterlinkedContextOptions mirrors AddInterlinkedContextOptions in the TS source. +type AddInterlinkedContextOptions struct { + RootDir string + Items []InterlinkedItem + AutoLink bool +} + +// RetrieveWithTraversalOptions mirrors RetrieveWithTraversalOptions in the TS source. +type RetrieveWithTraversalOptions struct { + RootDir string + StartNodeID string + MaxDepth int + EdgeFilter []RelationType +} + +// --- Formatters -------------------------------------------------------------- + +func formatTraversalResult(r TraversalResult) string { + content := r.Node.Content + if len(content) > 120 { + content = content[:120] + "..." + } + lines := []string{ + fmt.Sprintf(" [%s] %s (depth: %d, score: %.2f)", r.Node.Type, r.Node.Label, r.Depth, r.RelevanceScore), + fmt.Sprintf(" Content: %s", content), + } + if len(r.PathRelations) > 1 { + lines = append(lines, fmt.Sprintf(" Path: %s", strings.Join(r.PathRelations, " "))) + } + lines = append(lines, fmt.Sprintf(" ID: %s | Accessed: %dx", r.Node.ID, r.Node.AccessCount)) + return strings.Join(lines, "\n") +} + +// --- Tool implementations ---------------------------------------------------- + +// ToolUpsertMemoryNode creates or updates a memory node and returns a +// human-readable summary including updated graph stats. +func ToolUpsertMemoryNode(opts UpsertMemoryNodeOptions) (string, error) { + node, err := UpsertNode(opts.RootDir, opts.Type, opts.Label, opts.Content, opts.Metadata) + if err != nil { + return "", err + } + stats, err := GetGraphStats(opts.RootDir) + if err != nil { + return "", err + } + return strings.Join([]string{ + fmt.Sprintf("✅ Memory node upserted: %s", node.Label), + fmt.Sprintf(" ID: %s", node.ID), + fmt.Sprintf(" Type: %s", node.Type), + fmt.Sprintf(" Access count: %d", node.AccessCount), + fmt.Sprintf("\nGraph: %d nodes, %d edges", stats.Nodes, stats.Edges), + }, "\n"), nil +} + +// ToolCreateRelation adds a directed edge between two existing nodes. +func ToolCreateRelation(opts CreateRelationOptions) (string, error) { + edge, err := CreateRelation(opts.RootDir, opts.SourceID, opts.TargetID, opts.Relation, opts.Weight, opts.Metadata) + if err != nil { + return "", err + } + if edge == nil { + return fmt.Sprintf("❌ Failed: one or both node IDs not found (source: %s, target: %s)", + opts.SourceID, opts.TargetID), nil + } + stats, err := GetGraphStats(opts.RootDir) + if err != nil { + return "", err + } + return strings.Join([]string{ + fmt.Sprintf("✅ Relation created: %s --[%s]--> %s", opts.SourceID, edge.Relation, opts.TargetID), + fmt.Sprintf(" Edge ID: %s", edge.ID), + fmt.Sprintf(" Weight: %.2f", edge.Weight), + fmt.Sprintf("\nGraph: %d nodes, %d edges", stats.Nodes, stats.Edges), + }, "\n"), nil +} + +// ToolSearchMemoryGraph searches the graph and returns direct matches plus +// one-hop neighbors, formatted identically to the TypeScript version. +func ToolSearchMemoryGraph(opts SearchMemoryGraphOptions) (string, error) { + result, err := SearchGraph(opts.RootDir, opts.Query, opts.MaxDepth, opts.TopK, opts.EdgeFilter) + if err != nil { + return "", err + } + if len(result.Direct) == 0 { + return fmt.Sprintf("No memory nodes found for: %q\nGraph has %d nodes, %d edges.", + opts.Query, result.TotalNodes, result.TotalEdges), nil + } + + sections := []string{ + fmt.Sprintf("Memory Graph Search: %q", opts.Query), + fmt.Sprintf("Graph: %d nodes, %d edges\n", result.TotalNodes, result.TotalEdges), + "Direct Matches:", + } + for _, hit := range result.Direct { + sections = append(sections, formatTraversalResult(hit)) + } + if len(result.Neighbors) > 0 { + sections = append(sections, "\nLinked Neighbors:") + for _, neighbor := range result.Neighbors { + sections = append(sections, formatTraversalResult(neighbor)) + } + } + return strings.Join(sections, "\n"), nil +} + +// ToolPruneStaleLinks removes weak edges and orphaned nodes. +func ToolPruneStaleLinks(opts PruneStaleLinksOptions) (string, error) { + result, err := PruneStaleLinks(opts.RootDir, opts.Threshold) + if err != nil { + return "", err + } + return strings.Join([]string{ + "🧹 Pruning complete", + fmt.Sprintf(" Removed: %d stale links/orphan nodes", result.Removed), + fmt.Sprintf(" Remaining edges: %d", result.Remaining), + }, "\n"), nil +} + +// ToolAddInterlinkedContext bulk-upserts nodes and optionally auto-links them +// by content similarity (threshold ≥ 0.72). +func ToolAddInterlinkedContext(opts AddInterlinkedContextOptions) (string, error) { + items := make([]struct { + Type NodeType + Label string + Content string + Metadata map[string]string + }, len(opts.Items)) + for i, it := range opts.Items { + items[i].Type = it.Type + items[i].Label = it.Label + items[i].Content = it.Content + items[i].Metadata = it.Metadata + } + + result, err := AddInterlinkedContext(opts.RootDir, items, opts.AutoLink) + if err != nil { + return "", err + } + + sections := []string{fmt.Sprintf("✅ Added %d interlinked nodes", len(result.Nodes))} + if len(result.Edges) > 0 { + sections = append(sections, fmt.Sprintf(" Auto-linked: %d similarity edges (threshold ≥ 0.72)", len(result.Edges))) + } else { + sections = append(sections, " No auto-links above threshold") + } + sections = append(sections, "\nNodes:") + for _, n := range result.Nodes { + sections = append(sections, fmt.Sprintf(" [%s] %s → %s", n.Type, n.Label, n.ID)) + } + if len(result.Edges) > 0 { + sections = append(sections, "\nEdges:") + for _, e := range result.Edges { + sections = append(sections, fmt.Sprintf(" %s --[%s w:%.2f]--> %s", + e.Source, e.Relation, e.Weight, e.Target)) + } + } + + stats, err := GetGraphStats(opts.RootDir) + if err != nil { + return "", err + } + sections = append(sections, fmt.Sprintf("\nGraph total: %d nodes, %d edges", stats.Nodes, stats.Edges)) + return strings.Join(sections, "\n"), nil +} + +// ToolRetrieveWithTraversal starts a BFS from startNodeID and returns all +// reachable nodes up to maxDepth, formatted with path context and scores. +func ToolRetrieveWithTraversal(opts RetrieveWithTraversalOptions) (string, error) { + results, err := RetrieveWithTraversal(opts.RootDir, opts.StartNodeID, opts.MaxDepth, opts.EdgeFilter) + if err != nil { + return "", err + } + if len(results) == 0 { + return fmt.Sprintf("❌ Node not found: %s", opts.StartNodeID), nil + } + + maxDepth := opts.MaxDepth + if maxDepth <= 0 { + maxDepth = 2 + } + + sections := []string{ + fmt.Sprintf("Traversal from: %s (depth limit: %d)\n", results[0].Node.Label, maxDepth), + } + for _, r := range results { + sections = append(sections, formatTraversalResult(r)) + } + return strings.Join(sections, "\n"), nil +}