From a630aa05575b59cf024ffdfadd6c771b5a70a9f8 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 00:11:53 +0100 Subject: [PATCH 01/19] Add diff command design spec and implementation plan --- docs/specs/2026-04-14-diff-command-design.md | 139 +++ .../plans/2026-04-15-diff-command.md | 928 ++++++++++++++++++ 2 files changed, 1067 insertions(+) create mode 100644 docs/specs/2026-04-14-diff-command-design.md create mode 100644 docs/superpowers/plans/2026-04-15-diff-command.md diff --git a/docs/specs/2026-04-14-diff-command-design.md b/docs/specs/2026-04-14-diff-command-design.md new file mode 100644 index 0000000..9993a63 --- /dev/null +++ b/docs/specs/2026-04-14-diff-command-design.md @@ -0,0 +1,139 @@ +# gpuaudit diff — Historical Scan Comparison + +**Issue:** #5 +**Date:** 2026-04-14 + +## Problem + +After running `gpuaudit scan --format json` periodically, there's no way to compare two reports and see what changed. When rescanning infrastructure (e.g. after filing a Karpenter waste escalation), you can't tell at a glance whether the situation improved. + +## Solution + +New `gpuaudit diff old.json new.json` subcommand that loads two scan result JSON files, matches instances by `instance_id`, and produces a cost-focused delta report. + +## Data Model + +### DiffResult + +```go +// internal/diff/diff.go + +type DiffResult struct { + OldTimestamp time.Time + NewTimestamp time.Time + Added []models.GPUInstance // in new, not in old + Removed []models.GPUInstance // in old, not in new + Changed []InstanceDiff // in both, something changed + UnchangedCount int + CostSummary CostDelta +} + +type InstanceDiff struct { + InstanceID string + Old models.GPUInstance + New models.GPUInstance + CostDelta float64 // new.MonthlyCost - old.MonthlyCost + Changes []string // human-readable: "GPU allocated: 0 -> 2" +} + +type CostDelta struct { + OldTotalMonthlyCost float64 + NewTotalMonthlyCost float64 + CostChange float64 + OldTotalWaste float64 + NewTotalWaste float64 + WasteChange float64 + AddedCost float64 // cost from new instances + RemovedSavings float64 // cost removed with departing instances +} +``` + +### Instance Matching + +Instances are matched by `instance_id`. If an instance ID exists only in old, it's "removed". Only in new, it's "added". In both, it's compared for changes. + +No fuzzy matching by name or instance type — a replaced node is honestly reported as removed + added. + +### Change Detection + +An instance is "changed" if any of these fields differ between old and new: + +| Field | Format | +|---|---| +| `InstanceType` | `Instance type: g6e.16xlarge -> g6e.48xlarge` | +| `PricingModel` | `Pricing: on-demand -> reserved` | +| `MonthlyCost` | `Cost: $6,750 -> $4,200 (-$2,550/mo)` | +| `State` | `State: ready -> not-ready` | +| `GPUAllocated` | `GPU allocated: 0 -> 2` | +| `WasteSignals` | `Severity: critical -> (none)` or `Signal: idle -> low_utilization` | + +If none of these differ, the instance is counted as unchanged (not listed). + +## Table Output Format + +``` + gpuaudit diff -- Apr 08 -> Apr 14 + + +----------------------------------------------------------+ + | Cost Delta | + +----------------------------------------------------------+ + | Monthly spend: $372,000 -> $251,000 (-$121,000) | + | Estimated waste: $189,000 -> $68,000 (-$121,000) | + | Instances: 116 -> 82 (-34 removed, +0 added) | + +----------------------------------------------------------+ + + REMOVED -- 34 instance(s), -$121,000/mo + + Instance Type Monthly + ------------------------------------ -------------------------- ---------- + ml-prod-iad/ip-10-22-249-9 g6e.48xlarge (8x L40S) $26,800 + ... + + ADDED -- 2 instance(s), +$5,000/mo + + Instance Type Monthly + ------------------------------------ -------------------------- ---------- + ... + + CHANGED -- 3 instance(s) + + Instance Change + ------------------------------------ ------------------------------------------ + ml-prod-iad/ip-10-1-2-3 GPU allocated: 0 -> 2 (was idle) + ml-prod-iad/ip-10-4-5-6 Pricing: on-demand -> reserved (-$2,400/mo) + + UNCHANGED -- 77 instance(s) +``` + +Cost summary box is the first thing rendered — the "did it get better" answer. Sections only appear if non-empty. + +## JSON Output Format + +Serialize `DiffResult` as JSON for programmatic consumption. Same structure as the Go types above. + +## CLI Interface + +``` +gpuaudit diff [--format table|json] +``` + +- Two required positional arguments (file paths to JSON scan results) +- `--format` flag, default `table` +- Exit code 0 always (diff is informational, not pass/fail) + +## Files + +### Create + +- `internal/diff/diff.go` — `Compare(old, new *models.ScanResult) *DiffResult` plus `computeCostDelta` and `diffInstance` helpers +- `internal/diff/diff_test.go` — tests: added, removed, changed (each field), unchanged, cost math, empty scans +- `internal/output/diff.go` — `FormatDiffTable(w, *DiffResult)` and `FormatDiffJSON(w, *DiffResult)` + +### Modify + +- `cmd/gpuaudit/main.go` — register `diff` subcommand with two positional args and `--format` flag + +## Testing + +- Unit tests in `diff_test.go` covering: add/remove/change detection, all 6 compared fields, cost delta math, edge cases (empty old, empty new, identical scans) +- Manual test: run two scans against same cluster, diff the output files diff --git a/docs/superpowers/plans/2026-04-15-diff-command.md b/docs/superpowers/plans/2026-04-15-diff-command.md new file mode 100644 index 0000000..6aeeef0 --- /dev/null +++ b/docs/superpowers/plans/2026-04-15-diff-command.md @@ -0,0 +1,928 @@ +# Diff Command Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add `gpuaudit diff old.json new.json` subcommand that compares two scan result JSON files and reports cost deltas, added/removed/changed instances. + +**Architecture:** New `internal/diff/` package contains the comparison logic (`Compare` function). New `internal/output/diff.go` handles table and JSON formatting. The `diff` subcommand in `cmd/gpuaudit/main.go` reads two JSON files, calls `Compare`, and formats the output. + +**Tech Stack:** Go standard library only — no new dependencies. Uses existing `models.ScanResult` and `models.GPUInstance` types. + +--- + +### File Map + +| File | Action | Responsibility | +|---|---|---| +| `internal/diff/diff.go` | Create | `DiffResult`, `InstanceDiff`, `CostDelta` types + `Compare` function | +| `internal/diff/diff_test.go` | Create | Unit tests for comparison logic | +| `internal/output/diff.go` | Create | `FormatDiffTable` and `FormatDiffJSON` formatters | +| `cmd/gpuaudit/main.go` | Modify | Register `diff` subcommand | + +--- + +### Task 1: Core diff types and Compare function + +**Files:** +- Create: `internal/diff/diff.go` +- Create: `internal/diff/diff_test.go` + +- [ ] **Step 1: Write the test file with test for added instances** + +Create `internal/diff/diff_test.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package diff + +import ( + "testing" + "time" + + "github.com/gpuaudit/cli/internal/models" +) + +func scanResult(instances ...models.GPUInstance) *models.ScanResult { + return &models.ScanResult{ + Timestamp: time.Date(2026, 4, 8, 12, 0, 0, 0, time.UTC), + Instances: instances, + Summary: models.ScanSummary{ + TotalInstances: len(instances), + TotalMonthlyCost: sumMonthlyCost(instances), + TotalEstimatedWaste: sumWaste(instances), + }, + } +} + +func sumMonthlyCost(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.MonthlyCost + } + return total +} + +func sumWaste(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.EstimatedSavings + } + return total +} + +func inst(id string, monthlyCost float64) models.GPUInstance { + return models.GPUInstance{ + InstanceID: id, + InstanceType: "g6e.16xlarge", + GPUModel: "L40S", + GPUCount: 1, + MonthlyCost: monthlyCost, + HourlyCost: monthlyCost / 730, + State: "ready", + Source: models.SourceK8sNode, + PricingModel: "on-demand", + } +} + +func TestCompare_AddedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 1 { + t.Fatalf("expected 1 added, got %d", len(result.Added)) + } + if result.Added[0].InstanceID != "i-bbb" { + t.Errorf("expected added instance i-bbb, got %s", result.Added[0].InstanceID) + } + if result.CostSummary.AddedCost != 3000 { + t.Errorf("expected added cost 3000, got %.0f", result.CostSummary.AddedCost) + } +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/diff/ -run TestCompare_AddedInstances -v` +Expected: FAIL — `Compare` not defined. + +- [ ] **Step 3: Write the diff package with types and Compare function** + +Create `internal/diff/diff.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package diff compares two scan results and reports what changed. +package diff + +import ( + "fmt" + + "github.com/gpuaudit/cli/internal/models" +) + +// DiffResult holds the comparison between two scan results. +type DiffResult struct { + OldTimestamp string `json:"old_timestamp"` + NewTimestamp string `json:"new_timestamp"` + Added []models.GPUInstance `json:"added,omitempty"` + Removed []models.GPUInstance `json:"removed,omitempty"` + Changed []InstanceDiff `json:"changed,omitempty"` + UnchangedCount int `json:"unchanged_count"` + CostSummary CostDelta `json:"cost_summary"` +} + +// InstanceDiff describes what changed for a single instance between scans. +type InstanceDiff struct { + InstanceID string `json:"instance_id"` + Old models.GPUInstance `json:"old"` + New models.GPUInstance `json:"new"` + CostDelta float64 `json:"cost_delta"` + Changes []string `json:"changes"` +} + +// CostDelta summarizes the financial impact of changes between scans. +type CostDelta struct { + OldTotalMonthlyCost float64 `json:"old_total_monthly_cost"` + NewTotalMonthlyCost float64 `json:"new_total_monthly_cost"` + CostChange float64 `json:"cost_change"` + OldTotalWaste float64 `json:"old_total_waste"` + NewTotalWaste float64 `json:"new_total_waste"` + WasteChange float64 `json:"waste_change"` + AddedCost float64 `json:"added_cost"` + RemovedSavings float64 `json:"removed_savings"` +} + +// Compare computes the diff between two scan results, matching instances by ID. +func Compare(old, new *models.ScanResult) *DiffResult { + oldMap := make(map[string]models.GPUInstance, len(old.Instances)) + for _, inst := range old.Instances { + oldMap[inst.InstanceID] = inst + } + + newMap := make(map[string]models.GPUInstance, len(new.Instances)) + for _, inst := range new.Instances { + newMap[inst.InstanceID] = inst + } + + result := &DiffResult{ + OldTimestamp: old.Timestamp.Format("2006-01-02 15:04 UTC"), + NewTimestamp: new.Timestamp.Format("2006-01-02 15:04 UTC"), + } + + // Find removed and changed + for id, oldInst := range oldMap { + newInst, exists := newMap[id] + if !exists { + result.Removed = append(result.Removed, oldInst) + continue + } + changes := diffInstance(oldInst, newInst) + if len(changes) > 0 { + result.Changed = append(result.Changed, InstanceDiff{ + InstanceID: id, + Old: oldInst, + New: newInst, + CostDelta: newInst.MonthlyCost - oldInst.MonthlyCost, + Changes: changes, + }) + } else { + result.UnchangedCount++ + } + } + + // Find added + for id, newInst := range newMap { + if _, exists := oldMap[id]; !exists { + result.Added = append(result.Added, newInst) + } + } + + // Cost summary + result.CostSummary = computeCostDelta(old, new, result) + + return result +} + +func diffInstance(old, new models.GPUInstance) []string { + var changes []string + + if old.InstanceType != new.InstanceType { + changes = append(changes, fmt.Sprintf("Instance type: %s → %s", old.InstanceType, new.InstanceType)) + } + if old.PricingModel != new.PricingModel { + changes = append(changes, fmt.Sprintf("Pricing: %s → %s", old.PricingModel, new.PricingModel)) + } + if old.MonthlyCost != new.MonthlyCost { + delta := new.MonthlyCost - old.MonthlyCost + sign := "+" + if delta < 0 { + sign = "" + } + changes = append(changes, fmt.Sprintf("Cost: $%.0f → $%.0f (%s$%.0f/mo)", old.MonthlyCost, new.MonthlyCost, sign, delta)) + } + if old.State != new.State { + changes = append(changes, fmt.Sprintf("State: %s → %s", old.State, new.State)) + } + if old.GPUAllocated != new.GPUAllocated { + changes = append(changes, fmt.Sprintf("GPU allocated: %d → %d", old.GPUAllocated, new.GPUAllocated)) + } + if maxSeverityStr(old.WasteSignals) != maxSeverityStr(new.WasteSignals) { + oldSev := maxSeverityStr(old.WasteSignals) + newSev := maxSeverityStr(new.WasteSignals) + if oldSev == "" { + oldSev = "(none)" + } + if newSev == "" { + newSev = "(none)" + } + changes = append(changes, fmt.Sprintf("Severity: %s → %s", oldSev, newSev)) + } + + return changes +} + +func maxSeverityStr(signals []models.WasteSignal) string { + max := models.Severity("") + for _, s := range signals { + if s.Severity == models.SeverityCritical { + return string(models.SeverityCritical) + } + if s.Severity == models.SeverityWarning { + max = models.SeverityWarning + } + if s.Severity == models.SeverityInfo && max == "" { + max = models.SeverityInfo + } + } + return string(max) +} + +func computeCostDelta(old, new *models.ScanResult, diff *DiffResult) CostDelta { + cd := CostDelta{ + OldTotalMonthlyCost: old.Summary.TotalMonthlyCost, + NewTotalMonthlyCost: new.Summary.TotalMonthlyCost, + CostChange: new.Summary.TotalMonthlyCost - old.Summary.TotalMonthlyCost, + OldTotalWaste: old.Summary.TotalEstimatedWaste, + NewTotalWaste: new.Summary.TotalEstimatedWaste, + WasteChange: new.Summary.TotalEstimatedWaste - old.Summary.TotalEstimatedWaste, + } + + for _, inst := range diff.Added { + cd.AddedCost += inst.MonthlyCost + } + for _, inst := range diff.Removed { + cd.RemovedSavings += inst.MonthlyCost + } + + return cd +} +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/diff/ -run TestCompare_AddedInstances -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/diff/diff.go internal/diff/diff_test.go +git commit -m "Add diff package with Compare function and added-instance test" +``` + +--- + +### Task 2: Tests for removed, changed, unchanged, and cost math + +**Files:** +- Modify: `internal/diff/diff_test.go` + +- [ ] **Step 1: Add test for removed instances** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_RemovedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750)) + + result := Compare(old, new) + + if len(result.Removed) != 1 { + t.Fatalf("expected 1 removed, got %d", len(result.Removed)) + } + if result.Removed[0].InstanceID != "i-bbb" { + t.Errorf("expected removed instance i-bbb, got %s", result.Removed[0].InstanceID) + } + if result.CostSummary.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", result.CostSummary.RemovedSavings) + } +} +``` + +- [ ] **Step 2: Add test for changed instances (cost change)** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_CostChange(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 4200)) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + if result.Changed[0].CostDelta != -2550 { + t.Errorf("expected cost delta -2550, got %.0f", result.Changed[0].CostDelta) + } + found := false + for _, c := range result.Changed[0].Changes { + if c == "Cost: $6750 → $4200 (-$2550/mo)" { + found = true + } + } + if !found { + t.Errorf("expected cost change string, got %v", result.Changed[0].Changes) + } +} +``` + +- [ ] **Step 3: Add test for changed instances (instance type, pricing model, state, GPU allocated, severity)** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_AllFieldChanges(t *testing.T) { + oldInst := inst("i-aaa", 6750) + oldInst.InstanceType = "g6e.16xlarge" + oldInst.PricingModel = "on-demand" + oldInst.State = "ready" + oldInst.GPUAllocated = 0 + oldInst.WasteSignals = []models.WasteSignal{{Severity: models.SeverityCritical}} + + newInst := inst("i-aaa", 4200) + newInst.InstanceType = "g6e.12xlarge" + newInst.PricingModel = "reserved" + newInst.State = "not-ready" + newInst.GPUAllocated = 2 + newInst.WasteSignals = nil + + old := scanResult(oldInst) + new := scanResult(newInst) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + + changes := result.Changed[0].Changes + expected := []string{ + "Instance type: g6e.16xlarge → g6e.12xlarge", + "Pricing: on-demand → reserved", + "Cost: $6750 → $4200 (-$2550/mo)", + "State: ready → not-ready", + "GPU allocated: 0 → 2", + "Severity: critical → (none)", + } + if len(changes) != len(expected) { + t.Fatalf("expected %d changes, got %d: %v", len(expected), len(changes), changes) + } + for i, exp := range expected { + if changes[i] != exp { + t.Errorf("change[%d]: expected %q, got %q", i, exp, changes[i]) + } + } +} +``` + +- [ ] **Step 4: Add test for unchanged instances** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_UnchangedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 0 { + t.Errorf("expected 0 added, got %d", len(result.Added)) + } + if len(result.Removed) != 0 { + t.Errorf("expected 0 removed, got %d", len(result.Removed)) + } + if len(result.Changed) != 0 { + t.Errorf("expected 0 changed, got %d", len(result.Changed)) + } + if result.UnchangedCount != 2 { + t.Errorf("expected 2 unchanged, got %d", result.UnchangedCount) + } +} +``` + +- [ ] **Step 5: Add test for cost summary math** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_CostSummary(t *testing.T) { + oldA := inst("i-aaa", 6750) + oldA.EstimatedSavings = 6750 + oldB := inst("i-bbb", 3000) + + newA := inst("i-aaa", 6750) + newA.EstimatedSavings = 6750 + newC := inst("i-ccc", 2000) + + old := scanResult(oldA, oldB) + new := scanResult(newA, newC) + + result := Compare(old, new) + + cs := result.CostSummary + if cs.OldTotalMonthlyCost != 9750 { + t.Errorf("expected old total 9750, got %.0f", cs.OldTotalMonthlyCost) + } + if cs.NewTotalMonthlyCost != 8750 { + t.Errorf("expected new total 8750, got %.0f", cs.NewTotalMonthlyCost) + } + if cs.CostChange != -1000 { + t.Errorf("expected cost change -1000, got %.0f", cs.CostChange) + } + if cs.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", cs.RemovedSavings) + } + if cs.AddedCost != 2000 { + t.Errorf("expected added cost 2000, got %.0f", cs.AddedCost) + } +} +``` + +- [ ] **Step 6: Add test for empty scans** + +Append to `internal/diff/diff_test.go`: + +```go +func TestCompare_EmptyScans(t *testing.T) { + old := scanResult() + new := scanResult() + + result := Compare(old, new) + + if len(result.Added) != 0 || len(result.Removed) != 0 || len(result.Changed) != 0 { + t.Errorf("expected no changes for empty scans") + } + if result.UnchangedCount != 0 { + t.Errorf("expected 0 unchanged, got %d", result.UnchangedCount) + } +} +``` + +- [ ] **Step 7: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/diff/ -v` +Expected: All 6 tests PASS. + +- [ ] **Step 8: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/diff/diff_test.go +git commit -m "Add diff comparison tests: removed, changed, unchanged, cost math, empty" +``` + +--- + +### Task 3: Diff output formatters + +**Files:** +- Create: `internal/output/diff.go` + +- [ ] **Step 1: Create the diff output formatters** + +Create `internal/output/diff.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package output + +import ( + "encoding/json" + "fmt" + "io" + "sort" + "strings" + + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" +) + +// FormatDiffTable writes a human-readable diff report. +func FormatDiffTable(w io.Writer, d *diff.DiffResult) { + fmt.Fprintf(w, "\n gpuaudit diff — %s → %s\n\n", d.OldTimestamp, d.NewTimestamp) + + cs := d.CostSummary + + oldCount := len(d.Removed) + len(d.Changed) + d.UnchangedCount + newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount + + // Cost summary box + fmt.Fprintf(w, " ┌──────────────────────────────────────────────────────────┐\n") + fmt.Fprintf(w, " │ Cost Delta │\n") + fmt.Fprintf(w, " ├──────────────────────────────────────────────────────────┤\n") + fmt.Fprintf(w, " │ Monthly spend: $%s → $%s (%s)%s│\n", + fmtCost(cs.OldTotalMonthlyCost), fmtCost(cs.NewTotalMonthlyCost), + fmtDelta(cs.CostChange), pad(cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, cs.CostChange)) + fmt.Fprintf(w, " │ Estimated waste: $%s → $%s (%s)%s│\n", + fmtCost(cs.OldTotalWaste), fmtCost(cs.NewTotalWaste), + fmtDelta(cs.WasteChange), pad(cs.OldTotalWaste, cs.NewTotalWaste, cs.WasteChange)) + fmt.Fprintf(w, " │ Instances: %d → %d (-%d removed, +%d added)%s│\n", + oldCount, newCount, len(d.Removed), len(d.Added), + padInstances(oldCount, newCount, len(d.Removed), len(d.Added))) + fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n") + + // Removed + if len(d.Removed) > 0 { + sortByCost(d.Removed) + fmt.Fprintf(w, "\n REMOVED — %d instance(s), -$%.0f/mo\n\n", len(d.Removed), cs.RemovedSavings) + printDiffInstanceTable(w, d.Removed) + } + + // Added + if len(d.Added) > 0 { + sortByCost(d.Added) + fmt.Fprintf(w, "\n ADDED — %d instance(s), +$%.0f/mo\n\n", len(d.Added), cs.AddedCost) + printDiffInstanceTable(w, d.Added) + } + + // Changed + if len(d.Changed) > 0 { + fmt.Fprintf(w, "\n CHANGED — %d instance(s)\n\n", len(d.Changed)) + fmt.Fprintf(w, " %-36s %s\n", "Instance", "Change") + fmt.Fprintf(w, " %s %s\n", strings.Repeat("─", 36), strings.Repeat("─", 50)) + for _, c := range d.Changed { + name := c.New.Name + if name == "" { + name = c.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + for i, change := range c.Changes { + if i == 0 { + fmt.Fprintf(w, " %-36s %s\n", name, change) + } else { + fmt.Fprintf(w, " %-36s %s\n", "", change) + } + } + } + fmt.Fprintln(w) + } + + // Unchanged + if d.UnchangedCount > 0 { + fmt.Fprintf(w, " UNCHANGED — %d instance(s)\n\n", d.UnchangedCount) + } +} + +func printDiffInstanceTable(w io.Writer, instances []models.GPUInstance) { + fmt.Fprintf(w, " %-36s %-26s %10s\n", "Instance", "Type", "Monthly") + fmt.Fprintf(w, " %s %s %s\n", + strings.Repeat("─", 36), strings.Repeat("─", 26), strings.Repeat("─", 10)) + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + fmt.Fprintf(w, " %-36s %-26s $%9.0f\n", name, typeDesc, inst.MonthlyCost) + } +} + +func sortByCost(instances []models.GPUInstance) { + sort.Slice(instances, func(i, j int) bool { + return instances[i].MonthlyCost > instances[j].MonthlyCost + }) +} + +func fmtCost(v float64) string { + return fmt.Sprintf("%.0f", v) +} + +func fmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +// pad and padInstances return enough spaces to right-fill the summary box line to the closing │. +// These are best-effort; exact alignment depends on number widths. +func pad(old, new, delta float64) string { + content := fmt.Sprintf(" │ Monthly spend: $%s → $%s (%s)", + fmtCost(old), fmtCost(new), fmtDelta(delta)) + if len(content) >= 59 { + return "" + } + return strings.Repeat(" ", 59-len(content)) +} + +func padInstances(oldCount, newCount, removed, added int) string { + content := fmt.Sprintf(" │ Instances: %d → %d (-%d removed, +%d added)", + oldCount, newCount, removed, added) + if len(content) >= 59 { + return "" + } + return strings.Repeat(" ", 59-len(content)) +} + +// FormatDiffJSON writes the diff result as pretty-printed JSON. +func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(d) +} +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: Success (no errors). + +- [ ] **Step 3: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/diff.go +git commit -m "Add diff table and JSON output formatters" +``` + +--- + +### Task 4: Wire up the diff subcommand in main.go + +**Files:** +- Modify: `cmd/gpuaudit/main.go:7-21` (imports) +- Modify: `cmd/gpuaudit/main.go:77-80` (init, register command) + +- [ ] **Step 1: Add the diff subcommand to main.go** + +Add import for the diff package. In the imports block at the top, add: + +```go +"github.com/gpuaudit/cli/internal/diff" +``` + +Add these variables after the scan flag variables (after line 53, before `var scanCmd`): + +```go +// --- diff command --- + +var diffFormat string + +var diffCmd = &cobra.Command{ + Use: "diff ", + Short: "Compare two scan results and show what changed", + Args: cobra.ExactArgs(2), + RunE: runDiff, +} +``` + +Register the command in the first `init()` function, alongside the other `rootCmd.AddCommand` calls (line 78): + +```go +rootCmd.AddCommand(diffCmd) +``` + +Add the flag registration in a new `init()` block or the existing one: + +```go +func init() { + diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") +} +``` + +Add the `runDiff` function after `runScan`: + +```go +func runDiff(cmd *cobra.Command, args []string) error { + old, err := loadScanResult(args[0]) + if err != nil { + return fmt.Errorf("loading old scan: %w", err) + } + new, err := loadScanResult(args[1]) + if err != nil { + return fmt.Errorf("loading new scan: %w", err) + } + + result := diff.Compare(old, new) + + switch strings.ToLower(diffFormat) { + case "json": + return output.FormatDiffJSON(os.Stdout, result) + default: + output.FormatDiffTable(os.Stdout, result) + } + + return nil +} + +func loadScanResult(path string) (*models.ScanResult, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var result models.ScanResult + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + return &result, nil +} +``` + +Note: `encoding/json` is already imported in main.go (used by iam-policy command). + +- [ ] **Step 2: Verify it compiles** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: Success. + +- [ ] **Step 3: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: All tests pass (including existing tests and new diff tests). + +- [ ] **Step 4: Run go vet** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go vet ./...` +Expected: Clean. + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add diff subcommand to compare two scan results + +Closes #5" +``` + +--- + +### Task 5: Manual smoke test + +- [ ] **Step 1: Create two test JSON files and run diff** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit + +cat > /tmp/old-scan.json << 'EOF' +{ + "timestamp": "2026-04-08T12:00:00Z", + "account_id": "123456789", + "regions": ["us-east-1"], + "scan_duration": "5s", + "instances": [ + { + "instance_id": "i-aaa", + "source": "k8s-node", + "region": "us-east-1", + "name": "ml-prod/node-1", + "instance_type": "g6e.16xlarge", + "gpu_model": "L40S", + "gpu_count": 1, + "state": "ready", + "launch_time": "2026-03-01T00:00:00Z", + "uptime_hours": 912, + "pricing_model": "on-demand", + "hourly_cost": 9.25, + "monthly_cost": 6750, + "gpu_allocated": 0, + "estimated_savings": 6750, + "waste_signals": [{"type": "idle", "severity": "critical", "confidence": 0.9, "evidence": "No GPU pods"}], + "recommendations": [{"action": "terminate", "description": "Remove idle node", "current_monthly_cost": 6750, "monthly_savings": 6750, "savings_percent": 100, "risk": "low"}] + }, + { + "instance_id": "i-bbb", + "source": "k8s-node", + "region": "us-east-1", + "name": "ml-prod/node-2", + "instance_type": "g6e.16xlarge", + "gpu_model": "L40S", + "gpu_count": 1, + "state": "ready", + "launch_time": "2026-03-01T00:00:00Z", + "uptime_hours": 912, + "pricing_model": "on-demand", + "hourly_cost": 9.25, + "monthly_cost": 6750, + "gpu_allocated": 1, + "estimated_savings": 0 + } + ], + "summary": { + "total_instances": 2, + "total_monthly_cost": 13500, + "total_estimated_waste": 6750, + "waste_percent": 50, + "critical_count": 1, + "warning_count": 0, + "info_count": 0, + "healthy_count": 1 + } +} +EOF + +cat > /tmp/new-scan.json << 'EOF' +{ + "timestamp": "2026-04-14T12:00:00Z", + "account_id": "123456789", + "regions": ["us-east-1"], + "scan_duration": "4s", + "instances": [ + { + "instance_id": "i-bbb", + "source": "k8s-node", + "region": "us-east-1", + "name": "ml-prod/node-2", + "instance_type": "g6e.16xlarge", + "gpu_model": "L40S", + "gpu_count": 1, + "state": "ready", + "launch_time": "2026-03-01T00:00:00Z", + "uptime_hours": 1056, + "pricing_model": "on-demand", + "hourly_cost": 9.25, + "monthly_cost": 6750, + "gpu_allocated": 1, + "estimated_savings": 0 + }, + { + "instance_id": "i-ccc", + "source": "k8s-node", + "region": "us-east-1", + "name": "ml-prod/node-3", + "instance_type": "g6.2xlarge", + "gpu_model": "L4", + "gpu_count": 1, + "state": "ready", + "launch_time": "2026-04-10T00:00:00Z", + "uptime_hours": 96, + "pricing_model": "on-demand", + "hourly_cost": 1.23, + "monthly_cost": 898, + "gpu_allocated": 1, + "estimated_savings": 0 + } + ], + "summary": { + "total_instances": 2, + "total_monthly_cost": 7648, + "total_estimated_waste": 0, + "waste_percent": 0, + "critical_count": 0, + "warning_count": 0, + "info_count": 0, + "healthy_count": 2 + } +} +EOF + +go run ./cmd/gpuaudit diff /tmp/old-scan.json /tmp/new-scan.json +``` + +Expected: Table output showing i-aaa removed (-$6,750/mo), i-ccc added (+$898/mo), i-bbb unchanged. Cost summary showing $13,500 → $7,648 (-$5,852). + +- [ ] **Step 2: Test JSON output** + +```bash +go run ./cmd/gpuaudit diff /tmp/old-scan.json /tmp/new-scan.json --format json +``` + +Expected: JSON output with `added`, `removed`, `changed`, `unchanged_count`, and `cost_summary` fields. + +- [ ] **Step 3: Clean up test files** + +```bash +rm /tmp/old-scan.json /tmp/new-scan.json +``` From f963144d8ce87f2daa3a7cc636f65a76e64f0a1e Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:03:26 +0100 Subject: [PATCH 02/19] Add diff package with Compare function and tests Compares two scan results by instance ID. Detects added, removed, and changed instances across 6 fields (instance type, pricing model, cost, state, GPU allocation, waste severity). Computes cost deltas. --- internal/diff/diff.go | 171 +++++++++++++++++++++++++++++ internal/diff/diff_test.go | 219 +++++++++++++++++++++++++++++++++++++ 2 files changed, 390 insertions(+) create mode 100644 internal/diff/diff.go create mode 100644 internal/diff/diff_test.go diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000..7d74430 --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,171 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package diff compares two scan results and reports what changed. +package diff + +import ( + "fmt" + + "github.com/gpuaudit/cli/internal/models" +) + +// DiffResult holds the comparison between two scan results. +type DiffResult struct { + OldTimestamp string `json:"old_timestamp"` + NewTimestamp string `json:"new_timestamp"` + Added []models.GPUInstance `json:"added,omitempty"` + Removed []models.GPUInstance `json:"removed,omitempty"` + Changed []InstanceDiff `json:"changed,omitempty"` + UnchangedCount int `json:"unchanged_count"` + CostSummary CostDelta `json:"cost_summary"` +} + +// InstanceDiff describes what changed for a single instance between scans. +type InstanceDiff struct { + InstanceID string `json:"instance_id"` + Old models.GPUInstance `json:"old"` + New models.GPUInstance `json:"new"` + CostDelta float64 `json:"cost_delta"` + Changes []string `json:"changes"` +} + +// CostDelta summarizes the financial impact of changes between scans. +type CostDelta struct { + OldTotalMonthlyCost float64 `json:"old_total_monthly_cost"` + NewTotalMonthlyCost float64 `json:"new_total_monthly_cost"` + CostChange float64 `json:"cost_change"` + OldTotalWaste float64 `json:"old_total_waste"` + NewTotalWaste float64 `json:"new_total_waste"` + WasteChange float64 `json:"waste_change"` + AddedCost float64 `json:"added_cost"` + RemovedSavings float64 `json:"removed_savings"` +} + +// Compare computes the diff between two scan results, matching instances by ID. +func Compare(old, new *models.ScanResult) *DiffResult { + oldMap := make(map[string]models.GPUInstance, len(old.Instances)) + for _, inst := range old.Instances { + oldMap[inst.InstanceID] = inst + } + + newMap := make(map[string]models.GPUInstance, len(new.Instances)) + for _, inst := range new.Instances { + newMap[inst.InstanceID] = inst + } + + result := &DiffResult{ + OldTimestamp: old.Timestamp.Format("2006-01-02 15:04 UTC"), + NewTimestamp: new.Timestamp.Format("2006-01-02 15:04 UTC"), + } + + // Find removed and changed + for id, oldInst := range oldMap { + newInst, exists := newMap[id] + if !exists { + result.Removed = append(result.Removed, oldInst) + continue + } + changes := diffInstance(oldInst, newInst) + if len(changes) > 0 { + result.Changed = append(result.Changed, InstanceDiff{ + InstanceID: id, + Old: oldInst, + New: newInst, + CostDelta: newInst.MonthlyCost - oldInst.MonthlyCost, + Changes: changes, + }) + } else { + result.UnchangedCount++ + } + } + + // Find added + for id, newInst := range newMap { + if _, exists := oldMap[id]; !exists { + result.Added = append(result.Added, newInst) + } + } + + // Cost summary + result.CostSummary = computeCostDelta(old, new, result) + + return result +} + +func diffInstance(old, new models.GPUInstance) []string { + var changes []string + + if old.InstanceType != new.InstanceType { + changes = append(changes, fmt.Sprintf("Instance type: %s → %s", old.InstanceType, new.InstanceType)) + } + if old.PricingModel != new.PricingModel { + changes = append(changes, fmt.Sprintf("Pricing: %s → %s", old.PricingModel, new.PricingModel)) + } + if old.MonthlyCost != new.MonthlyCost { + delta := new.MonthlyCost - old.MonthlyCost + changes = append(changes, fmt.Sprintf("Cost: $%.0f → $%.0f (%s/mo)", old.MonthlyCost, new.MonthlyCost, fmtDelta(delta))) + } + if old.State != new.State { + changes = append(changes, fmt.Sprintf("State: %s → %s", old.State, new.State)) + } + if old.GPUAllocated != new.GPUAllocated { + changes = append(changes, fmt.Sprintf("GPU allocated: %d → %d", old.GPUAllocated, new.GPUAllocated)) + } + if maxSeverityStr(old.WasteSignals) != maxSeverityStr(new.WasteSignals) { + oldSev := maxSeverityStr(old.WasteSignals) + newSev := maxSeverityStr(new.WasteSignals) + if oldSev == "" { + oldSev = "(none)" + } + if newSev == "" { + newSev = "(none)" + } + changes = append(changes, fmt.Sprintf("Severity: %s → %s", oldSev, newSev)) + } + + return changes +} + +func maxSeverityStr(signals []models.WasteSignal) string { + max := models.Severity("") + for _, s := range signals { + if s.Severity == models.SeverityCritical { + return string(models.SeverityCritical) + } + if s.Severity == models.SeverityWarning { + max = models.SeverityWarning + } + if s.Severity == models.SeverityInfo && max == "" { + max = models.SeverityInfo + } + } + return string(max) +} + +func fmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +func computeCostDelta(old, new *models.ScanResult, diff *DiffResult) CostDelta { + cd := CostDelta{ + OldTotalMonthlyCost: old.Summary.TotalMonthlyCost, + NewTotalMonthlyCost: new.Summary.TotalMonthlyCost, + CostChange: new.Summary.TotalMonthlyCost - old.Summary.TotalMonthlyCost, + OldTotalWaste: old.Summary.TotalEstimatedWaste, + NewTotalWaste: new.Summary.TotalEstimatedWaste, + WasteChange: new.Summary.TotalEstimatedWaste - old.Summary.TotalEstimatedWaste, + } + + for _, inst := range diff.Added { + cd.AddedCost += inst.MonthlyCost + } + for _, inst := range diff.Removed { + cd.RemovedSavings += inst.MonthlyCost + } + + return cd +} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go new file mode 100644 index 0000000..35d4f1f --- /dev/null +++ b/internal/diff/diff_test.go @@ -0,0 +1,219 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package diff + +import ( + "testing" + "time" + + "github.com/gpuaudit/cli/internal/models" +) + +func scanResult(instances ...models.GPUInstance) *models.ScanResult { + return &models.ScanResult{ + Timestamp: time.Date(2026, 4, 8, 12, 0, 0, 0, time.UTC), + Instances: instances, + Summary: models.ScanSummary{ + TotalInstances: len(instances), + TotalMonthlyCost: sumMonthlyCost(instances), + TotalEstimatedWaste: sumWaste(instances), + }, + } +} + +func sumMonthlyCost(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.MonthlyCost + } + return total +} + +func sumWaste(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.EstimatedSavings + } + return total +} + +func inst(id string, monthlyCost float64) models.GPUInstance { + return models.GPUInstance{ + InstanceID: id, + InstanceType: "g6e.16xlarge", + GPUModel: "L40S", + GPUCount: 1, + MonthlyCost: monthlyCost, + HourlyCost: monthlyCost / 730, + State: "ready", + Source: models.SourceK8sNode, + PricingModel: "on-demand", + } +} + +func TestCompare_AddedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 1 { + t.Fatalf("expected 1 added, got %d", len(result.Added)) + } + if result.Added[0].InstanceID != "i-bbb" { + t.Errorf("expected added instance i-bbb, got %s", result.Added[0].InstanceID) + } + if result.CostSummary.AddedCost != 3000 { + t.Errorf("expected added cost 3000, got %.0f", result.CostSummary.AddedCost) + } +} + +func TestCompare_RemovedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750)) + + result := Compare(old, new) + + if len(result.Removed) != 1 { + t.Fatalf("expected 1 removed, got %d", len(result.Removed)) + } + if result.Removed[0].InstanceID != "i-bbb" { + t.Errorf("expected removed instance i-bbb, got %s", result.Removed[0].InstanceID) + } + if result.CostSummary.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", result.CostSummary.RemovedSavings) + } +} + +func TestCompare_CostChange(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 4200)) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + if result.Changed[0].CostDelta != -2550 { + t.Errorf("expected cost delta -2550, got %.0f", result.Changed[0].CostDelta) + } + found := false + for _, c := range result.Changed[0].Changes { + if c == "Cost: $6750 → $4200 (-$2550/mo)" { + found = true + } + } + if !found { + t.Errorf("expected cost change string, got %v", result.Changed[0].Changes) + } +} + +func TestCompare_AllFieldChanges(t *testing.T) { + oldInst := inst("i-aaa", 6750) + oldInst.InstanceType = "g6e.16xlarge" + oldInst.PricingModel = "on-demand" + oldInst.State = "ready" + oldInst.GPUAllocated = 0 + oldInst.WasteSignals = []models.WasteSignal{{Severity: models.SeverityCritical}} + + newInst := inst("i-aaa", 4200) + newInst.InstanceType = "g6e.12xlarge" + newInst.PricingModel = "reserved" + newInst.State = "not-ready" + newInst.GPUAllocated = 2 + newInst.WasteSignals = nil + + old := scanResult(oldInst) + new := scanResult(newInst) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + + changes := result.Changed[0].Changes + expected := []string{ + "Instance type: g6e.16xlarge → g6e.12xlarge", + "Pricing: on-demand → reserved", + "Cost: $6750 → $4200 (-$2550/mo)", + "State: ready → not-ready", + "GPU allocated: 0 → 2", + "Severity: critical → (none)", + } + if len(changes) != len(expected) { + t.Fatalf("expected %d changes, got %d: %v", len(expected), len(changes), changes) + } + for i, exp := range expected { + if changes[i] != exp { + t.Errorf("change[%d]: expected %q, got %q", i, exp, changes[i]) + } + } +} + +func TestCompare_UnchangedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 0 { + t.Errorf("expected 0 added, got %d", len(result.Added)) + } + if len(result.Removed) != 0 { + t.Errorf("expected 0 removed, got %d", len(result.Removed)) + } + if len(result.Changed) != 0 { + t.Errorf("expected 0 changed, got %d", len(result.Changed)) + } + if result.UnchangedCount != 2 { + t.Errorf("expected 2 unchanged, got %d", result.UnchangedCount) + } +} + +func TestCompare_CostSummary(t *testing.T) { + oldA := inst("i-aaa", 6750) + oldA.EstimatedSavings = 6750 + oldB := inst("i-bbb", 3000) + + newA := inst("i-aaa", 6750) + newA.EstimatedSavings = 6750 + newC := inst("i-ccc", 2000) + + old := scanResult(oldA, oldB) + new := scanResult(newA, newC) + + result := Compare(old, new) + + cs := result.CostSummary + if cs.OldTotalMonthlyCost != 9750 { + t.Errorf("expected old total 9750, got %.0f", cs.OldTotalMonthlyCost) + } + if cs.NewTotalMonthlyCost != 8750 { + t.Errorf("expected new total 8750, got %.0f", cs.NewTotalMonthlyCost) + } + if cs.CostChange != -1000 { + t.Errorf("expected cost change -1000, got %.0f", cs.CostChange) + } + if cs.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", cs.RemovedSavings) + } + if cs.AddedCost != 2000 { + t.Errorf("expected added cost 2000, got %.0f", cs.AddedCost) + } +} + +func TestCompare_EmptyScans(t *testing.T) { + old := scanResult() + new := scanResult() + + result := Compare(old, new) + + if len(result.Added) != 0 || len(result.Removed) != 0 || len(result.Changed) != 0 { + t.Errorf("expected no changes for empty scans") + } + if result.UnchangedCount != 0 { + t.Errorf("expected 0 unchanged, got %d", result.UnchangedCount) + } +} From cc633186d3953bf2ab6935b67c2295d20ea8fb1e Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:04:33 +0100 Subject: [PATCH 03/19] Add diff table and JSON output formatters --- internal/output/diff.go | 145 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 internal/output/diff.go diff --git a/internal/output/diff.go b/internal/output/diff.go new file mode 100644 index 0000000..2bd5753 --- /dev/null +++ b/internal/output/diff.go @@ -0,0 +1,145 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package output + +import ( + "encoding/json" + "fmt" + "io" + "sort" + "strings" + + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" +) + +// FormatDiffTable writes a human-readable diff report. +func FormatDiffTable(w io.Writer, d *diff.DiffResult) { + fmt.Fprintf(w, "\n gpuaudit diff — %s → %s\n\n", d.OldTimestamp, d.NewTimestamp) + + cs := d.CostSummary + + oldCount := len(d.Removed) + len(d.Changed) + d.UnchangedCount + newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount + + // Cost summary box + fmt.Fprintf(w, " ┌──────────────────────────────────────────────────────────┐\n") + fmt.Fprintf(w, " │ Cost Delta │\n") + fmt.Fprintf(w, " ├──────────────────────────────────────────────────────────┤\n") + fmt.Fprintf(w, " │ Monthly spend: $%-9.0f → $%-9.0f (%s)%s│\n", + cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, + diffFmtDelta(cs.CostChange), diffPad(cs.CostChange)) + fmt.Fprintf(w, " │ Estimated waste: $%-9.0f → $%-9.0f (%s)%s│\n", + cs.OldTotalWaste, cs.NewTotalWaste, + diffFmtDelta(cs.WasteChange), diffPad(cs.WasteChange)) + fmt.Fprintf(w, " │ Instances: %-3d → %-3d (-%d removed, +%d added)%s│\n", + oldCount, newCount, len(d.Removed), len(d.Added), + diffPadInstances(oldCount, newCount, len(d.Removed), len(d.Added))) + fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n") + + // Removed + if len(d.Removed) > 0 { + sortInstancesByCost(d.Removed) + fmt.Fprintf(w, "\n REMOVED — %d instance(s), -$%.0f/mo\n\n", len(d.Removed), cs.RemovedSavings) + printDiffInstanceTable(w, d.Removed) + } + + // Added + if len(d.Added) > 0 { + sortInstancesByCost(d.Added) + fmt.Fprintf(w, "\n ADDED — %d instance(s), +$%.0f/mo\n\n", len(d.Added), cs.AddedCost) + printDiffInstanceTable(w, d.Added) + } + + // Changed + if len(d.Changed) > 0 { + fmt.Fprintf(w, "\n CHANGED — %d instance(s)\n\n", len(d.Changed)) + fmt.Fprintf(w, " %-36s %s\n", "Instance", "Change") + fmt.Fprintf(w, " %s %s\n", strings.Repeat("─", 36), strings.Repeat("─", 50)) + for _, c := range d.Changed { + name := c.New.Name + if name == "" { + name = c.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + for i, change := range c.Changes { + if i == 0 { + fmt.Fprintf(w, " %-36s %s\n", name, change) + } else { + fmt.Fprintf(w, " %-36s %s\n", "", change) + } + } + } + fmt.Fprintln(w) + } + + // Unchanged + if d.UnchangedCount > 0 { + fmt.Fprintf(w, " UNCHANGED — %d instance(s)\n\n", d.UnchangedCount) + } +} + +func printDiffInstanceTable(w io.Writer, instances []models.GPUInstance) { + fmt.Fprintf(w, " %-36s %-26s %10s\n", "Instance", "Type", "Monthly") + fmt.Fprintf(w, " %s %s %s\n", + strings.Repeat("─", 36), strings.Repeat("─", 26), strings.Repeat("─", 10)) + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + fmt.Fprintf(w, " %-36s %-26s $%9.0f\n", name, typeDesc, inst.MonthlyCost) + } +} + +func sortInstancesByCost(instances []models.GPUInstance) { + sort.Slice(instances, func(i, j int) bool { + return instances[i].MonthlyCost > instances[j].MonthlyCost + }) +} + +func diffFmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +// diffPad returns spaces to align the summary box closing border. +func diffPad(delta float64) string { + s := diffFmtDelta(delta) + // The content before the delta is ~44 chars, delta is variable, need to fill to col 59 + used := 44 + len(s) + 2 // +2 for parens + target := 59 + if used >= target { + return "" + } + return strings.Repeat(" ", target-used) +} + +func diffPadInstances(oldCount, newCount, removed, added int) string { + content := fmt.Sprintf(" │ Instances: %-3d → %-3d (-%d removed, +%d added)", + oldCount, newCount, removed, added) + if len(content) >= 59 { + return "" + } + return strings.Repeat(" ", 59-len(content)) +} + +// FormatDiffJSON writes the diff result as pretty-printed JSON. +func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(d) +} From 68abdfafa9c468cc9e3ca51d79393494808fd49c Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:05:37 +0100 Subject: [PATCH 04/19] Add diff subcommand to compare two scan results gpuaudit diff old.json new.json [--format table|json] Closes #5 --- cmd/gpuaudit/main.go | 55 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index ce8d61e..217a100 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -13,12 +13,13 @@ import ( "github.com/spf13/cobra" - "github.com/gpuaudit/cli/internal/models" "github.com/gpuaudit/cli/internal/analysis" - awsprovider "github.com/gpuaudit/cli/internal/providers/aws" - k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" "github.com/gpuaudit/cli/internal/output" "github.com/gpuaudit/cli/internal/pricing" + awsprovider "github.com/gpuaudit/cli/internal/providers/aws" + k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" ) var version = "dev" @@ -53,6 +54,17 @@ var ( scanMinUptimeDays int ) +// --- diff command --- + +var diffFormat string + +var diffCmd = &cobra.Command{ + Use: "diff ", + Short: "Compare two scan results and show what changed", + Args: cobra.ExactArgs(2), + RunE: runDiff, +} + var scanCmd = &cobra.Command{ Use: "scan", Short: "Scan AWS account for GPU waste", @@ -74,7 +86,10 @@ func init() { scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") + diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") + rootCmd.AddCommand(scanCmd) + rootCmd.AddCommand(diffCmd) rootCmd.AddCommand(pricingCmd) rootCmd.AddCommand(iamPolicyCmd) rootCmd.AddCommand(versionCmd) @@ -144,6 +159,40 @@ func runScan(cmd *cobra.Command, args []string) error { return nil } +func runDiff(cmd *cobra.Command, args []string) error { + old, err := loadScanResult(args[0]) + if err != nil { + return fmt.Errorf("loading old scan: %w", err) + } + new, err := loadScanResult(args[1]) + if err != nil { + return fmt.Errorf("loading new scan: %w", err) + } + + result := diff.Compare(old, new) + + switch strings.ToLower(diffFormat) { + case "json": + return output.FormatDiffJSON(os.Stdout, result) + default: + output.FormatDiffTable(os.Stdout, result) + } + + return nil +} + +func loadScanResult(path string) (*models.ScanResult, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var result models.ScanResult + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + return &result, nil +} + // --- pricing command --- var pricingGPU string From de3487f6dec5786b32f7610dad54f5779972df15 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 09:15:45 +0100 Subject: [PATCH 05/19] Fix box alignment in diff table output --- internal/output/diff.go | 55 ++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/internal/output/diff.go b/internal/output/diff.go index 2bd5753..db2f7c9 100644 --- a/internal/output/diff.go +++ b/internal/output/diff.go @@ -24,19 +24,18 @@ func FormatDiffTable(w io.Writer, d *diff.DiffResult) { newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount // Cost summary box - fmt.Fprintf(w, " ┌──────────────────────────────────────────────────────────┐\n") - fmt.Fprintf(w, " │ Cost Delta │\n") - fmt.Fprintf(w, " ├──────────────────────────────────────────────────────────┤\n") - fmt.Fprintf(w, " │ Monthly spend: $%-9.0f → $%-9.0f (%s)%s│\n", - cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, - diffFmtDelta(cs.CostChange), diffPad(cs.CostChange)) - fmt.Fprintf(w, " │ Estimated waste: $%-9.0f → $%-9.0f (%s)%s│\n", - cs.OldTotalWaste, cs.NewTotalWaste, - diffFmtDelta(cs.WasteChange), diffPad(cs.WasteChange)) - fmt.Fprintf(w, " │ Instances: %-3d → %-3d (-%d removed, +%d added)%s│\n", - oldCount, newCount, len(d.Removed), len(d.Added), - diffPadInstances(oldCount, newCount, len(d.Removed), len(d.Added))) - fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n") + boxWidth := 58 // inner width between │ markers + boxLine := strings.Repeat("─", boxWidth) + fmt.Fprintf(w, " ┌%s┐\n", boxLine) + writeBoxLine(w, "Cost Delta", boxWidth) + fmt.Fprintf(w, " ├%s┤\n", boxLine) + writeBoxLine(w, fmt.Sprintf("Monthly spend: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, diffFmtDelta(cs.CostChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Estimated waste: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalWaste, cs.NewTotalWaste, diffFmtDelta(cs.WasteChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Instances: %d → %d (-%d removed, +%d added)", + oldCount, newCount, len(d.Removed), len(d.Added)), boxWidth) + fmt.Fprintf(w, " └%s┘\n", boxLine) // Removed if len(d.Removed) > 0 { @@ -109,6 +108,15 @@ func sortInstancesByCost(instances []models.GPUInstance) { }) } +func writeBoxLine(w io.Writer, content string, width int) { + // Pad content to fill the box width (with 2-char margin on each side) + inner := width - 4 // 2 spaces on each side + if len(content) > inner { + content = content[:inner] + } + fmt.Fprintf(w, " │ %-*s │\n", inner, content) +} + func diffFmtDelta(v float64) string { if v >= 0 { return fmt.Sprintf("+$%.0f", v) @@ -116,27 +124,6 @@ func diffFmtDelta(v float64) string { return fmt.Sprintf("-$%.0f", -v) } -// diffPad returns spaces to align the summary box closing border. -func diffPad(delta float64) string { - s := diffFmtDelta(delta) - // The content before the delta is ~44 chars, delta is variable, need to fill to col 59 - used := 44 + len(s) + 2 // +2 for parens - target := 59 - if used >= target { - return "" - } - return strings.Repeat(" ", target-used) -} - -func diffPadInstances(oldCount, newCount, removed, added int) string { - content := fmt.Sprintf(" │ Instances: %-3d → %-3d (-%d removed, +%d added)", - oldCount, newCount, removed, added) - if len(content) >= 59 { - return "" - } - return strings.Repeat(" ", 59-len(content)) -} - // FormatDiffJSON writes the diff result as pretty-printed JSON. func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { enc := json.NewEncoder(w) From 7f5cfb3c9effc8f91789608188e91191d5b9b8d6 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 12:33:54 +0100 Subject: [PATCH 06/19] Fix misleading idle duration in K8s GPU node recommendations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recommendation said "No GPU pods scheduled for X days" but X was the node's total uptime, not the idle duration. We don't know when the node became idle — only that it currently has zero GPU pods. Changed wording to "Node up X days with 0 GPU pods scheduled." --- internal/analysis/rules.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index f91bcbe..a782583 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -317,11 +317,11 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { Type: "idle", Severity: models.SeverityCritical, Confidence: 0.9, - Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs for %.0f+ hours.", inst.GPUCount, inst.UptimeHours), + Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs. Node up for %d days.", inst.GPUCount, int(inst.UptimeHours/24)), }) inst.Recommendations = append(inst.Recommendations, models.Recommendation{ Action: models.ActionTerminate, - Description: fmt.Sprintf("No GPU pods scheduled on this node for %d days. Remove from node pool or scale down.", int(inst.UptimeHours/24)), + Description: fmt.Sprintf("Node up %d days with 0 GPU pods scheduled. Remove from node pool or scale down.", int(inst.UptimeHours/24)), CurrentMonthlyCost: inst.MonthlyCost, MonthlySavings: inst.MonthlyCost, SavingsPercent: 100, From 39a49262dfb3b877ec57ff3f07826ba18d7ecb59 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 14:36:26 +0100 Subject: [PATCH 07/19] Update README with K8s scanning, diff command, and current output format --- README.md | 120 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index f3c05dc..8738521 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,39 @@ # gpuaudit -Scan your AWS account for GPU waste and get actionable recommendations to cut your cloud spend. +Scan your cloud for GPU waste and get actionable recommendations to cut your spend. ``` -$ gpuaudit scan --profile ml-prod +$ gpuaudit scan --skip-eks - GPU Fleet Summary - Total GPU instances: 14 - Total monthly GPU spend: $47,832 - Estimated monthly waste: $18,240 (38%) + Found 103 GPU nodes across 111 nodes in ml-prod-iad - CRITICAL (3 instances, $8,940/mo potential savings) + gpuaudit — GPU Cost Audit for AWS + Account: 123456789012 | Regions: us-east-1 | Duration: 4.2s - i-0a1b2c3d4e g5.12xlarge (4x A10G) $4,380/mo Idle — no activity for 18 days → terminate - i-9f8e7d6c5b p4d.24xlarge (8x A100) $23,652/mo Idle — <1% CPU for 6 days → terminate - sagemaker:asr ml.g6.48xlarge (8x L40S) $9,490/mo GPU util avg 8% → downsize to ml.g5.xlarge + ┌──────────────────────────────────────────────────────────┐ + │ GPU Fleet Summary │ + ├──────────────────────────────────────────────────────────┤ + │ Total GPU instances: 103 │ + │ Total monthly GPU spend: $365155 │ + │ Estimated monthly waste: $23408 ( 6%) │ + └──────────────────────────────────────────────────────────┘ + + CRITICAL — 4 instance(s), $21728/mo potential savings + + Instance Type Monthly Signal Recommendation + ──────────────────────────────────── ────────────────────────── ──────── ──────────────── ────────────────────────────────────────────── + ml-prod-iad/ip-10-15-255-248 g6e.16xlarge (1× L40S) $ 6752 idle Node up 13 days with 0 GPU pods scheduled. + ml-prod-iad/ip-10-22-250-15 g6e.16xlarge (1× L40S) $ 6752 idle Node up 1 days with 0 GPU pods scheduled. + ... ``` +## What it scans + +- **EC2** — GPU instances (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) with CloudWatch metrics +- **SageMaker** — Endpoints with GPU utilization and invocation metrics +- **EKS** — Managed GPU node groups via the AWS EKS API +- **Kubernetes** — GPU nodes and pod allocation via the Kubernetes API (Karpenter, self-managed, any CNI) + ## What it detects - **Idle GPU instances** — running but doing nothing (low CPU + near-zero network for 24+ hours) @@ -25,6 +42,7 @@ $ gpuaudit scan --profile ml-prod - **Stale instances** — non-production instances running 90+ days - **SageMaker low utilization** — endpoints with <10% GPU utilization - **SageMaker oversized** — endpoints using <30% GPU memory on multi-GPU instances +- **K8s unallocated GPUs** — nodes with GPU capacity but no pods requesting GPUs ## Install @@ -36,7 +54,7 @@ Or build from source: ```bash git clone https://github.com/gpuaudit/cli.git -cd gpuaudit +cd cli go build -o gpuaudit ./cmd/gpuaudit ``` @@ -49,22 +67,57 @@ gpuaudit scan # Specific profile and region gpuaudit scan --profile production --region us-east-1 +# Kubernetes cluster scan (uses KUBECONFIG or ~/.kube/config) +gpuaudit scan --skip-eks + +# Specific kubeconfig and context +gpuaudit scan --kubeconfig ~/.kube/config --kube-context ml-prod-iad + # JSON output for automation -gpuaudit scan --format json --output report.json +gpuaudit scan --format json -o report.json -# Markdown for docs/PRs -gpuaudit scan --format markdown +# Compare two scans to see what changed +gpuaudit diff old-report.json new-report.json # Slack Block Kit payload (pipe to webhook) -gpuaudit scan --format slack --output - | curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK - -# Skip CloudWatch metrics (faster, less accurate) -gpuaudit scan --skip-metrics +gpuaudit scan --format slack -o - | \ + curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK -# Skip SageMaker scanning +# Skip specific scanners +gpuaudit scan --skip-metrics # faster, less accurate gpuaudit scan --skip-sagemaker +gpuaudit scan --skip-eks # skip AWS EKS API (use --skip-k8s for Kubernetes API) +gpuaudit scan --skip-k8s ``` +## Comparing scans + +Save scan results as JSON, then diff them later: + +```bash +gpuaudit scan --format json -o scan-apr-08.json +# ... time passes, changes happen ... +gpuaudit scan --format json -o scan-apr-15.json +gpuaudit diff scan-apr-08.json scan-apr-15.json +``` + +``` + gpuaudit diff — 2026-04-08 12:00 UTC → 2026-04-15 12:00 UTC + + ┌──────────────────────────────────────────────────────────┐ + │ Cost Delta │ + ├──────────────────────────────────────────────────────────┤ + │ Monthly spend: $372000 → $365155 (-$6845) │ + │ Estimated waste: $189000 → $23408 (-$165592) │ + │ Instances: 116 → 103 (-13 removed, +0 added) │ + └──────────────────────────────────────────────────────────┘ + + REMOVED — 13 instance(s), -$6845/mo + ... +``` + +Matches instances by ID. Reports added, removed, and changed instances with per-field diffs (instance type, pricing model, cost, state, GPU allocation, waste severity). + ## IAM permissions gpuaudit is read-only. It never modifies your infrastructure. Generate the minimal IAM policy: @@ -73,7 +126,7 @@ gpuaudit is read-only. It never modifies your infrastructure. Generate the minim gpuaudit iam-policy ``` -This outputs a JSON policy requiring only `Describe*`, `List*`, `Get*` permissions for EC2, SageMaker, CloudWatch, Cost Explorer, and Pricing APIs. +For Kubernetes scanning, gpuaudit needs `get`/`list` on `nodes` and `pods` cluster-wide. ## GPU pricing reference @@ -83,8 +136,7 @@ gpuaudit pricing # Filter by GPU model gpuaudit pricing --gpu H100 -gpuaudit pricing --gpu A10G -gpuaudit pricing --gpu T4 +gpuaudit pricing --gpu L4 ``` ## Output formats @@ -92,18 +144,17 @@ gpuaudit pricing --gpu T4 | Format | Flag | Use case | |---|---|---| | Table | `--format table` (default) | Terminal viewing | -| JSON | `--format json` | Automation, CI/CD pipelines | +| JSON | `--format json` | Automation, CI/CD, `gpuaudit diff` | | Markdown | `--format markdown` | PRs, wikis, docs | | Slack | `--format slack` | Slack webhook integration | ## How it works -1. **Discovery** — Scans EC2 and SageMaker across multiple regions for GPU instance families (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) +1. **Discovery** — Scans EC2, SageMaker, EKS node groups, and Kubernetes API across multiple regions for GPU resources 2. **Metrics** — Collects 7-day CloudWatch metrics: CPU, network I/O for EC2; GPU utilization, GPU memory, invocations for SageMaker -3. **Analysis** — Applies 6 waste detection rules with severity levels (critical/warning) -4. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings - -Regions scanned by default: us-east-1, us-east-2, us-west-2, eu-west-1, eu-west-2, eu-central-1, ap-southeast-1, ap-northeast-1, ap-south-1. +3. **K8s allocation** — Lists pods requesting `nvidia.com/gpu` resources and maps them to nodes +4. **Analysis** — Applies 7 waste detection rules with severity levels (critical/warning/info) +5. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings ## Project structure @@ -113,21 +164,22 @@ gpuaudit/ ├── internal/ │ ├── models/ Core data types (GPUInstance, WasteSignal, Recommendation) │ ├── pricing/ Bundled GPU pricing database (40+ instance types) -│ ├── analysis/ Waste detection rules engine -│ ├── output/ Formatters (table, JSON, markdown, Slack) -│ └── providers/aws/ EC2, SageMaker, CloudWatch, scanner orchestrator +│ ├── analysis/ Waste detection rules engine (7 rules) +│ ├── diff/ Scan comparison logic +│ ├── output/ Formatters (table, JSON, markdown, Slack, diff) +│ └── providers/ +│ ├── aws/ EC2, SageMaker, EKS, CloudWatch, Cost Explorer +│ └── k8s/ Kubernetes API GPU node/pod discovery └── LICENSE Apache 2.0 ``` ## Roadmap -- [ ] AWS Cost Explorer integration (actual vs projected spend) -- [ ] EKS GPU pod discovery +- [ ] DCGM GPU metrics via Kubernetes (actual GPU utilization, not just allocation) - [ ] SageMaker training job analysis - [ ] Multi-account (AWS Organizations) scanning - [ ] GCP + Azure support - [ ] GitHub Action for scheduled scans -- [ ] Historical scan comparison (`gpuaudit diff`) ## License From 60cf64467d31f7261df166af1810a291413deb4a Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:03:27 +0100 Subject: [PATCH 08/19] Add multi-target scanning design spec Covers CLI flags (--targets, --role, --org), architecture for parallel cross-account scanning via STS AssumeRole, output changes with per-target sub-summaries, and IAM role setup docs (Terraform + CloudFormation). --- ...2026-04-18-multi-target-scanning-design.md | 374 ++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 docs/specs/2026-04-18-multi-target-scanning-design.md diff --git a/docs/specs/2026-04-18-multi-target-scanning-design.md b/docs/specs/2026-04-18-multi-target-scanning-design.md new file mode 100644 index 0000000..9c2fd34 --- /dev/null +++ b/docs/specs/2026-04-18-multi-target-scanning-design.md @@ -0,0 +1,374 @@ +# Multi-Target Scanning + +**Date:** April 18, 2026 +**Status:** Draft + +--- + +## Summary + +Add the ability to scan multiple AWS accounts (and eventually GCP projects / Azure subscriptions) in a single `gpuaudit scan` invocation. Uses STS AssumeRole to obtain credentials for each target, scans them all in parallel, and merges results into a single flat output with per-target sub-summaries. + +Zero breaking changes — existing single-account behavior is the default. + +--- + +## CLI Interface + +### New flags on `gpuaudit scan` + +| Flag | Type | Description | +|------|------|-------------| +| `--targets` | `[]string` | Comma-separated list of account IDs to scan | +| `--role` | `string` | IAM role name to assume in each target (required with `--targets` or `--org`) | +| `--org` | `bool` | Auto-discover all accounts from AWS Organizations | +| `--external-id` | `string` | STS external ID for cross-account role assumption (optional) | +| `--skip-self` | `bool` | Exclude the caller's own account from the scan | + +### Constraints + +- `--targets` and `--org` are mutually exclusive. +- `--role` is required when `--targets` or `--org` is set. +- No `--targets` or `--org` means scan the caller's account only (current behavior, no changes). +- The caller's own account is included by default unless `--skip-self` is set. + +### Examples + +```bash +# Current behavior (unchanged) +gpuaudit scan + +# Scan 3 specific accounts +gpuaudit scan --targets 111111111111,222222222222,333333333333 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Org scan, exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID for extra security +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Flag naming rationale + +Flags use provider-neutral names (`--targets` not `--accounts`, `--role` not `--assume-role`) so that when GCP and Azure support lands, the same flags work: targets are project IDs or subscription IDs, role is a service account or principal name. No renaming, no backward-compatibility concerns. + +--- + +## Architecture + +### New file: `internal/providers/aws/multiaccount.go` + +Contains: + +- `Target` struct: `{AccountID string, Config aws.Config}` +- `ResolveTargets(ctx, cfg, opts) ([]Target, []TargetError)`: + - No `--targets`/`--org`: returns caller's account with existing config. + - `--targets`: calls `sts:AssumeRole` for each account ID, returns credentials. Failed assumptions are collected as `TargetError`, not fatal. + - `--org`: calls `organizations:ListAccounts`, filters to active accounts, then assumes role in each. + - Caller's own account is included (with original config, no AssumeRole needed) unless `--skip-self`. +- `TargetError` struct: `{AccountID string, Err error}` + +### Changes to `ScanOptions` + +```go +type ScanOptions struct { + // ... existing fields ... + Targets []string // account IDs to scan + Role string // role name to assume + ExternalID string // STS external ID + OrgScan bool // auto-discover from Organizations + SkipSelf bool // exclude caller's account +} +``` + +### Changes to `Scan()` + +Current flow: +``` +load config → get account ID → scan regions in parallel → merge → analyze → output +``` + +New flow: +``` +load config → ResolveTargets() → for each target (parallel): + for each region (parallel): + scanRegion(ctx, target.Config, target.AccountID, region, opts) +→ merge all instances into flat list +→ filter, analyze, enrich (unchanged) +→ BuildSummary (global + per-target sub-summaries) +→ output +``` + +All targets are scanned in parallel. Within each target, all regions are scanned in parallel (same as today). + +### Error handling: best-effort + +- `ResolveTargets` returns both successful targets and a list of `TargetError`s. +- Scan continues for all resolvable targets. +- Per-region errors within a target are handled as today (warn and continue). +- Target-level errors are surfaced in the output (see Output section). +- Exit code: 0 = success, non-zero if all targets failed. + +### Unchanged components + +- Analysis rules — operate per-instance, already provider-agnostic. +- Diff command — matches by `instance_id`, globally unique across accounts. +- `GPUInstance` model — already has `AccountID` field. +- Pricing database — account-independent. + +--- + +## Model Changes + +### `ScanResult` + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` // caller's account (kept for backward compat) + Targets []string `json:"targets,omitempty"` // NEW: all scanned target IDs + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` // NEW: per-target breakdown + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` // NEW: failed targets +} + +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +New fields use `omitempty` — single-account scans produce identical JSON to today. + +--- + +## Output Changes + +### Table + +When multiple targets are present, two additions: + +1. **"By Target" summary table** after the global summary: + +``` + By Target + ┌──────────────┬───────────┬───────────┬───────────┬───────┐ + │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │ + ├──────────────┼───────────┼───────────┼───────────┼───────┤ + │ 111111111111 │ 31 │ $142,000 │ $38,000 │ 27% │ + │ 222222222222 │ 12 │ $35,400 │ $4,200 │ 12% │ + └──────────────┴───────────┴───────────┴───────────┴───────┘ +``` + +2. **"Target" column** in instance detail tables. + +Single-target scans look identical to today. + +### JSON + +New `targets`, `target_summaries`, and `target_errors` fields as shown in the model above. Omitted when empty. + +### Markdown + +Per-target summary section added when multiple targets present. + +### Slack + +Per-target summary block added when multiple targets present. + +### Errors + +When targets fail, a warnings section appears in all formats: + +``` + Warnings + ✗ 444444444444 — AssumeRole failed: AccessDenied + ✗ 555555555555 — role "gpuaudit-reader" not found in account +``` + +--- + +## IAM Policy Updates + +### `gpuaudit iam-policy` additions + +Add two new statements to the generated policy: + +```json +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader" +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*" +} +``` + +These are printed as a separate "Multi-Account Permissions" section in the `iam-policy` output, with a comment explaining they're only needed for `--targets` or `--org` scanning. Always included in the output — users can ignore them if they only scan a single account. + +--- + +## Cross-Account Role Setup + +### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +variable "external_id" { + description = "External ID for AssumeRole (optional but recommended)" + type = string + default = "" +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + Condition = var.external_id != "" ? { + StringEquals = { "sts:ExternalId" = var.external_id } + } : {} + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Sid = "EC2ReadOnly" + Effect = "Allow" + Action = ["ec2:DescribeInstances", "ec2:DescribeInstanceTypes", "ec2:DescribeRegions"] + Resource = "*" + }, + { + Sid = "SageMakerReadOnly" + Effect = "Allow" + Action = ["sagemaker:ListEndpoints", "sagemaker:DescribeEndpoint", "sagemaker:DescribeEndpointConfig"] + Resource = "*" + }, + { + Sid = "EKSReadOnly" + Effect = "Allow" + Action = ["eks:ListClusters", "eks:ListNodegroups", "eks:DescribeNodegroup"] + Resource = "*" + }, + { + Sid = "CloudWatchReadOnly" + Effect = "Allow" + Action = ["cloudwatch:GetMetricData", "cloudwatch:GetMetricStatistics", "cloudwatch:ListMetrics"] + Resource = "*" + }, + { + Sid = "CostExplorerReadOnly" + Effect = "Allow" + Action = ["ce:GetCostAndUsage", "ce:GetReservationUtilization", "ce:GetSavingsPlansUtilization"] + Resource = "*" + }, + { + Sid = "PricingReadOnly" + Effect = "Allow" + Action = ["pricing:GetProducts"] + Resource = "*" + } + ] + }) +} +``` + +### CloudFormation (for StackSet deployment across all accounts) + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Description: gpuaudit cross-account reader role + +Parameters: + ManagementAccountId: + Type: String + Description: Account ID where gpuaudit runs + ExternalId: + Type: String + Description: External ID for AssumeRole + Default: "" + +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` + +Recommended deployment: use CloudFormation StackSets to deploy the role to all member accounts from the management account. + +--- + +## Testing + +- **Unit tests for `ResolveTargets`**: mock STS and Organizations clients, verify correct target list for each mode (explicit, org, skip-self, mixed failures). +- **Unit tests for `BuildSummary`**: verify per-target summaries compute correctly with instances from multiple accounts. +- **Unit tests for output formatters**: verify "By Target" table and Target column appear only when multiple targets present. +- **Integration test pattern**: test the full `Scan` flow with mocked AWS clients for 2-3 accounts, verify merged output. From 0330be76ccaf21dee8b8c7914371aaadb8355f7e Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:11:03 +0100 Subject: [PATCH 09/19] Add multi-target scanning implementation plan --- .../plans/2026-04-18-multi-target-scanning.md | 1537 +++++++++++++++++ 1 file changed, 1537 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-18-multi-target-scanning.md diff --git a/docs/superpowers/plans/2026-04-18-multi-target-scanning.md b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md new file mode 100644 index 0000000..ebde5e3 --- /dev/null +++ b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md @@ -0,0 +1,1537 @@ +# Multi-Target Scanning Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Enable gpuaudit to scan multiple AWS accounts in a single invocation via STS AssumeRole, with optional Organizations auto-discovery. + +**Architecture:** New `multiaccount.go` handles target resolution (explicit list or Organizations API) and credential assumption. The existing `Scan()` function is refactored to accept multiple targets and scan them all in parallel. Output formatters gain per-target summary sections when multiple targets are present. All new fields use `omitempty` so single-account scans produce identical output to today. + +**Tech Stack:** Go 1.24, AWS SDK v2 (STS, Organizations), cobra CLI, standard library testing + +--- + +## File Map + +| File | Action | Responsibility | +|------|--------|---------------| +| `internal/providers/aws/multiaccount.go` | Create | `Target` struct, `ResolveTargets()`, `TargetError` type, STS AssumeRole + Organizations list | +| `internal/providers/aws/multiaccount_test.go` | Create | Tests for `ResolveTargets()` with mock STS/Org clients | +| `internal/models/models.go` | Modify | Add `TargetSummary`, `TargetErrorInfo` types; add new fields to `ScanResult` | +| `internal/providers/aws/scanner.go` | Modify | Refactor `Scan()` to use `ResolveTargets()` and scan all targets in parallel | +| `cmd/gpuaudit/main.go` | Modify | Add `--targets`, `--role`, `--org`, `--external-id`, `--skip-self` flags; wire into `ScanOptions` | +| `internal/providers/aws/summary.go` | Create | Extract `BuildSummary` from scanner.go, add `BuildTargetSummaries()` | +| `internal/providers/aws/summary_test.go` | Create | Tests for per-target summary computation | +| `internal/output/table.go` | Modify | Add "By Target" summary table and "Target" column when multiple targets | +| `internal/output/markdown.go` | Modify | Add per-target summary section when multiple targets | +| `internal/output/slack.go` | Modify | Add per-target summary block when multiple targets | +| `go.mod` | Modify | Add `organizations` SDK dependency | + +--- + +### Task 1: Add model types for multi-target results + +**Files:** +- Modify: `internal/models/models.go` + +- [ ] **Step 1: Add `TargetSummary` and `TargetErrorInfo` types and new `ScanResult` fields** + +Add to `internal/models/models.go` after the `ScanSummary` struct: + +```go +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +Add three new fields to `ScanResult`: + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` +} +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success (new types are additive, omitempty means no output change) + +- [ ] **Step 3: Run existing tests to confirm nothing broke** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/models/models.go +git commit -m "Add TargetSummary and TargetErrorInfo model types for multi-target scanning" +``` + +--- + +### Task 2: Extract `BuildSummary` and add `BuildTargetSummaries` + +**Files:** +- Create: `internal/providers/aws/summary.go` +- Create: `internal/providers/aws/summary_test.go` +- Modify: `internal/providers/aws/scanner.go` (remove `BuildSummary` — it moves to summary.go) + +- [ ] **Step 1: Write the failing test for `BuildTargetSummaries`** + +Create `internal/providers/aws/summary_test.go`: + +```go +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + // Find each target + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} +``` + +- [ ] **Step 2: Run the test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestBuildTargetSummaries -v` +Expected: FAIL (function not defined) + +- [ ] **Step 3: Create `summary.go` with `BuildSummary` (moved from scanner.go) and `BuildTargetSummaries`** + +Create `internal/providers/aws/summary.go`: + +```go +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + maxSeverity := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSeverity = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { + maxSeverity = models.SeverityWarning + } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { + maxSeverity = models.SeverityInfo + } + } + + switch maxSeverity { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + maxSev := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSev = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { + maxSev = models.SeverityWarning + } + } + switch maxSev { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} +``` + +- [ ] **Step 4: Remove `BuildSummary` and `matchesExcludeTags` from `scanner.go`** + +In `internal/providers/aws/scanner.go`, delete the `BuildSummary` function (lines 235-272) and keep `matchesExcludeTags`. The `BuildSummary` is now in `summary.go`. No import changes needed since both files are in the same package. + +- [ ] **Step 5: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass, including the new `TestBuildTargetSummaries_*` tests + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/summary.go internal/providers/aws/summary_test.go internal/providers/aws/scanner.go +git commit -m "Extract BuildSummary to summary.go and add BuildTargetSummaries" +``` + +--- + +### Task 3: Implement `ResolveTargets` with STS AssumeRole + +**Files:** +- Create: `internal/providers/aws/multiaccount.go` +- Create: `internal/providers/aws/multiaccount_test.go` +- Modify: `go.mod` (add organizations dependency) + +- [ ] **Step 1: Add the Organizations SDK dependency** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go get github.com/aws/aws-sdk-go-v2/service/organizations` + +- [ ] **Step 2: Write failing tests for `ResolveTargets`** + +Create `internal/providers/aws/multiaccount_test.go`: + +```go +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +type mockSTSClient struct { + identity *sts.GetCallerIdentityOutput + roles map[string]*sts.AssumeRoleOutput // keyed by account ID + failAccts map[string]error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return m.identity, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple extraction: find the account ID between the 4th and 5th colons + acct := "" + colons := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + rest := arn[i+1:] + for j, r := range rest { + if r == ':' { + acct = rest[:j] + break + } + } + break + } + } + } + if err, ok := m.failAccts[acct]; ok { + return nil, err + } + if out, ok := m.roles[acct]; ok { + return out, nil + } + return nil, fmt.Errorf("no role for account %s", acct) +} + +type mockOrgClient struct { + accounts []orgtypes.Account +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +func TestResolveTargets_NoTargets_ReturnsSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, ScanOptions{}) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target, got %d", len(targets)) + } + if targets[0].AccountID != "999999999999" { + t.Errorf("expected account 999999999999, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + "222222222222": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK2"), SecretAccessKey: aws.String("SK2"), SessionToken: aws.String("ST2"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + // 2 explicit + self = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets (2 explicit + self), got %d", len(targets)) + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (skip self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + failAccts: map[string]error{ + "222222222222": fmt.Errorf("AccessDenied"), + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "222222222222" { + t.Errorf("expected error for 222222222222, got %s", errs[0].AccountID) + } + if len(targets) != 1 { + t.Fatalf("expected 1 successful target, got %d", len(targets)) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("999999999999"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + }, + } + + opts := ScanOptions{ + OrgScan: true, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, orgClient, opts) + + // 999 (self, no assume) + 111 (assumed) = 2 targets; 333 is suspended so skipped + // Note: 999 is self so not assumed; 111 is assumed successfully + if len(targets) != 2 { + t.Fatalf("expected 2 targets (self + 1 active non-self), got %d", len(targets)) + } + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } +} +``` + +- [ ] **Step 3: Run test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: FAIL (function and types not defined) + +- [ ] **Step 4: Implement `multiaccount.go`** + +Create `internal/providers/aws/multiaccount.go`: + +```go +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount := aws.ToString(identity.Account) + + // No multi-target flags: return self only + if len(opts.Targets) == 0 && !opts.OrgScan { + return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + // Determine account IDs to scan + var accountIDs []string + if opts.OrgScan { + discovered, err := discoverOrgAccounts(ctx, orgClient) + if err != nil { + return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", err)}} + } + accountIDs = discovered + } else { + accountIDs = opts.Targets + } + + var targets []Target + var targetErrors []TargetError + + // Include self unless skipped + if !opts.SkipSelf { + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + } + + // Assume role in each non-self account + for _, acctID := range accountIDs { + if acctID == selfAccount { + continue // already included as self (or skipped) + } + + cfg, err := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if err != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: err}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return targets, targetErrors +} + +func discoverOrgAccounts(ctx context.Context, client OrgClient) ([]string, error) { + var accounts []string + var nextToken *string + + for { + out, err := client.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accounts = append(accounts, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accounts, nil +} + +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: &roleArn, + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = &externalID + } + + out, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleArn, err) + } + + creds := out.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} +``` + +- [ ] **Step 5: Fix the test import — add `ststypes` import** + +The tests reference `ststypes.Credentials`. Add this import to `multiaccount_test.go`: + +```go +import ( + // ... existing imports ... + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) +``` + +- [ ] **Step 6: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: all pass + +- [ ] **Step 7: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 8: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/multiaccount.go internal/providers/aws/multiaccount_test.go go.mod go.sum +git commit -m "Add ResolveTargets with STS AssumeRole and Organizations discovery" +``` + +--- + +### Task 4: Refactor `Scan()` for multi-target parallel scanning + +**Files:** +- Modify: `internal/providers/aws/scanner.go` + +- [ ] **Step 1: Add multi-target fields to `ScanOptions`** + +In `internal/providers/aws/scanner.go`, add to the `ScanOptions` struct: + +```go +type ScanOptions struct { + Profile string + Regions []string + MetricWindow MetricWindow + SkipMetrics bool + SkipSageMaker bool + SkipEKS bool + SkipCosts bool + ExcludeTags map[string]string + MinUptimeDays int + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool +} +``` + +- [ ] **Step 2: Refactor `Scan()` to use `ResolveTargets` and scan all targets in parallel** + +Replace the `Scan` function in `scanner.go` with: + +```go +func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { + start := time.Now() + + // Load AWS config + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if opts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(opts.Profile)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err != nil { + return nil, fmt.Errorf("loading AWS config: %w", err) + } + + // Resolve targets + stsClient := sts.NewFromConfig(cfg) + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved") + } + + // Report target errors + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: target %s: %v\n", te.AccountID, te.Err) + } + + fmt.Fprintf(os.Stderr, " Scanning %d target(s)...\n", len(targets)) + + // Determine regions to scan + regions := opts.Regions + if len(regions) == 0 { + regions, err = getGPURegions(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("listing regions: %w", err) + } + } + + fmt.Fprintf(os.Stderr, " Scanning %d regions per target for GPU instances...\n", len(regions)) + + // Scan all targets in parallel + type targetResult struct { + accountID string + instances []models.GPUInstance + regions []string + } + + resultsCh := make(chan targetResult, len(targets)) + var wg sync.WaitGroup + + for _, target := range targets { + wg.Add(1) + go func(t Target) { + defer wg.Done() + instances, scannedRegions := scanTarget(ctx, t, regions, opts) + resultsCh <- targetResult{ + accountID: t.AccountID, + instances: instances, + regions: scannedRegions, + } + }(target) + } + + go func() { + wg.Wait() + close(resultsCh) + }() + + var allInstances []models.GPUInstance + regionSet := make(map[string]bool) + callerAccount := "" + if len(targets) > 0 { + callerAccount = targets[0].AccountID + } + + for res := range resultsCh { + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true + } + } + + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + + // Filter by excluded tags + if len(opts.ExcludeTags) > 0 { + filtered := allInstances[:0] + excluded := 0 + for _, inst := range allInstances { + if matchesExcludeTags(inst.Tags, opts.ExcludeTags) { + excluded++ + continue + } + filtered = append(filtered, inst) + } + allInstances = filtered + if excluded > 0 { + fmt.Fprintf(os.Stderr, " Excluded %d instance(s) by tag filter.\n", excluded) + } + } + + // Run analysis + analysis.AnalyzeAll(allInstances) + + // Suppress signals below minimum uptime threshold + if opts.MinUptimeDays > 0 { + minHours := float64(opts.MinUptimeDays) * 24 + for i := range allInstances { + inst := &allInstances[i] + if inst.UptimeHours >= minHours { + continue + } + inst.WasteSignals = nil + inst.Recommendations = nil + inst.EstimatedSavings = 0 + } + } + + // Build summaries + summary := BuildSummary(allInstances) + + result := &models.ScanResult{ + Timestamp: start, + AccountID: callerAccount, + Regions: scannedRegions, + ScanDuration: time.Since(start).Round(time.Millisecond).String(), + Instances: allInstances, + Summary: summary, + } + + // Add multi-target metadata + if len(targets) > 1 || len(targetErrors) > 0 { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account. +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: %s/%s: %v\n", target.AccountID, res.region, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (per-target, since CE is account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: %s cost enrichment: %v\n", target.AccountID, err) + } + } + + return allInstances, scannedRegions +} +``` + +- [ ] **Step 3: Add the organizations import to scanner.go** + +Add to the import block: + +```go +"github.com/aws/aws-sdk-go-v2/service/organizations" +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/scanner.go +git commit -m "Refactor Scan() for parallel multi-target scanning" +``` + +--- + +### Task 5: Wire CLI flags into scan command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add flag variables and register flags** + +Add the new flag variables alongside the existing scan flags: + +```go +var ( + // ... existing flags ... + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool +) +``` + +In the `init()` function, add after the existing `scanCmd.Flags` calls: + +```go +scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") +scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") +scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") +scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") +scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") +scanCmd.MarkFlagsMutuallyExclusive("targets", "org") +``` + +- [ ] **Step 2: Wire flags into `ScanOptions` in `runScan`** + +In the `runScan` function, add the new fields to the opts construction: + +```go +opts := awsprovider.DefaultScanOptions() +opts.Profile = scanProfile +opts.Regions = scanRegions +opts.SkipMetrics = scanSkipMetrics +opts.SkipSageMaker = scanSkipSageMaker +opts.SkipEKS = scanSkipEKS +opts.SkipCosts = scanSkipCosts +opts.ExcludeTags = parseExcludeTags(scanExcludeTags) +opts.MinUptimeDays = scanMinUptimeDays +opts.Targets = scanTargets +opts.Role = scanRole +opts.ExternalID = scanExternalID +opts.OrgScan = scanOrg +opts.SkipSelf = scanSkipSelf +``` + +- [ ] **Step 3: Add validation — `--role` required with `--targets` or `--org`** + +Add at the top of `runScan`, before creating opts: + +```go +if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") +} +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Verify CLI help shows new flags** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: new flags visible in help text + +- [ ] **Step 6: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 7: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add --targets, --role, --org, --external-id, --skip-self flags to scan command" +``` + +--- + +### Task 6: Update table formatter for multi-target output + +**Files:** +- Modify: `internal/output/table.go` + +- [ ] **Step 1: Add "By Target" summary table to `FormatTable`** + +In `internal/output/table.go`, add a new function and call it from `FormatTable` after the summary box: + +```go +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + // Target errors + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} +``` + +In `FormatTable`, add the call after the summary box and before the "No GPU instances" check: + +```go +// ... after the summary box closing line ... + +printTargetSummary(w, result) + +if s.TotalInstances == 0 { +``` + +- [ ] **Step 2: Add "Target" column to `printInstanceTable` when multi-target** + +Modify `printInstanceTable` to accept and use target info. Since the formatter doesn't know if it's multi-target from just the instance slice, pass the result: + +Change the call sites in `FormatTable` from: +```go +printInstanceTable(w, critical) +``` +to: +```go +multiTarget := len(result.TargetSummaries) > 1 +printInstanceTable(w, critical, multiTarget) +``` + +Update `printInstanceTable`: + +```go +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } + + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + + signal := "" + if len(inst.WasteSignals) > 0 { + signal = inst.WasteSignals[0].Type + } + + rec := "" + if len(inst.Recommendations) > 0 { + rec = inst.Recommendations[0].Description + } + + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } + } + fmt.Fprintln(w) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/table.go +git commit -m "Add per-target summary table and target column to table formatter" +``` + +--- + +### Task 7: Update markdown and Slack formatters for multi-target output + +**Files:** +- Modify: `internal/output/markdown.go` +- Modify: `internal/output/slack.go` + +- [ ] **Step 1: Add per-target section to markdown formatter** + +In `internal/output/markdown.go`, add after the Summary table (after the `s.HealthyCount` line and before the "No GPU instances" check): + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) +} + +if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) +} +``` + +Also add a "Target" column to the Findings table when multi-target. Change the table header and row formatting: + +```go +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") +} else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") +} + +for _, inst := range result.Instances { + // ... existing name/signal/rec/savings formatting ... + + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } +} +``` + +- [ ] **Step 2: Add per-target block to Slack formatter** + +In `internal/output/slack.go`, in `FormatSlack`, add after the summary block and divider: + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) +} + +// Target errors +if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/markdown.go internal/output/slack.go +git commit -m "Add per-target summaries to markdown and Slack formatters" +``` + +--- + +### Task 8: Update `iam-policy` command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add cross-account and Organizations statements to `iam-policy` output** + +In `cmd/gpuaudit/main.go`, in the `iamPolicyCmd` Run function, add two new statements to the policy `Statement` slice: + +```go +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", +}, +``` + +Add a comment before encoding: + +```go +fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 3: Verify output looks correct** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit iam-policy` +Expected: JSON policy with the two new statements appended + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add cross-account and Organizations permissions to iam-policy output" +``` + +--- + +### Task 9: Update README with multi-target documentation + +**Files:** +- Modify: `README.md` + +- [ ] **Step 1: Add multi-account scanning section to README** + +Add a new section after the existing usage documentation: + +```markdown +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` +``` + +- [ ] **Step 2: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add README.md +git commit -m "Add multi-account scanning docs to README" +``` + +--- + +### Task 10: End-to-end verification + +**Files:** None (verification only) + +- [ ] **Step 1: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass + +- [ ] **Step 2: Run go vet** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go vet ./...` +Expected: no issues + +- [ ] **Step 3: Verify CLI help** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: all new flags visible (--targets, --role, --org, --external-id, --skip-self) + +- [ ] **Step 4: Verify mutual exclusivity** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 --org --role test 2>&1` +Expected: error about mutually exclusive flags + +- [ ] **Step 5: Verify --role validation** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 2>&1` +Expected: error "role is required when using --targets or --org" + +- [ ] **Step 6: Verify single-account scan still works (no regression)** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --skip-metrics --skip-sagemaker --skip-eks --skip-k8s --skip-costs 2>&1` +Expected: runs normally, output unchanged from before this feature From bf6ab49daafe81a9621f0d70176e59112a779fb0 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:13:34 +0100 Subject: [PATCH 10/19] Add TargetSummary and TargetErrorInfo model types for multi-target scanning --- internal/models/models.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index 0fd6557..a5b9835 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -117,12 +117,15 @@ type Recommendation struct { // ScanResult holds the complete output of a gpuaudit scan. type ScanResult struct { - Timestamp time.Time `json:"timestamp"` - AccountID string `json:"account_id"` - Regions []string `json:"regions"` - ScanDuration string `json:"scan_duration"` - Instances []GPUInstance `json:"instances"` - Summary ScanSummary `json:"summary"` + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` } // ScanSummary provides aggregate statistics for a scan. @@ -137,5 +140,22 @@ type ScanSummary struct { HealthyCount int `json:"healthy_count"` } +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} + // Ptr is a convenience helper for creating pointer values in tests and literals. func Ptr[T any](v T) *T { return &v } From 817eaac48d36075ec868ea1fd56fdbe914a775a9 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:15:35 +0100 Subject: [PATCH 11/19] Extract BuildSummary to summary.go and add BuildTargetSummaries --- internal/providers/aws/scanner.go | 40 ----------- internal/providers/aws/summary.go | 96 ++++++++++++++++++++++++++ internal/providers/aws/summary_test.go | 89 ++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 40 deletions(-) create mode 100644 internal/providers/aws/summary.go create mode 100644 internal/providers/aws/summary_test.go diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index d8d5921..e28cbab 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -231,46 +231,6 @@ func getGPURegions(ctx context.Context, cfg aws.Config) ([]string, error) { }, nil } -// BuildSummary computes aggregate statistics for a set of GPU instances. -func BuildSummary(instances []models.GPUInstance) models.ScanSummary { - s := models.ScanSummary{ - TotalInstances: len(instances), - } - - for _, inst := range instances { - s.TotalMonthlyCost += inst.MonthlyCost - s.TotalEstimatedWaste += inst.EstimatedSavings - - maxSeverity := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSeverity = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { - maxSeverity = models.SeverityWarning - } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { - maxSeverity = models.SeverityInfo - } - } - - switch maxSeverity { - case models.SeverityCritical: - s.CriticalCount++ - case models.SeverityWarning: - s.WarningCount++ - case models.SeverityInfo: - s.InfoCount++ - default: - s.HealthyCount++ - } - } - - if s.TotalMonthlyCost > 0 { - s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 - } - - return s -} - func matchesExcludeTags(instanceTags map[string]string, excludes map[string]string) bool { for k, v := range excludes { if instanceTags[k] == v { diff --git a/internal/providers/aws/summary.go b/internal/providers/aws/summary.go new file mode 100644 index 0000000..bae351a --- /dev/null +++ b/internal/providers/aws/summary.go @@ -0,0 +1,96 @@ +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + maxSeverity := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSeverity = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { + maxSeverity = models.SeverityWarning + } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { + maxSeverity = models.SeverityInfo + } + } + + switch maxSeverity { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + maxSev := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSev = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { + maxSev = models.SeverityWarning + } + } + switch maxSev { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} diff --git a/internal/providers/aws/summary_test.go b/internal/providers/aws/summary_test.go new file mode 100644 index 0000000..b429e39 --- /dev/null +++ b/internal/providers/aws/summary_test.go @@ -0,0 +1,89 @@ +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} From 1f21c288f860d5ea1f5bcb20d63981ae8275c941 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:20:17 +0100 Subject: [PATCH 12/19] Implement ResolveTargets with STS AssumeRole for multi-account scanning Add ResolveTargets function that resolves scan targets based on --targets, --org, --role, and --skip-self options. Self account uses original credentials (no AssumeRole), failed assumptions are collected as TargetError rather than being fatal. Add STSClient and OrgClient interfaces, Target and TargetError types, multi-target fields to ScanOptions, and organizations SDK dependency. Includes 6 tests covering: self-only, explicit targets, skip-self, partial failure, org discovery with suspended account filtering, and self-in-targets deduplication. --- go.mod | 11 +- go.sum | 18 +- internal/providers/aws/multiaccount.go | 165 +++++++++++ internal/providers/aws/multiaccount_test.go | 298 ++++++++++++++++++++ internal/providers/aws/scanner.go | 7 + 5 files changed, 486 insertions(+), 13 deletions(-) create mode 100644 internal/providers/aws/multiaccount.go create mode 100644 internal/providers/aws/multiaccount_test.go diff --git a/go.mod b/go.mod index b86d582..96c6700 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,14 @@ module github.com/gpuaudit/cli go 1.24.0 require ( - github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2 v1.41.6 github.com/aws/aws-sdk-go-v2/config v1.32.14 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 github.com/aws/aws-sdk-go-v2/service/costexplorer v1.63.6 github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 github.com/aws/aws-sdk-go-v2/service/eks v1.82.0 + github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 github.com/spf13/cobra v1.10.2 @@ -18,17 +20,16 @@ require ( ) require ( - github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect - github.com/aws/smithy-go v1.24.2 // indirect + github.com/aws/smithy-go v1.25.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index c4d6139..0c37a90 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,15 @@ -github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= -github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg= +github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 h1:ud2A364lLBkhGAC7oYw/1xg9BF4acwJC+qdLykxy83o= @@ -24,6 +24,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhL github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 h1:2TDersSNowBwSRTrnD0LxLilpr6Dr5coXwVsWO7f2rw= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2/go.mod h1:UMm4MKZDJMbuJZF5QOJBsVRMLeKiEXAgCXFpocWPDFo= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 h1:5jLvLVu20tlFgVOsX+ns4jNVzoUWP36AQc5sAvNJSMI= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0/go.mod h1:zsRrjJIfG9a9b3VRU+uPa3dX5fqgI+zKMXD4tbIlbdA= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= @@ -34,8 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6f github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= -github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= -github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U= +github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/providers/aws/multiaccount.go b/internal/providers/aws/multiaccount.go new file mode 100644 index 0000000..fd8a99c --- /dev/null +++ b/internal/providers/aws/multiaccount.go @@ -0,0 +1,165 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +// +// Behaviour: +// - No --targets/--org: returns self only (uses baseCfg, no AssumeRole) +// - --targets + --role: AssumeRole for each, self included by default +// - --org + --role: ListAccounts, filter Active, AssumeRole for non-self accounts +// - --skip-self: exclude caller's account +// - Self account is never AssumeRole'd — uses original credentials +// - Failed AssumeRole calls are collected as TargetError, not fatal +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { + // Identify the caller's own account. + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount := aws.ToString(identity.Account) + + // Determine the list of account IDs to scan. + var accountIDs []string + + switch { + case opts.OrgScan: + activeAccounts, listErr := listActiveOrgAccounts(ctx, orgClient) + if listErr != nil { + return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} + } + accountIDs = activeAccounts + case len(opts.Targets) > 0: + // Always include self unless it is already in the list or --skip-self is set. + seen := make(map[string]bool) + for _, id := range opts.Targets { + if !seen[id] { + accountIDs = append(accountIDs, id) + seen[id] = true + } + } + if !seen[selfAccount] && !opts.SkipSelf { + // Prepend self so it appears first. + accountIDs = append([]string{selfAccount}, accountIDs...) + } + default: + // No multi-target flags — scan self only. + return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + // Resolve credentials for each account. + var targets []Target + var targetErrors []TargetError + + for _, acctID := range accountIDs { + if opts.SkipSelf && acctID == selfAccount { + continue + } + + if acctID == selfAccount { + // Self: use original credentials, no AssumeRole. + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + continue + } + + // AssumeRole into the target account. + cfg, assumeErr := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if assumeErr != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: assumeErr}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return targets, targetErrors +} + +// assumeRole assumes a role in the given account and returns an aws.Config +// with the temporary credentials. +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleARN := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleARN), + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = aws.String(externalID) + } + + result, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleARN, err) + } + + creds := result.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} + +// listActiveOrgAccounts returns the account IDs of all active accounts in the organization. +func listActiveOrgAccounts(ctx context.Context, orgClient OrgClient) ([]string, error) { + var accountIDs []string + var nextToken *string + + for { + out, err := orgClient.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accountIDs = append(accountIDs, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accountIDs, nil +} diff --git a/internal/providers/aws/multiaccount_test.go b/internal/providers/aws/multiaccount_test.go new file mode 100644 index 0000000..2d40cce --- /dev/null +++ b/internal/providers/aws/multiaccount_test.go @@ -0,0 +1,298 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// --- Mock STS client --- + +type mockSTSClient struct { + callerAccount string + assumeResults map[string]*sts.AssumeRoleOutput // accountID -> output + assumeErrors map[string]error // accountID -> error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Account: aws.String(m.callerAccount), + }, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from the role ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple parse: find the account between the 4th and 5th colons + accountID := parseAccountFromARN(arn) + + if err, ok := m.assumeErrors[accountID]; ok { + return nil, err + } + if out, ok := m.assumeResults[accountID]; ok { + return out, nil + } + return nil, fmt.Errorf("no mock configured for account %s", accountID) +} + +func parseAccountFromARN(arn string) string { + // arn:aws:iam::123456789012:role/name + colons := 0 + start := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + start = i + 1 + } + if colons == 5 { + return arn[start:i] + } + } + } + return "" +} + +// --- Mock Org client --- + +type mockOrgClient struct { + accounts []orgtypes.Account + err error +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + if m.err != nil { + return nil, m.err + } + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +// Helper to build a successful AssumeRole result with dummy credentials. +func assumeRoleOK(accountID string) *sts.AssumeRoleOutput { + exp := time.Now().Add(1 * time.Hour) + return &sts.AssumeRoleOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AKID-" + accountID), + SecretAccessKey: aws.String("SECRET-" + accountID), + SessionToken: aws.String("TOKEN-" + accountID), + Expiration: &exp, + }, + } +} + +func TestResolveTargets_NoTargets_ReturnsSelfOnly(t *testing.T) { + stsClient := &mockSTSClient{callerAccount: "111111111111"} + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{} + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected account 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets_ReturnsSelfPlusAssumed(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "333333333333": assumeRoleOK("333333333333"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + // Self + 2 explicit targets = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets, got %d", len(targets)) + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } + + // Verify assumed targets + for _, acct := range []string{"222222222222", "333333333333"} { + found := false + for _, tgt := range targets { + if tgt.AccountID == acct { + found = true + break + } + } + if !found { + t.Errorf("account %s not found in targets", acct) + } + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222"}, + Role: "AuditRole", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (no self), got %d", len(targets)) + } + if targets[0].AccountID != "222222222222" { + t.Errorf("expected account 222222222222, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + assumeErrors: map[string]error{ + "333333333333": fmt.Errorf("access denied"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + // Self + 222 succeeded, 333 failed + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "333333333333" { + t.Errorf("expected error for 333333333333, got %s", errs[0].AccountID) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "444444444444": assumeRoleOK("444444444444"), + }, + } + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("222222222222"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + {Id: aws.String("444444444444"), Status: orgtypes.AccountStatusActive}, + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + OrgScan: true, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Active accounts: 111 (self), 222, 444. Suspended 333 is filtered. + if len(targets) != 3 { + t.Fatalf("expected 3 targets (self + 2 active non-self), got %d", len(targets)) + } + + // Verify suspended account is excluded + for _, tgt := range targets { + if tgt.AccountID == "333333333333" { + t.Error("suspended account 333333333333 should be excluded") + } + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } +} + +func TestResolveTargets_SelfInExplicitTargets_NotAssumed(t *testing.T) { + // If the caller's own account appears in --targets, it should use baseCfg (no AssumeRole). + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + // No AssumeRole result for self — it should not be called + assumeErrors: map[string]error{ + "111111111111": fmt.Errorf("should not assume role for self"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Self (from targets list, no duplicate) + 222 + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } +} diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index e28cbab..34f716c 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -34,6 +34,13 @@ type ScanOptions struct { SkipCosts bool ExcludeTags map[string]string MinUptimeDays int + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool } // DefaultScanOptions returns sensible defaults. From 6bb43eae59fa7fdcd51a6ce4163a07bcabc95494 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:23:00 +0100 Subject: [PATCH 13/19] Refactor Scan() for parallel multi-target scanning --- internal/providers/aws/scanner.go | 151 +++++++++++++++++++++++------- 1 file changed, 117 insertions(+), 34 deletions(-) diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index 34f716c..b9a0986 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -16,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/costexplorer" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/eks" + "github.com/aws/aws-sdk-go-v2/service/organizations" "github.com/aws/aws-sdk-go-v2/service/sagemaker" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -50,7 +51,7 @@ func DefaultScanOptions() ScanOptions { } } -// Scan performs a full GPU audit of the AWS account. +// Scan performs a full GPU audit across one or more AWS accounts. func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { start := time.Now() @@ -65,13 +66,26 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { return nil, fmt.Errorf("loading AWS config: %w", err) } - // Get account ID + // Resolve targets (accounts to scan) stsClient := sts.NewFromConfig(cfg) - identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("getting caller identity: %w", err) + + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + + // Print target errors to stderr and check for fatal failure + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: failed to resolve target %s: %v\n", te.AccountID, te.Err) } - accountID := aws.ToString(identity.Account) + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved (errors: %d)", len(targetErrors)) + } + + // Determine the caller account from the first target + callerAccount := targets[0].AccountID // Determine regions to scan regions := opts.Regions @@ -82,46 +96,55 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - fmt.Fprintf(os.Stderr," Scanning %d regions for GPU instances...\n", len(regions)) + if len(targets) > 1 { + fmt.Fprintf(os.Stderr, " Scanning %d accounts across %d regions for GPU instances...\n", len(targets), len(regions)) + } else { + fmt.Fprintf(os.Stderr, " Scanning %d regions for GPU instances...\n", len(regions)) + } - // Scan all regions concurrently - type regionResult struct { - region string + // Scan all targets in parallel + type targetResult struct { instances []models.GPUInstance + regions []string err error } - results := make(chan regionResult, len(regions)) + targetResults := make(chan targetResult, len(targets)) var wg sync.WaitGroup - for _, region := range regions { + for _, t := range targets { wg.Add(1) - go func(r string) { + go func(target Target) { defer wg.Done() - instances, err := scanRegion(ctx, cfg, accountID, r, opts) - results <- regionResult{region: r, instances: instances, err: err} - }(region) + instances, scannedRegions, scanErr := scanTarget(ctx, target, regions, opts) + targetResults <- targetResult{instances: instances, regions: scannedRegions, err: scanErr} + }(t) } go func() { wg.Wait() - close(results) + close(targetResults) }() var allInstances []models.GPUInstance - var scannedRegions []string + regionSet := make(map[string]bool) - for res := range results { + for res := range targetResults { if res.err != nil { - fmt.Fprintf(os.Stderr," warning: error scanning %s: %v\n", res.region, res.err) + fmt.Fprintf(os.Stderr, " warning: target scan error: %v\n", res.err) continue } - if len(res.instances) > 0 { - allInstances = append(allInstances, res.instances...) - scannedRegions = append(scannedRegions, res.region) + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true } } + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + // Filter by excluded tags if len(opts.ExcludeTags) > 0 { filtered := allInstances[:0] @@ -139,14 +162,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - // Enrich with Cost Explorer data (account-level, not per-region) - if !opts.SkipCosts && len(allInstances) > 0 { - ceClient := costexplorer.NewFromConfig(cfg) - if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { - fmt.Fprintf(os.Stderr," warning: could not enrich cost data: %v\n", err) - } - } - // Run analysis analysis.AnalyzeAll(allInstances) @@ -167,14 +182,82 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { // Build summary summary := BuildSummary(allInstances) - return &models.ScanResult{ + result := &models.ScanResult{ Timestamp: start, - AccountID: accountID, + AccountID: callerAccount, Regions: scannedRegions, ScanDuration: time.Since(start).Round(time.Millisecond).String(), Instances: allInstances, Summary: summary, - }, nil + } + + // Populate multi-target metadata when multiple targets are involved + isMultiTarget := len(targets) > 1 || len(targetErrors) > 0 + if isMultiTarget { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account, including +// Cost Explorer enrichment (which is account-scoped). +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string, error) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: error scanning %s in account %s: %v\n", res.region, target.AccountID, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: could not enrich cost data for account %s: %v\n", target.AccountID, err) + } + } + + return allInstances, scannedRegions, nil } func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, opts ScanOptions) ([]models.GPUInstance, error) { From ce75ab16b97a1a850e04314f77d3e6b95e0180bd Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:24:43 +0100 Subject: [PATCH 14/19] Add --targets, --role, --org, --external-id, --skip-self flags to scan command --- cmd/gpuaudit/main.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 217a100..bef0da1 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -52,6 +52,11 @@ var ( scanKubeContext string scanExcludeTags []string scanMinUptimeDays int + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool ) // --- diff command --- @@ -85,6 +90,12 @@ func init() { scanCmd.Flags().StringVar(&scanKubeContext, "kube-context", "", "Kubernetes context to use (default: current context)") scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") + scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") + scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") + scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") + scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") + scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") + scanCmd.MarkFlagsMutuallyExclusive("targets", "org") diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") @@ -98,6 +109,10 @@ func init() { func runScan(cmd *cobra.Command, args []string) error { ctx := context.Background() + if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") + } + opts := awsprovider.DefaultScanOptions() opts.Profile = scanProfile opts.Regions = scanRegions @@ -107,6 +122,11 @@ func runScan(cmd *cobra.Command, args []string) error { opts.SkipCosts = scanSkipCosts opts.ExcludeTags = parseExcludeTags(scanExcludeTags) opts.MinUptimeDays = scanMinUptimeDays + opts.Targets = scanTargets + opts.Role = scanRole + opts.ExternalID = scanExternalID + opts.OrgScan = scanOrg + opts.SkipSelf = scanSkipSelf result, err := awsprovider.Scan(ctx, opts) if err != nil { From 1906f9f7fa423afa0e189e3cb042da4829efa70a Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:26:29 +0100 Subject: [PATCH 15/19] Add per-target summary table and target column to table formatter --- internal/output/table.go | 76 ++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/internal/output/table.go b/internal/output/table.go index 3f73232..1e60464 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -34,6 +34,8 @@ func FormatTable(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, " │ Estimated monthly waste: $%-10.0f (%4.0f%%) │\n", s.TotalEstimatedWaste, s.WastePercent) fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n\n") + printTargetSummary(w, result) + if s.TotalInstances == 0 { fmt.Fprintf(w, " No GPU instances found.\n\n") return @@ -42,14 +44,16 @@ func FormatTable(w io.Writer, result *models.ScanResult) { // Group instances by severity critical, warning, healthy := groupBySeverity(result.Instances) + multiTarget := len(result.TargetSummaries) > 1 + if len(critical) > 0 { fmt.Fprintf(w, " CRITICAL — %d instance(s), $%.0f/mo potential savings\n\n", len(critical), sumSavings(critical)) - printInstanceTable(w, critical) + printInstanceTable(w, critical, multiTarget) } if len(warning) > 0 { fmt.Fprintf(w, " WARNING — %d instance(s), $%.0f/mo potential savings\n\n", len(warning), sumSavings(warning)) - printInstanceTable(w, warning) + printInstanceTable(w, warning, multiTarget) } if len(healthy) > 0 { @@ -57,17 +61,54 @@ func FormatTable(w io.Writer, result *models.ScanResult) { } } -func printInstanceTable(w io.Writer, instances []models.GPUInstance) { - // Header - fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", - "Instance", "Type", "Monthly", "Signal", "Recommendation") - fmt.Fprintf(w, " %s %s %s %s %s\n", - strings.Repeat("─", 36), - strings.Repeat("─", 26), - strings.Repeat("─", 10), - strings.Repeat("─", 16), - strings.Repeat("─", 50), - ) +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} + +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } for _, inst := range instances { name := inst.Name @@ -94,8 +135,13 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance) { rec = inst.Recommendations[0].Description } - fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", - name, typeDesc, inst.MonthlyCost, signal, rec) + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } } fmt.Fprintln(w) } From 53289803d53cfd5c4ada274ed004a57f585a7eaa Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:27:59 +0100 Subject: [PATCH 16/19] Add per-target summaries to markdown and Slack formatters --- internal/output/markdown.go | 43 ++++++++++++++++++++++++++++++++----- internal/output/slack.go | 21 ++++++++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/internal/output/markdown.go b/internal/output/markdown.go index 13290bb..58995c5 100644 --- a/internal/output/markdown.go +++ b/internal/output/markdown.go @@ -31,14 +31,41 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, "| Warning | %d |\n", s.WarningCount) fmt.Fprintf(w, "| Healthy | %d |\n\n", s.HealthyCount) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) + } + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } + if s.TotalInstances == 0 { fmt.Fprintf(w, "No GPU instances found.\n") return } fmt.Fprintf(w, "## Findings\n\n") - fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") - fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + multiTarget := len(result.TargetSummaries) > 1 + if multiTarget { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") + } else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + } for _, inst := range result.Instances { name := inst.Name @@ -61,8 +88,14 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { savings = fmt.Sprintf("$%.0f/mo", inst.EstimatedSavings) } - fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", - name, inst.InstanceType, inst.GPUCount, inst.GPUModel, - inst.MonthlyCost, signal, savings, rec) + if multiTarget { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } } } diff --git a/internal/output/slack.go b/internal/output/slack.go index 530afe7..f8fc334 100644 --- a/internal/output/slack.go +++ b/internal/output/slack.go @@ -34,6 +34,27 @@ func FormatSlack(w io.Writer, result *models.ScanResult) error { blocks = append(blocks, map[string]any{"type": "divider"}) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) + } + + // Target errors + if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + } + // Critical findings critical, warning, _ := groupBySeverity(result.Instances) From 2e54bd0c5ae7f642f07f166a216c13029cbfd079 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:29:04 +0100 Subject: [PATCH 17/19] Add cross-account and Organizations permissions to iam-policy output --- cmd/gpuaudit/main.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index bef0da1..cff00e5 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -342,8 +342,21 @@ var iamPolicyCmd = &cobra.Command{ }, "Resource": "*", }, + { + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", + }, + { + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", + }, }, } + fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") enc.Encode(policy) From f8248233d4ef1c47c00f8947eb5f5fe2d1b64875 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:30:06 +0100 Subject: [PATCH 18/19] Add multi-account scanning docs to README --- README.md | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/README.md b/README.md index 8738521..3386201 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,106 @@ gpuaudit diff scan-apr-08.json scan-apr-15.json Matches instances by ID. Reports added, removed, and changed instances with per-field diffs (instance type, pricing model, cost, state, GPU allocation, waste severity). +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` + ## IAM permissions gpuaudit is read-only. It never modifies your infrastructure. Generate the minimal IAM policy: From ebd0806414c5a200f20f7895d5adf29676048c5f Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:41:41 +0100 Subject: [PATCH 19/19] Fix callerAccount bug, deduplicate severity logic, clean up dead code ResolveTargets now returns selfAccount separately so Scan() always gets the correct caller identity regardless of --skip-self. Extracted models.MaxSeverity to replace three copies of severity classification. Removed dead error return from scanTarget. Added missing copyright headers. --- internal/models/models.go | 17 ++++++++++++++ internal/output/table.go | 19 +-------------- internal/providers/aws/multiaccount.go | 16 +++++-------- internal/providers/aws/multiaccount_test.go | 18 +++++++++----- internal/providers/aws/scanner.go | 18 ++++---------- internal/providers/aws/summary.go | 26 ++++----------------- internal/providers/aws/summary_test.go | 3 +++ 7 files changed, 49 insertions(+), 68 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index a5b9835..ed9238a 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -157,5 +157,22 @@ type TargetErrorInfo struct { Error string `json:"error"` } +// MaxSeverity returns the highest severity among the given waste signals. +func MaxSeverity(signals []WasteSignal) Severity { + max := Severity("") + for _, s := range signals { + if s.Severity == SeverityCritical { + return SeverityCritical + } + if s.Severity == SeverityWarning { + max = SeverityWarning + } + if s.Severity == SeverityInfo && max == "" { + max = SeverityInfo + } + } + return max +} + // Ptr is a convenience helper for creating pointer values in tests and literals. func Ptr[T any](v T) *T { return &v } diff --git a/internal/output/table.go b/internal/output/table.go index 1e60464..ece729c 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -148,8 +148,7 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy []models.GPUInstance) { for _, inst := range instances { - maxSev := maxSeverity(inst.WasteSignals) - switch maxSev { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: critical = append(critical, inst) case models.SeverityWarning: @@ -171,22 +170,6 @@ func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy return } -func maxSeverity(signals []models.WasteSignal) models.Severity { - max := models.Severity("") - for _, s := range signals { - if s.Severity == models.SeverityCritical { - return models.SeverityCritical - } - if s.Severity == models.SeverityWarning { - max = models.SeverityWarning - } - if s.Severity == models.SeverityInfo && max == "" { - max = models.SeverityInfo - } - } - return max -} - func sumSavings(instances []models.GPUInstance) float64 { total := 0.0 for _, inst := range instances { diff --git a/internal/providers/aws/multiaccount.go b/internal/providers/aws/multiaccount.go index fd8a99c..298c475 100644 --- a/internal/providers/aws/multiaccount.go +++ b/internal/providers/aws/multiaccount.go @@ -46,13 +46,13 @@ type OrgClient interface { // - --skip-self: exclude caller's account // - Self account is never AssumeRole'd — uses original credentials // - Failed AssumeRole calls are collected as TargetError, not fatal -func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) (selfAccount string, targets []Target, targetErrors []TargetError) { // Identify the caller's own account. identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { - return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + return "", nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} } - selfAccount := aws.ToString(identity.Account) + selfAccount = aws.ToString(identity.Account) // Determine the list of account IDs to scan. var accountIDs []string @@ -61,7 +61,7 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient case opts.OrgScan: activeAccounts, listErr := listActiveOrgAccounts(ctx, orgClient) if listErr != nil { - return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} + return selfAccount, nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} } accountIDs = activeAccounts case len(opts.Targets) > 0: @@ -79,13 +79,9 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient } default: // No multi-target flags — scan self only. - return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + return selfAccount, []Target{{AccountID: selfAccount, Config: baseCfg}}, nil } - // Resolve credentials for each account. - var targets []Target - var targetErrors []TargetError - for _, acctID := range accountIDs { if opts.SkipSelf && acctID == selfAccount { continue @@ -106,7 +102,7 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient targets = append(targets, Target{AccountID: acctID, Config: cfg}) } - return targets, targetErrors + return selfAccount, targets, targetErrors } // assumeRole assumes a role in the given account and returns an aws.Config diff --git a/internal/providers/aws/multiaccount_test.go b/internal/providers/aws/multiaccount_test.go index 2d40cce..bc2ba11 100644 --- a/internal/providers/aws/multiaccount_test.go +++ b/internal/providers/aws/multiaccount_test.go @@ -95,11 +95,14 @@ func TestResolveTargets_NoTargets_ReturnsSelfOnly(t *testing.T) { baseCfg := aws.Config{Region: "us-east-1"} opts := ScanOptions{} - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } if len(targets) != 1 { t.Fatalf("expected 1 target (self), got %d", len(targets)) } @@ -122,7 +125,7 @@ func TestResolveTargets_ExplicitTargets_ReturnsSelfPlusAssumed(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d", len(errs)) @@ -173,11 +176,14 @@ func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { SkipSelf: true, } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d", len(errs)) } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } if len(targets) != 1 { t.Fatalf("expected 1 target (no self), got %d", len(targets)) } @@ -202,7 +208,7 @@ func TestResolveTargets_PartialFailure(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) // Self + 222 succeeded, 333 failed if len(targets) != 2 { @@ -238,7 +244,7 @@ func TestResolveTargets_OrgDiscovery(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) @@ -286,7 +292,7 @@ func TestResolveTargets_SelfInExplicitTargets_NotAssumed(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index b9a0986..a678867 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -74,7 +74,7 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { orgClient = organizations.NewFromConfig(cfg) } - targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + callerAccount, targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) // Print target errors to stderr and check for fatal failure for _, te := range targetErrors { @@ -84,9 +84,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { return nil, fmt.Errorf("no scannable targets resolved (errors: %d)", len(targetErrors)) } - // Determine the caller account from the first target - callerAccount := targets[0].AccountID - // Determine regions to scan regions := opts.Regions if len(regions) == 0 { @@ -106,7 +103,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { type targetResult struct { instances []models.GPUInstance regions []string - err error } targetResults := make(chan targetResult, len(targets)) @@ -116,8 +112,8 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { wg.Add(1) go func(target Target) { defer wg.Done() - instances, scannedRegions, scanErr := scanTarget(ctx, target, regions, opts) - targetResults <- targetResult{instances: instances, regions: scannedRegions, err: scanErr} + instances, scannedRegions := scanTarget(ctx, target, regions, opts) + targetResults <- targetResult{instances: instances, regions: scannedRegions} }(t) } @@ -130,10 +126,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { regionSet := make(map[string]bool) for res := range targetResults { - if res.err != nil { - fmt.Fprintf(os.Stderr, " warning: target scan error: %v\n", res.err) - continue - } allInstances = append(allInstances, res.instances...) for _, r := range res.regions { regionSet[r] = true @@ -211,7 +203,7 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { // scanTarget scans all regions for a single target account, including // Cost Explorer enrichment (which is account-scoped). -func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string, error) { +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { type regionResult struct { region string instances []models.GPUInstance @@ -257,7 +249,7 @@ func scanTarget(ctx context.Context, target Target, regions []string, opts ScanO } } - return allInstances, scannedRegions, nil + return allInstances, scannedRegions } func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, opts ScanOptions) ([]models.GPUInstance, error) { diff --git a/internal/providers/aws/summary.go b/internal/providers/aws/summary.go index bae351a..5c6f715 100644 --- a/internal/providers/aws/summary.go +++ b/internal/providers/aws/summary.go @@ -1,3 +1,6 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + package aws import ( @@ -16,18 +19,7 @@ func BuildSummary(instances []models.GPUInstance) models.ScanSummary { s.TotalMonthlyCost += inst.MonthlyCost s.TotalEstimatedWaste += inst.EstimatedSavings - maxSeverity := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSeverity = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { - maxSeverity = models.SeverityWarning - } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { - maxSeverity = models.SeverityInfo - } - } - - switch maxSeverity { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: s.CriticalCount++ case models.SeverityWarning: @@ -67,15 +59,7 @@ func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary ts.TotalMonthlyCost += inst.MonthlyCost ts.TotalEstimatedWaste += inst.EstimatedSavings - maxSev := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSev = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { - maxSev = models.SeverityWarning - } - } - switch maxSev { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: ts.CriticalCount++ case models.SeverityWarning: diff --git a/internal/providers/aws/summary_test.go b/internal/providers/aws/summary_test.go index b429e39..24702ec 100644 --- a/internal/providers/aws/summary_test.go +++ b/internal/providers/aws/summary_test.go @@ -1,3 +1,6 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + package aws import (