diff --git a/.claude/agents/security-reviewer.md b/.claude/agents/security-reviewer.md new file mode 100644 index 00000000..6455693c --- /dev/null +++ b/.claude/agents/security-reviewer.md @@ -0,0 +1,43 @@ +--- +name: security-reviewer +description: Deep security audit of code for vulnerabilities, injection, auth flaws, crypto weaknesses, and OWASP Top 10 issues +model: sonnet +tools: Read, Bash, Grep, Glob, WebFetch, WebSearch, TaskCreate, TaskUpdate +--- + +You are a senior application security engineer. When reviewing code: + +## Audit Checklist + +1. **OWASP Top 10**: Check for broken access control, cryptographic failures, injection, insecure design, security misconfiguration, vulnerable components, auth failures, data integrity failures, logging failures, SSRF. + +2. **Input Validation**: Every input from users, HTTP headers, URL params, and files must be validated before use. Check length limits, character allowlists, and type constraints. + +3. **Injection Vectors**: SQL/NoSQL/etcd key injection, command injection, expression language injection, path traversal, header injection. + +4. **Authentication & Authorization**: Missing auth checks, spoofable headers, token validation, session management, privilege escalation paths. + +5. **Cryptography**: Weak algorithms (MD5, SHA1), missing TLS verification, hardcoded secrets, insecure random, missing certificate validation. + +6. **Information Disclosure**: Debug endpoints, verbose error messages, stack traces, sensitive data in logs, token leakage. + +7. **Denial of Service**: Unbounded allocations, missing rate limits, slow attacks (Slowloris, ReDoS), resource exhaustion vectors. + +8. **Data Integrity**: Missing integrity checks, unsafe deserialization, TOCTOU races, partial failure without rollback. + +## Report Format + +For each finding: +- **Severity**: Critical / High / Medium / Low +- **File**: path and line numbers +- **Attack Scenario**: how an attacker would exploit this +- **Current Code**: the vulnerable code snippet +- **Fix**: concrete remediation with code example + +## Commands + +- Read source files with line numbers to verify findings +- Use Grep to search for patterns (hardcoded keys, missing validation, unsafe functions) +- Use Bash to run `go vet`, `gosec`, or `govulncheck` when available +- Use WebSearch to check for known CVEs in dependencies +- If you find issues, create tasks via TaskCreate for tracking diff --git a/pkg/assertion/asserter.go b/pkg/assertion/asserter.go index bd024a36..ce3e6d99 100644 --- a/pkg/assertion/asserter.go +++ b/pkg/assertion/asserter.go @@ -169,7 +169,7 @@ func (a *WebHookAsserter) AssertToken(token string, idpType string, allowedIDD s return nil, fmt.Errorf("asserter error, status code: %d", resp.StatusCode) } - raw, errRaw := io.ReadAll(resp.Body) + raw, errRaw := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if errRaw != nil { log.Errorf("ReadAll error: %v", errRaw) return nil, errRaw diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 14e4b21c..2ca13e90 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -29,6 +29,12 @@ func (e *speedleError) Cause() error { return e.cause } +// Unwrap implements the standard library errors.Unwrap interface so that +// errors.Is and errors.As can traverse the wrapped error chain. +func (e *speedleError) Unwrap() error { + return e.cause +} + func (e *speedleError) Code() ErrorCode { return e.code } diff --git a/pkg/eval/evaluator.go b/pkg/eval/evaluator.go index 464c8009..9aefc61d 100644 --- a/pkg/eval/evaluator.go +++ b/pkg/eval/evaluator.go @@ -185,9 +185,12 @@ func (p *PolicyEvalImpl) populateContext(ctx *adsapi.RequestContext) (*internalR Groups: []string{}, Entities: []string{}, } + // Declare identity-derived values at function scope so they can be + // re-asserted after the user-attribute merge below. + var user, entity interface{} + groups := []interface{}{} if ctx.Subject != nil { - groups := make([]interface{}, 0, len(ctx.Subject.Principals)) - var user, entity interface{} + groups = make([]interface{}, 0, len(ctx.Subject.Principals)) for _, principal := range ctx.Subject.Principals { encodedPrincipal, err := subjectutils.EncodePrincipal(principal) if err != nil { @@ -240,8 +243,18 @@ func (p *PolicyEvalImpl) populateContext(ctx *adsapi.RequestContext) (*internalR newCtx.Attributes[key] = value } - // Re-assert built-in attributes after user attributes are merged - // to prevent user-supplied attributes from overriding built-in values. + // Re-assert ALL built-in attributes after user attributes are merged + // to prevent user-supplied attributes from overriding identity-derived + // or system-derived values (e.g. a caller injecting request_user="admin" + // to bypass a policy condition). This covers request_user, request_groups, + // request_entity (identity-derived) as well as resource/action/time. + if user != nil { + newCtx.Attributes[adsapi.BuiltIn_Attr_RequestUser] = user + } + newCtx.Attributes[adsapi.BuiltIn_Attr_RequestGroups] = groups + if entity != nil { + newCtx.Attributes[adsapi.BuiltIn_Attr_RequestEntity] = entity + } newCtx.Attributes[adsapi.BuiltIn_Attr_RequestResource] = ctx.Resource newCtx.Attributes[adsapi.BuiltIn_Attr_RequestAction] = ctx.Action newCtx.Attributes[adsapi.BuiltIn_Attr_RequestTime] = now.Unix() diff --git a/pkg/eval/function/builtin_functions.go b/pkg/eval/function/builtin_functions.go index 42ac6a52..363d3aa7 100644 --- a/pkg/eval/function/builtin_functions.go +++ b/pkg/eval/function/builtin_functions.go @@ -105,8 +105,14 @@ func IsSubSet(args ...interface{}) (interface{}, error) { // element, the separatorStage places it as a bare value instead of // wrapping it in a slice. Detect and repair that case. s1, s2 = args[0], args[1] - if reflect.TypeOf(s1).Kind() != reflect.Slice && - reflect.TypeOf(s2).Kind() == reflect.Slice { + // Guard against nil arguments (e.g. a referenced attribute that + // resolved to JSON null). reflect.TypeOf(nil) returns nil, and + // calling .Kind() on a nil Type panics. + t1, t2 := reflect.TypeOf(s1), reflect.TypeOf(s2) + if t1 == nil || t2 == nil { + return false, nil + } + if t1.Kind() != reflect.Slice && t2.Kind() == reflect.Slice { s1 = []interface{}{args[0]} } } else { @@ -120,8 +126,9 @@ func IsSubSet(args ...interface{}) (interface{}, error) { } // Both s1 and s2 must be slices. - if reflect.TypeOf(s1).Kind() != reflect.Slice || - reflect.TypeOf(s2).Kind() != reflect.Slice { + t1, t2 := reflect.TypeOf(s1), reflect.TypeOf(s2) + if t1 == nil || t2 == nil || + t1.Kind() != reflect.Slice || t2.Kind() != reflect.Slice { return nil, err } diff --git a/pkg/store/etcd/discover.go b/pkg/store/etcd/discover.go index 95ef973a..d5f1e86a 100644 --- a/pkg/store/etcd/discover.go +++ b/pkg/store/etcd/discover.go @@ -31,6 +31,9 @@ func (s *Store) SaveDiscoverRequest(request *ads.RequestContext) error { } func (s *Store) PutRequest(request *ads.RequestContext) (int64, error) { + if err := validateServiceName(request.ServiceName); err != nil { + return 0, err + } // Redact token before persisting to prevent credential leakage in discover storage. if request.Subject != nil && request.Subject.Token != "" { origToken := request.Subject.Token @@ -118,6 +121,11 @@ func (s *Store) GetLastDiscoverRequest(serviceName string) (*ads.RequestContext, } func (s *Store) GetDiscoverRequestsSinceRevision(serviceName string, revision int64) ([]*ads.RequestContext, int64, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return nil, 0, err + } + } getOpts := []clientv3.OpOption{clientv3.WithMinCreateRev(revision + 1), clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByCreateRevision, clientv3.SortAscend)} keyPrefix4Search := DiscoverPrefix if len(serviceName) > 0 { @@ -171,6 +179,11 @@ func (s *Store) GetRequests(keyPrefix string, pageSize int64) ([]*ads.RequestCon } func (s *Store) GetDiscoverRequests(serviceName string) ([]*ads.RequestContext, int64, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return nil, 0, err + } + } if len(serviceName) == 0 { return s.GetRequests(DiscoverPrefix, DefaultPageSize) } else { @@ -180,6 +193,11 @@ func (s *Store) GetDiscoverRequests(serviceName string) ([]*ads.RequestContext, } func (s *Store) ResetDiscoverRequests(serviceName string) error { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return err + } + } var err error if len(serviceName) == 0 { _, err = s.client.Delete(context.Background(), DiscoverPrefix, clientv3.WithPrefix()) diff --git a/pkg/store/etcd/etcdStore.go b/pkg/store/etcd/etcdStore.go index b1827760..c9117ff0 100644 --- a/pkg/store/etcd/etcdStore.go +++ b/pkg/store/etcd/etcdStore.go @@ -492,6 +492,18 @@ func validateServiceName(serviceName string) error { return nil } +// validateID validates a policy or role-policy ID to prevent etcd key +// injection (e.g. "../../functions/victim" traversing the key hierarchy). +func validateID(id string) error { + if id == "" { + return errors.New(errors.InvalidRequest, "id cannot be empty") + } + if strings.Contains(id, KeySeparator) { + return errors.Errorf(errors.InvalidRequest, "id %q cannot contain %q", id, KeySeparator) + } + return nil +} + func (s *Store) CreateFunction(function *pms.Function) (*pms.Function, error) { if err := utils.ValidateFunc(function); err != nil { return nil, err @@ -799,6 +811,11 @@ func computeWatchBackoff(consecutiveFails int) time.Duration { // For policy manager func (s *Store) ListAllPolicies(serviceName string, filter string) ([]*pms.Policy, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return nil, err + } + } f := parseFilter(filter) policyKeyPrefix := s.KeyPrefix + ServicesKey + KeySeparator + serviceName + KeySeparator + PoliciesKey @@ -828,6 +845,11 @@ func (s *Store) ListAllPolicies(serviceName string, filter string) ([]*pms.Polic } func (s *Store) GetPolicyCount(serviceName string) (int64, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return 0, err + } + } var policyCount int64 = 0 if len(serviceName) > 0 { // Get the policy count in the specified service @@ -867,6 +889,12 @@ func (s *Store) getPolicyCountImpl(serviceName string) (int64, error) { } func (s *Store) GetPolicy(serviceName string, id string) (*pms.Policy, error) { + if err := validateServiceName(serviceName); err != nil { + return nil, err + } + if err := validateID(id); err != nil { + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() policyKey := s.KeyPrefix + ServicesKey + KeySeparator + serviceName + KeySeparator + PoliciesKey + KeySeparator + id @@ -887,6 +915,9 @@ func (s *Store) GetPolicy(serviceName string, id string) (*pms.Policy, error) { // TODO: to be implemented func (s *Store) GetRolePolicies(serviceName string, startID string, amount int) (policies []*pms.RolePolicy, nextID string, err error) { + if err := validateServiceName(serviceName); err != nil { + return nil, "", err + } if amount <= 0 { return nil, "", errors.Errorf(errors.InvalidRequest, "invalid amount %d", amount) } @@ -922,6 +953,9 @@ func (s *Store) GetRolePolicies(serviceName string, startID string, amount int) } func (s *Store) DeletePolicy(serviceName string, id string) error { + if err := validateID(id); err != nil { + return err + } if err := validateServiceName(serviceName); err != nil { return err } @@ -945,6 +979,9 @@ func (s *Store) DeletePolicy(serviceName string, id string) error { } func (s *Store) DeletePolicies(serviceName string) error { + if err := validateServiceName(serviceName); err != nil { + return err + } ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() _, err := s.client.KV.Txn(ctx).Then( @@ -1038,6 +1075,11 @@ func (s *Store) CreatePolicy(serviceName string, policy *pms.Policy) (*pms.Polic // For role policy manager func (s *Store) ListAllRolePolicies(serviceName string, filter string) ([]*pms.RolePolicy, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return nil, err + } + } f := parseFilter(filter) rolePolicyKeyPrefix := s.KeyPrefix + ServicesKey + KeySeparator + serviceName + KeySeparator + RolePoliciesKey responses, err := s.prefixGet(rolePolicyKeyPrefix) @@ -1065,6 +1107,11 @@ func (s *Store) ListAllRolePolicies(serviceName string, filter string) ([]*pms.R } func (s *Store) GetRolePolicyCount(serviceName string) (int64, error) { + if serviceName != "" { + if err := validateServiceName(serviceName); err != nil { + return 0, err + } + } var rolePolicyCount int64 = 0 if len(serviceName) > 0 { // Get the rolePolicy count in the specified service @@ -1105,6 +1152,9 @@ func (s *Store) getRolePolicyCountImpl(serviceName string) (int64, error) { // TODO: to be implemented func (s *Store) GetPolicies(serviceName string, startID string, amount int) (policies []*pms.Policy, nextID string, err error) { + if err := validateServiceName(serviceName); err != nil { + return nil, "", err + } if amount <= 0 { return nil, "", errors.Errorf(errors.InvalidRequest, "invalid input amount %d", amount) } @@ -1139,6 +1189,12 @@ func (s *Store) GetPolicies(serviceName string, startID string, amount int) (pol } func (s *Store) GetRolePolicy(serviceName string, id string) (*pms.RolePolicy, error) { + if err := validateServiceName(serviceName); err != nil { + return nil, err + } + if err := validateID(id); err != nil { + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() rolePolicyKey := s.KeyPrefix + ServicesKey + KeySeparator + serviceName + KeySeparator + RolePoliciesKey + KeySeparator + id @@ -1158,6 +1214,9 @@ func (s *Store) GetRolePolicy(serviceName string, id string) (*pms.RolePolicy, e } func (s *Store) DeleteRolePolicy(serviceName string, id string) error { + if err := validateID(id); err != nil { + return err + } if err := validateServiceName(serviceName); err != nil { return err } @@ -1181,6 +1240,9 @@ func (s *Store) DeleteRolePolicy(serviceName string, id string) error { } func (s *Store) DeleteRolePolicies(serviceName string) error { + if err := validateServiceName(serviceName); err != nil { + return err + } ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() _, err := s.client.KV.Txn(ctx).Then( diff --git a/pkg/store/mongodb/mongodbStore.go b/pkg/store/mongodb/mongodbStore.go index f4cf0fb8..0f8eb6f1 100644 --- a/pkg/store/mongodb/mongodbStore.go +++ b/pkg/store/mongodb/mongodbStore.go @@ -300,10 +300,17 @@ func (s *Store) Watch() (pms.StorageChangeChannel, error) { hasNext := changeStream.Next(nextCtx) nextCancel() if !hasNext { + // Distinguish a deadline expiry (normal idle — retry) from a + // genuine stream error or stop signal (exit). changeStream.Err() + // is nil on a clean context deadline; it is non-nil on a real + // error. Returning on idle would kill the watch after 30s of + // no policy changes, breaking incremental cache updates. if err := changeStream.Err(); err != nil { - log.Error(err) + log.Errorf("MongoDB change stream error: %v", err) + return } - return + // Deadline expired with no events — loop and wait again. + continue } // A new event variable should be declared for each event. var event bson.M diff --git a/pkg/svcs/adsrest/discover.go b/pkg/svcs/adsrest/discover.go index 326a236a..abe00207 100644 --- a/pkg/svcs/adsrest/discover.go +++ b/pkg/svcs/adsrest/discover.go @@ -34,18 +34,18 @@ func (e *RESTService) Discover(w http.ResponseWriter, r *http.Request) { } result, reason, err := e.Evaluator.Discover(*context) + if err != nil { + httputils.HandleError(w, err) + logging.WriteSimpleFailedAuditLog("Discovery", context, err.Error()) + return + } response := IsAllowedResponse{ Allowed: result, Reason: int32(reason), } // Audit log - if err != nil { - response.ErrorMessage = err.Error() - logging.WriteSimpleFailedAuditLog("Discovery", context, err.Error()) - } else { - logging.WriteSimpleSucceededAuditLog("Discovery", context, nil) - } + logging.WriteSimpleSucceededAuditLog("Discovery", context, nil) httputils.SendOKResponse(w, &response) } diff --git a/pkg/svcs/adsrest/routers.go b/pkg/svcs/adsrest/routers.go index 46552233..ce6fc75c 100644 --- a/pkg/svcs/adsrest/routers.go +++ b/pkg/svcs/adsrest/routers.go @@ -79,7 +79,7 @@ func NewRouter(evaluator eval.InternalEvaluator) (*mux.Router, error) { router := mux.NewRouter().StrictSlash(true) for _, route := range *routes { var handler http.Handler - handler = route.HandlerFunc + handler = svcs.HTTPPanicRecoveryMiddleware(route.HandlerFunc) router. Methods(route.Method). diff --git a/pkg/svcs/panic_interceptor.go b/pkg/svcs/panic_interceptor.go index 9737a97b..0b23d844 100644 --- a/pkg/svcs/panic_interceptor.go +++ b/pkg/svcs/panic_interceptor.go @@ -5,6 +5,7 @@ package svcs import ( "context" + "net/http" "runtime/debug" log "github.com/sirupsen/logrus" @@ -27,3 +28,18 @@ func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor { return handler(ctx, req) } } + +// HTTPPanicRecoveryMiddleware returns an HTTP middleware that recovers from +// panics in HTTP handlers. It logs the panic and stack trace, then returns +// HTTP 500 instead of crashing the server process. +func HTTPPanicRecoveryMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Errorf("Panic recovered in HTTP handler %s %s: %v\n%s", r.Method, r.URL.Path, err, debug.Stack()) + http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/svcs/pmsrest/routers.go b/pkg/svcs/pmsrest/routers.go index aa49cbaa..8691b0f3 100644 --- a/pkg/svcs/pmsrest/routers.go +++ b/pkg/svcs/pmsrest/routers.go @@ -246,7 +246,7 @@ func NewRouter(ps pms.PolicyStoreManager) (*mux.Router, error) { for _, route := range *routes { var handler http.Handler - handler = route.HandlerFunc + handler = svcs.HTTPPanicRecoveryMiddleware(route.HandlerFunc) router. Methods(route.Method). Path(route.Pattern).