From b306faeaf8a185ffa50290ec902ac87848df6191 Mon Sep 17 00:00:00 2001 From: "Jakub A. W" Date: Wed, 1 Jul 2026 00:49:37 +0200 Subject: [PATCH 1/2] refactor: remove dead guardrails wrapper, cut redundancies, fix registry trim Quality pass focused on removing dead/duplicated code and one routing bug, while keeping behavior and test coverage intact. fix(registry): provider types were stored verbatim at registration while every reader except ProviderByType trimmed on lookup, so a configured type with surrounding whitespace (YAML/env) made type-based routing silently return nil. Normalize the type at registration; add a regression test. refactor(guardrails): delete the dead GuardedProvider wrapper (provider.go, 518 lines) and the unused RequestPatcher/BatchPreparer in executor.go. The live request path applies guardrails through WorkflowRequestPatcher / WorkflowBatchPreparer (wired in app.go); the wrapper was only reachable from its own tests. Live clone helpers moved to clone.go. Coverage of the shared rewrite functions (processGuardedChat/Responses/BatchRequest) is preserved via a trimmed test-only harness; the 6 pure-delegation tests were dropped and the server tests retargeted onto the live Workflow* patchers. refactor: assorted redundancy cleanups - - providers/router.go: collapse 3 near-identical Native*ProviderTypes loops into one providerTypesSupporting helper. - core/errors.go: add NewEmptyProviderResponseError, replacing 8 duplicated constructions in native_response_service.go and inference_execute.go. - replace 4 hand-rolled map[string]any clone helpers with stdlib maps.Clone. - anthropic: drop an unreachable len(tokens)==0 branch. Net ~-469 lines. go build/vet/test and gofmt all clean. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/core/errors.go | 5 + internal/gateway/inference_execute.go | 2 +- internal/guardrails/clone.go | 88 +++ internal/guardrails/executor.go | 52 -- internal/guardrails/provider.go | 518 ------------------ internal/guardrails/provider_harness_test.go | 125 +++++ internal/guardrails/provider_test.go | 126 +---- .../guardrails/responses_message_apply.go | 10 +- .../anthropic/request_translation.go | 4 - internal/providers/registry.go | 1 + internal/providers/registry_test.go | 22 + internal/providers/responses_adapter.go | 7 +- internal/providers/router.go | 60 +- internal/responsecache/stream_cache.go | 10 +- internal/server/handlers_test.go | 13 +- internal/server/native_response_service.go | 14 +- internal/usage/extractor.go | 7 +- 17 files changed, 293 insertions(+), 771 deletions(-) create mode 100644 internal/guardrails/clone.go delete mode 100644 internal/guardrails/provider.go create mode 100644 internal/guardrails/provider_harness_test.go diff --git a/internal/core/errors.go b/internal/core/errors.go index b1097796..bceccab9 100644 --- a/internal/core/errors.go +++ b/internal/core/errors.go @@ -131,6 +131,11 @@ func NewProviderError(provider string, statusCode int, message string, err error } } +// NewEmptyProviderResponseError reports that a provider returned no response body (502). +func NewEmptyProviderResponseError(provider string) *GatewayError { + return NewProviderError(provider, http.StatusBadGateway, "provider returned empty response", nil) +} + // NewRateLimitError creates a new rate limit error (429) func NewRateLimitError(provider string, message string) *GatewayError { return &GatewayError{ diff --git a/internal/gateway/inference_execute.go b/internal/gateway/inference_execute.go index f1d6307d..2deae250 100644 --- a/internal/gateway/inference_execute.go +++ b/internal/gateway/inference_execute.go @@ -461,7 +461,7 @@ func (o *InferenceOrchestrator) streamResponsesProviderCall(ctx context.Context, } func emptyProviderResponseError(providerType string) *core.GatewayError { - return core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil) + return core.NewEmptyProviderResponseError(providerType) } func emptyProviderStreamError(providerType string) *core.GatewayError { diff --git a/internal/guardrails/clone.go b/internal/guardrails/clone.go new file mode 100644 index 00000000..7f1afe5e --- /dev/null +++ b/internal/guardrails/clone.go @@ -0,0 +1,88 @@ +package guardrails + +import "gomodel/internal/core" + +// cloneToolCalls deep-copies tool calls so guardrail rewrites never mutate the +// caller's original message slice. +func cloneToolCalls(toolCalls []core.ToolCall) []core.ToolCall { + if len(toolCalls) == 0 { + return nil + } + cloned := make([]core.ToolCall, len(toolCalls)) + for i, toolCall := range toolCalls { + cloned[i] = core.ToolCall{ + ID: toolCall.ID, + Type: toolCall.Type, + Function: core.FunctionCall{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + ExtraFields: core.CloneUnknownJSONFields(toolCall.Function.ExtraFields), + }, + ExtraFields: core.CloneUnknownJSONFields(toolCall.ExtraFields), + } + } + return cloned +} + +func cloneChatMessageEnvelope(message core.Message) core.Message { + return core.Message{ + Role: message.Role, + ToolCallID: message.ToolCallID, + ContentNull: message.ContentNull, + Content: cloneMessageContent(message.Content), + ToolCalls: cloneToolCalls(message.ToolCalls), + ExtraFields: core.CloneUnknownJSONFields(message.ExtraFields), + } +} + +func cloneMessageContent(content any) any { + switch value := content.(type) { + case nil: + return nil + case string: + return value + case []core.ContentPart: + return cloneContentParts(value) + default: + parts, ok := core.NormalizeContentParts(content) + if !ok { + return value + } + return cloneContentParts(parts) + } +} + +func cloneContentParts(parts []core.ContentPart) []core.ContentPart { + if len(parts) == 0 { + return nil + } + cloned := make([]core.ContentPart, len(parts)) + for i, part := range parts { + cloned[i] = cloneContentPart(part) + } + return cloned +} + +func cloneContentPart(part core.ContentPart) core.ContentPart { + cloned := core.ContentPart{ + Type: part.Type, + Text: part.Text, + ExtraFields: core.CloneUnknownJSONFields(part.ExtraFields), + } + if part.ImageURL != nil { + cloned.ImageURL = &core.ImageURLContent{ + URL: part.ImageURL.URL, + Detail: part.ImageURL.Detail, + MediaType: part.ImageURL.MediaType, + ExtraFields: core.CloneUnknownJSONFields(part.ImageURL.ExtraFields), + } + } + if part.InputAudio != nil { + cloned.InputAudio = &core.InputAudioContent{ + Data: part.InputAudio.Data, + Format: part.InputAudio.Format, + ExtraFields: core.CloneUnknownJSONFields(part.InputAudio.ExtraFields), + } + } + return cloned +} diff --git a/internal/guardrails/executor.go b/internal/guardrails/executor.go index f956ad28..c2f2b3ac 100644 --- a/internal/guardrails/executor.go +++ b/internal/guardrails/executor.go @@ -8,58 +8,6 @@ import ( "gomodel/internal/core" ) -// RequestPatcher applies guardrails to translated requests without owning -// provider execution. -type RequestPatcher struct { - pipeline *Pipeline -} - -// NewRequestPatcher creates an explicit translated-request patcher. -func NewRequestPatcher(pipeline *Pipeline) *RequestPatcher { - return &RequestPatcher{pipeline: pipeline} -} - -// PatchChatRequest applies guardrails to a translated chat request. -func (p *RequestPatcher) PatchChatRequest(ctx context.Context, req *core.ChatRequest) (*core.ChatRequest, error) { - return processGuardedChat(ctx, p.pipeline, req) -} - -// PatchResponsesRequest applies guardrails to a translated responses request. -func (p *RequestPatcher) PatchResponsesRequest(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesRequest, error) { - return processGuardedResponses(ctx, p.pipeline, req) -} - -// BatchPreparer applies guardrails to native batch subrequests before provider -// submission. -type BatchPreparer struct { - provider core.RoutableProvider - pipeline *Pipeline -} - -// NewBatchPreparer creates an explicit native-batch preparer. -func NewBatchPreparer(provider core.RoutableProvider, pipeline *Pipeline) *BatchPreparer { - return &BatchPreparer{ - provider: provider, - pipeline: pipeline, - } -} - -// PrepareBatchRequest applies guardrails to batch subrequests without -// submitting the batch to the wrapped provider. -func (p *BatchPreparer) PrepareBatchRequest(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchRewriteResult, error) { - return processGuardedBatchRequest(ctx, providerType, req, p.pipeline, p.batchFileTransport()) -} - -func (p *BatchPreparer) batchFileTransport() core.BatchFileTransport { - if p == nil || p.provider == nil { - return nil - } - if files, ok := p.provider.(core.NativeFileRoutableProvider); ok { - return files - } - return nil -} - func processGuardedBatchRequest( ctx context.Context, providerType string, diff --git a/internal/guardrails/provider.go b/internal/guardrails/provider.go deleted file mode 100644 index 41bd84c5..00000000 --- a/internal/guardrails/provider.go +++ /dev/null @@ -1,518 +0,0 @@ -package guardrails - -import ( - "context" - "io" - - "gomodel/internal/batchrewrite" - "gomodel/internal/core" -) - -// GuardedProvider wraps a RoutableProvider and applies the guardrails pipeline -// before routing requests to providers. It implements core.RoutableProvider. -// -// Adapters convert between concrete request types and the normalized []Message -// DTO that guardrails operate on. This decouples guardrails from API-specific types. -type GuardedProvider struct { - inner core.RoutableProvider - pipeline *Pipeline - options Options -} - -// Options controls optional behavior of GuardedProvider. -type Options struct { - EnableForBatchProcessing bool - // DisableTranslatedRequestProcessing lets an explicit server-side executor own - // translated-route patching while this wrapper still handles batch rewriting. - DisableTranslatedRequestProcessing bool -} - -// NewGuardedProvider creates a RoutableProvider that applies guardrails -// before delegating to the inner provider. -func NewGuardedProvider(inner core.RoutableProvider, pipeline *Pipeline) *GuardedProvider { - return NewGuardedProviderWithOptions(inner, pipeline, Options{}) -} - -// NewGuardedProviderWithOptions creates a RoutableProvider with explicit options. -func NewGuardedProviderWithOptions(inner core.RoutableProvider, pipeline *Pipeline, options Options) *GuardedProvider { - return &GuardedProvider{ - inner: inner, - pipeline: pipeline, - options: options, - } -} - -// Supports delegates to the inner provider. -func (g *GuardedProvider) Supports(model string) bool { - return g.inner.Supports(model) -} - -// GetProviderType delegates to the inner provider. -func (g *GuardedProvider) GetProviderType(model string) string { - return g.inner.GetProviderType(model) -} - -// ModelCount delegates to the inner provider when it exposes registry size. -// It returns -1 when the wrapped provider does not expose model count. -func (g *GuardedProvider) ModelCount() int { - if counted, ok := g.inner.(interface{ ModelCount() int }); ok { - return counted.ModelCount() - } - return -1 -} - -// NativeFileProviderTypes delegates provider capability inventory to the inner -// provider when available. -func (g *GuardedProvider) NativeFileProviderTypes() []string { - if typed, ok := g.inner.(core.NativeFileProviderTypeLister); ok { - return typed.NativeFileProviderTypes() - } - return nil -} - -// NativeBatchProviderTypes delegates provider capability inventory to the inner -// provider when available. -func (g *GuardedProvider) NativeBatchProviderTypes() []string { - if typed, ok := g.inner.(core.NativeBatchProviderTypeLister); ok { - return typed.NativeBatchProviderTypes() - } - return nil -} - -// NativeResponseProviderTypes delegates provider capability inventory to the -// inner provider when available. -func (g *GuardedProvider) NativeResponseProviderTypes() []string { - if typed, ok := g.inner.(core.NativeResponseProviderTypeLister); ok { - return typed.NativeResponseProviderTypes() - } - return nil -} - -// ChatCompletion extracts messages, applies guardrails, then routes the request. -func (g *GuardedProvider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { - if g.options.DisableTranslatedRequestProcessing { - return g.inner.ChatCompletion(ctx, req) - } - modified, err := processGuardedChat(ctx, g.pipeline, req) - if err != nil { - return nil, err - } - return g.inner.ChatCompletion(ctx, modified) -} - -// StreamChatCompletion extracts messages, applies guardrails, then routes the streaming request. -func (g *GuardedProvider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) { - if g.options.DisableTranslatedRequestProcessing { - return g.inner.StreamChatCompletion(ctx, req) - } - modified, err := processGuardedChat(ctx, g.pipeline, req) - if err != nil { - return nil, err - } - return g.inner.StreamChatCompletion(ctx, modified) -} - -// ListModels delegates directly to the inner provider (no guardrails needed). -func (g *GuardedProvider) ListModels(ctx context.Context) (*core.ModelsResponse, error) { - return g.inner.ListModels(ctx) -} - -// Embeddings delegates directly to the inner provider (no guardrails needed for embeddings). -func (g *GuardedProvider) Embeddings(ctx context.Context, req *core.EmbeddingRequest) (*core.EmbeddingResponse, error) { - return g.inner.Embeddings(ctx, req) -} - -// Responses extracts messages, applies guardrails, then routes the request. -func (g *GuardedProvider) Responses(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesResponse, error) { - if g.options.DisableTranslatedRequestProcessing { - return g.inner.Responses(ctx, req) - } - modified, err := processGuardedResponses(ctx, g.pipeline, req) - if err != nil { - return nil, err - } - return g.inner.Responses(ctx, modified) -} - -// StreamResponses extracts messages, applies guardrails, then routes the streaming request. -func (g *GuardedProvider) StreamResponses(ctx context.Context, req *core.ResponsesRequest) (io.ReadCloser, error) { - if g.options.DisableTranslatedRequestProcessing { - return g.inner.StreamResponses(ctx, req) - } - modified, err := processGuardedResponses(ctx, g.pipeline, req) - if err != nil { - return nil, err - } - return g.inner.StreamResponses(ctx, modified) -} - -func (g *GuardedProvider) nativeBatchRouter() (core.NativeBatchRoutableProvider, error) { - bp, ok := g.inner.(core.NativeBatchRoutableProvider) - if !ok { - return nil, core.NewInvalidRequestError("batch routing is not supported by the current provider router", nil) - } - return bp, nil -} - -func (g *GuardedProvider) nativeBatchHintRouter() (core.NativeBatchHintRoutableProvider, error) { - hinted, ok := g.inner.(core.NativeBatchHintRoutableProvider) - if !ok { - return nil, core.NewInvalidRequestError("batch hint routing is not supported by the current provider router", nil) - } - return hinted, nil -} - -func (g *GuardedProvider) nativeFileRouter() (core.NativeFileRoutableProvider, error) { - fp, ok := g.inner.(core.NativeFileRoutableProvider) - if !ok { - return nil, core.NewInvalidRequestError("file routing is not supported by the current provider router", nil) - } - return fp, nil -} - -func (g *GuardedProvider) nativeResponseLifecycleRouter() (core.NativeResponseLifecycleRoutableProvider, error) { - responses, ok := g.inner.(core.NativeResponseLifecycleRoutableProvider) - if !ok { - return nil, core.NewInvalidRequestError("response lifecycle routing is not supported by the current provider router", nil) - } - return responses, nil -} - -func (g *GuardedProvider) nativeResponseUtilityRouter() (core.NativeResponseUtilityRoutableProvider, error) { - responses, ok := g.inner.(core.NativeResponseUtilityRoutableProvider) - if !ok { - return nil, core.NewInvalidRequestError("response utility routing is not supported by the current provider router", nil) - } - return responses, nil -} - -func (g *GuardedProvider) batchFileTransport() core.BatchFileTransport { - files, err := g.nativeFileRouter() - if err != nil { - return nil - } - return files -} - -func (g *GuardedProvider) passthroughRouter() (core.RoutablePassthrough, error) { - pp, ok := g.inner.(core.RoutablePassthrough) - if !ok { - return nil, core.NewInvalidRequestError("passthrough routing is not supported by the current provider router", nil) - } - return pp, nil -} - -// CreateBatch delegates native batch creation and optionally applies guardrails to inline items. -func (g *GuardedProvider) CreateBatch(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchResponse, error) { - bp, err := g.nativeBatchRouter() - if err != nil { - return nil, err - } - if !g.options.EnableForBatchProcessing { - return bp.CreateBatch(ctx, providerType, req) - } - - result, err := processGuardedBatchRequest(ctx, providerType, req, g.pipeline, g.batchFileTransport()) - if err != nil { - return nil, err - } - batchrewrite.RecordPreparation(ctx, req, result.Request) - resp, err := bp.CreateBatch(ctx, providerType, result.Request) - if err != nil { - batchrewrite.CleanupFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") - return nil, err - } - batchrewrite.CleanupSupersededFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") - return resp, nil -} - -// CreateBatchWithHints delegates hint-aware native batch creation while preserving -// guardrail batch processing when enabled. -func (g *GuardedProvider) CreateBatchWithHints(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchResponse, map[string]string, error) { - hinted, err := g.nativeBatchHintRouter() - if err != nil { - return nil, nil, err - } - if !g.options.EnableForBatchProcessing { - return hinted.CreateBatchWithHints(ctx, providerType, req) - } - - result, err := processGuardedBatchRequest(ctx, providerType, req, g.pipeline, g.batchFileTransport()) - if err != nil { - return nil, nil, err - } - batchrewrite.RecordPreparation(ctx, req, result.Request) - resp, hints, err := hinted.CreateBatchWithHints(ctx, providerType, result.Request) - if err != nil { - batchrewrite.CleanupFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") - return nil, nil, err - } - batchrewrite.CleanupSupersededFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") - return resp, batchrewrite.MergeEndpointHints(result.RequestEndpointHints, hints), nil -} - -// GetBatch delegates native batch retrieval. -func (g *GuardedProvider) GetBatch(ctx context.Context, providerType, id string) (*core.BatchResponse, error) { - bp, err := g.nativeBatchRouter() - if err != nil { - return nil, err - } - return bp.GetBatch(ctx, providerType, id) -} - -// ListBatches delegates native batch listing. -func (g *GuardedProvider) ListBatches(ctx context.Context, providerType string, limit int, after string) (*core.BatchListResponse, error) { - bp, err := g.nativeBatchRouter() - if err != nil { - return nil, err - } - return bp.ListBatches(ctx, providerType, limit, after) -} - -// CancelBatch delegates native batch cancellation. -func (g *GuardedProvider) CancelBatch(ctx context.Context, providerType, id string) (*core.BatchResponse, error) { - bp, err := g.nativeBatchRouter() - if err != nil { - return nil, err - } - return bp.CancelBatch(ctx, providerType, id) -} - -// GetBatchResults delegates native batch results retrieval. -func (g *GuardedProvider) GetBatchResults(ctx context.Context, providerType, id string) (*core.BatchResultsResponse, error) { - bp, err := g.nativeBatchRouter() - if err != nil { - return nil, err - } - return bp.GetBatchResults(ctx, providerType, id) -} - -// GetBatchResultsWithHints delegates hint-aware native batch results retrieval. -func (g *GuardedProvider) GetBatchResultsWithHints(ctx context.Context, providerType, id string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) { - hinted, err := g.nativeBatchHintRouter() - if err != nil { - return nil, err - } - return hinted.GetBatchResultsWithHints(ctx, providerType, id, endpointByCustomID) -} - -// ClearBatchResultHints delegates cleanup of transient provider-side result hints. -func (g *GuardedProvider) ClearBatchResultHints(providerType, batchID string) { - hinted, err := g.nativeBatchHintRouter() - if err != nil { - return - } - hinted.ClearBatchResultHints(providerType, batchID) -} - -// CreateFile delegates native file upload. -func (g *GuardedProvider) CreateFile(ctx context.Context, providerType string, req *core.FileCreateRequest) (*core.FileObject, error) { - fp, err := g.nativeFileRouter() - if err != nil { - return nil, err - } - return fp.CreateFile(ctx, providerType, req) -} - -// ListFiles delegates native file listing. -func (g *GuardedProvider) ListFiles(ctx context.Context, providerType, purpose string, limit int, after string) (*core.FileListResponse, error) { - fp, err := g.nativeFileRouter() - if err != nil { - return nil, err - } - return fp.ListFiles(ctx, providerType, purpose, limit, after) -} - -// GetFile delegates native file lookup. -func (g *GuardedProvider) GetFile(ctx context.Context, providerType, id string) (*core.FileObject, error) { - fp, err := g.nativeFileRouter() - if err != nil { - return nil, err - } - return fp.GetFile(ctx, providerType, id) -} - -// DeleteFile delegates native file deletion. -func (g *GuardedProvider) DeleteFile(ctx context.Context, providerType, id string) (*core.FileDeleteResponse, error) { - fp, err := g.nativeFileRouter() - if err != nil { - return nil, err - } - return fp.DeleteFile(ctx, providerType, id) -} - -// GetFileContent delegates native file content retrieval. -func (g *GuardedProvider) GetFileContent(ctx context.Context, providerType, id string) (*core.FileContentResponse, error) { - fp, err := g.nativeFileRouter() - if err != nil { - return nil, err - } - return fp.GetFileContent(ctx, providerType, id) -} - -// Passthrough delegates opaque provider-native requests without semantic guardrail processing. -func (g *GuardedProvider) Passthrough(ctx context.Context, providerType string, req *core.PassthroughRequest) (*core.PassthroughResponse, error) { - pp, err := g.passthroughRouter() - if err != nil { - return nil, err - } - return pp.Passthrough(ctx, providerType, req) -} - -// GetResponse delegates native response lookup. -func (g *GuardedProvider) GetResponse(ctx context.Context, providerType, id string, params core.ResponseRetrieveParams) (*core.ResponsesResponse, error) { - responses, err := g.nativeResponseLifecycleRouter() - if err != nil { - return nil, err - } - return responses.GetResponse(ctx, providerType, id, params) -} - -// ListResponseInputItems delegates native response input item listing. -func (g *GuardedProvider) ListResponseInputItems(ctx context.Context, providerType, id string, params core.ResponseInputItemsParams) (*core.ResponseInputItemListResponse, error) { - responses, err := g.nativeResponseLifecycleRouter() - if err != nil { - return nil, err - } - return responses.ListResponseInputItems(ctx, providerType, id, params) -} - -// CancelResponse delegates native response cancellation. -func (g *GuardedProvider) CancelResponse(ctx context.Context, providerType, id string) (*core.ResponsesResponse, error) { - responses, err := g.nativeResponseLifecycleRouter() - if err != nil { - return nil, err - } - return responses.CancelResponse(ctx, providerType, id) -} - -// DeleteResponse delegates native response deletion. -func (g *GuardedProvider) DeleteResponse(ctx context.Context, providerType, id string) (*core.ResponseDeleteResponse, error) { - responses, err := g.nativeResponseLifecycleRouter() - if err != nil { - return nil, err - } - return responses.DeleteResponse(ctx, providerType, id) -} - -// CountResponseInputTokens delegates native response token counting. -func (g *GuardedProvider) CountResponseInputTokens(ctx context.Context, providerType string, req *core.ResponsesRequest) (*core.ResponseInputTokensResponse, error) { - responses, err := g.nativeResponseUtilityRouter() - if err != nil { - return nil, err - } - return responses.CountResponseInputTokens(ctx, providerType, req) -} - -// CompactResponse delegates native response compaction. -func (g *GuardedProvider) CompactResponse(ctx context.Context, providerType string, req *core.ResponsesRequest) (*core.ResponseCompactResponse, error) { - responses, err := g.nativeResponseUtilityRouter() - if err != nil { - return nil, err - } - return responses.CompactResponse(ctx, providerType, req) -} - -// PatchChatRequest applies guardrails to a translated chat request without -// delegating to the wrapped provider. -func (g *GuardedProvider) PatchChatRequest(ctx context.Context, req *core.ChatRequest) (*core.ChatRequest, error) { - return processGuardedChat(ctx, g.pipeline, req) -} - -// PatchResponsesRequest applies guardrails to a translated responses request -// without delegating to the wrapped provider. -func (g *GuardedProvider) PatchResponsesRequest(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesRequest, error) { - return processGuardedResponses(ctx, g.pipeline, req) -} - -// PrepareBatchRequest applies guardrails to batch subrequests without -// submitting the native batch to the wrapped provider. -func (g *GuardedProvider) PrepareBatchRequest(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchRewriteResult, error) { - if !g.options.EnableForBatchProcessing { - return &core.BatchRewriteResult{Request: req}, nil - } - return processGuardedBatchRequest(ctx, providerType, req, g.pipeline, g.batchFileTransport()) -} - -func cloneToolCalls(toolCalls []core.ToolCall) []core.ToolCall { - if len(toolCalls) == 0 { - return nil - } - cloned := make([]core.ToolCall, len(toolCalls)) - for i, toolCall := range toolCalls { - cloned[i] = core.ToolCall{ - ID: toolCall.ID, - Type: toolCall.Type, - Function: core.FunctionCall{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, - ExtraFields: core.CloneUnknownJSONFields(toolCall.Function.ExtraFields), - }, - ExtraFields: core.CloneUnknownJSONFields(toolCall.ExtraFields), - } - } - return cloned -} - -func cloneChatMessageEnvelope(message core.Message) core.Message { - return core.Message{ - Role: message.Role, - ToolCallID: message.ToolCallID, - ContentNull: message.ContentNull, - Content: cloneMessageContent(message.Content), - ToolCalls: cloneToolCalls(message.ToolCalls), - ExtraFields: core.CloneUnknownJSONFields(message.ExtraFields), - } -} - -func cloneMessageContent(content any) any { - switch value := content.(type) { - case nil: - return nil - case string: - return value - case []core.ContentPart: - return cloneContentParts(value) - default: - parts, ok := core.NormalizeContentParts(content) - if !ok { - return value - } - return cloneContentParts(parts) - } -} - -func cloneContentParts(parts []core.ContentPart) []core.ContentPart { - if len(parts) == 0 { - return nil - } - cloned := make([]core.ContentPart, len(parts)) - for i, part := range parts { - cloned[i] = cloneContentPart(part) - } - return cloned -} - -func cloneContentPart(part core.ContentPart) core.ContentPart { - cloned := core.ContentPart{ - Type: part.Type, - Text: part.Text, - ExtraFields: core.CloneUnknownJSONFields(part.ExtraFields), - } - if part.ImageURL != nil { - cloned.ImageURL = &core.ImageURLContent{ - URL: part.ImageURL.URL, - Detail: part.ImageURL.Detail, - MediaType: part.ImageURL.MediaType, - ExtraFields: core.CloneUnknownJSONFields(part.ImageURL.ExtraFields), - } - } - if part.InputAudio != nil { - cloned.InputAudio = &core.InputAudioContent{ - Data: part.InputAudio.Data, - Format: part.InputAudio.Format, - ExtraFields: core.CloneUnknownJSONFields(part.InputAudio.ExtraFields), - } - } - return cloned -} diff --git a/internal/guardrails/provider_harness_test.go b/internal/guardrails/provider_harness_test.go new file mode 100644 index 00000000..0560c05b --- /dev/null +++ b/internal/guardrails/provider_harness_test.go @@ -0,0 +1,125 @@ +package guardrails + +import ( + "context" + "io" + + "gomodel/internal/batchrewrite" + "gomodel/internal/core" +) + +// GuardedProvider is a test harness that exercises the live guardrail +// request-rewrite functions (processGuardedChat / processGuardedResponses / +// processGuardedBatchRequest) end to end, the same way the production +// server-side patchers (WorkflowRequestPatcher / WorkflowBatchPreparer) do. +// +// Production wires the pipeline through those Workflow* patchers; this wrapper +// exists only so the shared rewrite logic can be tested against a real inner +// provider in one place. +type GuardedProvider struct { + inner core.RoutableProvider + pipeline *Pipeline + options Options +} + +// Options mirrors the batch-processing toggle exercised by the tests. +type Options struct { + EnableForBatchProcessing bool +} + +func NewGuardedProvider(inner core.RoutableProvider, pipeline *Pipeline) *GuardedProvider { + return NewGuardedProviderWithOptions(inner, pipeline, Options{}) +} + +func NewGuardedProviderWithOptions(inner core.RoutableProvider, pipeline *Pipeline, options Options) *GuardedProvider { + return &GuardedProvider{inner: inner, pipeline: pipeline, options: options} +} + +func (g *GuardedProvider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { + modified, err := processGuardedChat(ctx, g.pipeline, req) + if err != nil { + return nil, err + } + return g.inner.ChatCompletion(ctx, modified) +} + +func (g *GuardedProvider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) { + modified, err := processGuardedChat(ctx, g.pipeline, req) + if err != nil { + return nil, err + } + return g.inner.StreamChatCompletion(ctx, modified) +} + +func (g *GuardedProvider) Responses(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesResponse, error) { + modified, err := processGuardedResponses(ctx, g.pipeline, req) + if err != nil { + return nil, err + } + return g.inner.Responses(ctx, modified) +} + +func (g *GuardedProvider) StreamResponses(ctx context.Context, req *core.ResponsesRequest) (io.ReadCloser, error) { + modified, err := processGuardedResponses(ctx, g.pipeline, req) + if err != nil { + return nil, err + } + return g.inner.StreamResponses(ctx, modified) +} + +func (g *GuardedProvider) nativeBatchRouter() (core.NativeBatchRoutableProvider, error) { + bp, ok := g.inner.(core.NativeBatchRoutableProvider) + if !ok { + return nil, core.NewInvalidRequestError("batch routing is not supported by the current provider router", nil) + } + return bp, nil +} + +func (g *GuardedProvider) nativeFileRouter() (core.NativeFileRoutableProvider, error) { + fp, ok := g.inner.(core.NativeFileRoutableProvider) + if !ok { + return nil, core.NewInvalidRequestError("file routing is not supported by the current provider router", nil) + } + return fp, nil +} + +func (g *GuardedProvider) batchFileTransport() core.BatchFileTransport { + files, err := g.nativeFileRouter() + if err != nil { + return nil + } + return files +} + +// CreateBatch applies guardrails to inline batch items (when enabled) before +// delegating native batch creation, mirroring the production submit/cleanup +// orchestration so that path stays covered. +func (g *GuardedProvider) CreateBatch(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchResponse, error) { + bp, err := g.nativeBatchRouter() + if err != nil { + return nil, err + } + if !g.options.EnableForBatchProcessing { + return bp.CreateBatch(ctx, providerType, req) + } + + result, err := processGuardedBatchRequest(ctx, providerType, req, g.pipeline, g.batchFileTransport()) + if err != nil { + return nil, err + } + batchrewrite.RecordPreparation(ctx, req, result.Request) + resp, err := bp.CreateBatch(ctx, providerType, result.Request) + if err != nil { + batchrewrite.CleanupFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") + return nil, err + } + batchrewrite.CleanupSupersededFileFromRouter(ctx, g.nativeFileRouter, providerType, result.RewrittenInputFileID, "") + return resp, nil +} + +func (g *GuardedProvider) PrepareBatchRequest(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchRewriteResult, error) { + if !g.options.EnableForBatchProcessing { + return &core.BatchRewriteResult{Request: req}, nil + } + return processGuardedBatchRequest(ctx, providerType, req, g.pipeline, g.batchFileTransport()) +} diff --git a/internal/guardrails/provider_test.go b/internal/guardrails/provider_test.go index 9dcfdec2..39fcc3b6 100644 --- a/internal/guardrails/provider_test.go +++ b/internal/guardrails/provider_test.go @@ -1198,9 +1198,9 @@ func TestGuardedProvider_PrepareBatchRequest_DefaultNoBatchGuardrails(t *testing } } -func TestBatchPreparer_PrepareBatchRequest_NoPipelineReturnsOriginalRequest(t *testing.T) { +func TestGuardedProvider_PrepareBatchRequest_NoPipelineReturnsOriginalRequest(t *testing.T) { inner := &mockRoutableProvider{} - preparer := NewBatchPreparer(inner, nil) + guarded := NewGuardedProvider(inner, nil) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1213,7 +1213,7 @@ func TestBatchPreparer_PrepareBatchRequest_NoPipelineReturnsOriginalRequest(t *t }, } - result, err := preparer.PrepareBatchRequest(context.Background(), "mock", req) + result, err := guarded.PrepareBatchRequest(context.Background(), "mock", req) if err != nil { t.Fatal(err) } @@ -1900,126 +1900,6 @@ func TestGuardedProvider_DoesNotMutateOriginalRequest(t *testing.T) { } } -// --- Embeddings delegation tests --- - -func TestGuardedProvider_Embeddings_DelegatesDirectly(t *testing.T) { - inner := &mockRoutableProvider{} - pipeline := NewPipeline() - - g, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "should not affect embeddings") - pipeline.Add(g, 0) - - guarded := NewGuardedProvider(inner, pipeline) - - req := &core.EmbeddingRequest{Model: "text-embedding-3-small", Input: "hello"} - resp, err := guarded.Embeddings(context.Background(), req) - if err != nil { - t.Fatal(err) - } - if resp.Object != "list" { - t.Errorf("expected object 'list', got %q", resp.Object) - } - if resp.Provider != "mock" { - t.Errorf("expected provider 'mock', got %q", resp.Provider) - } -} - -// --- Delegation tests --- - -func TestGuardedProvider_ListModels_NoGuardrails(t *testing.T) { - inner := &mockRoutableProvider{} - pipeline := NewPipeline() - guarded := NewGuardedProvider(inner, pipeline) - - resp, err := guarded.ListModels(context.Background()) - if err != nil { - t.Fatal(err) - } - if resp.Object != "list" { - t.Errorf("expected 'list', got %q", resp.Object) - } -} - -func TestGuardedProvider_DelegatesSupports(t *testing.T) { - inner := &mockRoutableProvider{ - supportsFn: func(model string) bool { - return model == "gpt-4" - }, - } - pipeline := NewPipeline() - guarded := NewGuardedProvider(inner, pipeline) - - if !guarded.Supports("gpt-4") { - t.Error("expected Supports to return true for gpt-4") - } - if guarded.Supports("unknown") { - t.Error("expected Supports to return false for unknown") - } -} - -func TestGuardedProvider_DelegatesGetProviderType(t *testing.T) { - inner := &mockRoutableProvider{ - getProviderTypeFn: func(_ string) string { - return "openai" - }, - } - pipeline := NewPipeline() - guarded := NewGuardedProvider(inner, pipeline) - - if guarded.GetProviderType("gpt-4") != "openai" { - t.Errorf("expected 'openai', got %q", guarded.GetProviderType("gpt-4")) - } -} - -func TestGuardedProvider_ModelCount_UnknownWhenInnerDoesNotExposeCount(t *testing.T) { - inner := &mockRoutableProvider{} - pipeline := NewPipeline() - guarded := NewGuardedProvider(inner, pipeline) - - if got := guarded.ModelCount(); got != -1 { - t.Fatalf("ModelCount() = %d, want -1 for unknown count", got) - } -} - -func TestGuardedProvider_Passthrough_Delegates(t *testing.T) { - inner := &mockRoutableProvider{} - pipeline := NewPipeline() - guarded := NewGuardedProvider(inner, pipeline) - - resp, err := guarded.Passthrough(context.Background(), "openai", &core.PassthroughRequest{ - Method: http.MethodPost, - Endpoint: "responses", - Body: io.NopCloser(strings.NewReader(`{"foo":"bar"}`)), - Headers: http.Header{ - "Content-Type": {"application/json"}, - }, - }) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if inner.passthroughType != "openai" { - t.Fatalf("providerType = %q, want openai", inner.passthroughType) - } - if inner.passthroughReq == nil { - t.Fatal("passthroughReq = nil") - } - if inner.passthroughReq.Endpoint != "responses" { - t.Fatalf("Endpoint = %q, want responses", inner.passthroughReq.Endpoint) - } - body, readErr := io.ReadAll(inner.passthroughReq.Body) - if readErr != nil { - t.Fatalf("failed to read passthrough body: %v", readErr) - } - if got := string(body); got != `{"foo":"bar"}` { - t.Fatalf("Body = %q, want request body", got) - } - if resp.StatusCode != http.StatusAccepted { - t.Fatalf("StatusCode = %d, want %d", resp.StatusCode, http.StatusAccepted) - } -} - func TestGuardedProvider_GuardrailError_BlocksRequest(t *testing.T) { inner := &mockRoutableProvider{} pipeline := NewPipeline() diff --git a/internal/guardrails/responses_message_apply.go b/internal/guardrails/responses_message_apply.go index 5035917d..73d4bd43 100644 --- a/internal/guardrails/responses_message_apply.go +++ b/internal/guardrails/responses_message_apply.go @@ -1,6 +1,7 @@ package guardrails import ( + "maps" "reflect" "strings" @@ -631,12 +632,5 @@ func cloneResponsesInterfacePart(part any) any { // intentionally shared; callers are expected to either preserve them as-is or // replace whole top-level values instead of mutating nested structures in place. func cloneStringAnyMap(src map[string]any) map[string]any { - if src == nil { - return nil - } - cloned := make(map[string]any, len(src)) - for key, value := range src { - cloned[key] = value - } - return cloned + return maps.Clone(src) } diff --git a/internal/providers/anthropic/request_translation.go b/internal/providers/anthropic/request_translation.go index 4c20ae29..6a48753d 100644 --- a/internal/providers/anthropic/request_translation.go +++ b/internal/providers/anthropic/request_translation.go @@ -706,10 +706,6 @@ func anthropicImageSource(raw, mediaTypeHint string) (*anthropicContentSource, e meta := raw[len("data:"):comma] tokens := strings.Split(meta, ";") - if len(tokens) == 0 { - return nil, core.NewInvalidRequestError("anthropic image data URL is missing a media type", nil) - } - mediaType := strings.TrimSpace(tokens[0]) if mediaType == "" { mediaType = strings.TrimSpace(mediaTypeHint) diff --git a/internal/providers/registry.go b/internal/providers/registry.go index abb60c38..8f3db391 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -267,6 +267,7 @@ func (r *ModelRegistry) RegisterProviderWithNameAndType(provider core.Provider, defer r.mu.Unlock() providerName = strings.TrimSpace(providerName) + providerType = strings.TrimSpace(providerType) if providerName == "" { if providerType != "" { providerName = providerType diff --git a/internal/providers/registry_test.go b/internal/providers/registry_test.go index d0f54c89..041b9297 100644 --- a/internal/providers/registry_test.go +++ b/internal/providers/registry_test.go @@ -2019,3 +2019,25 @@ func TestListPublicModels_HidesAudioOnlyModelsFromProvidersWithoutAudioSupport(t } } } + +// TestProviderByTypeAndNameTrimConfiguredValues verifies that configured provider +// names and types are normalized at registration so lookups succeed even when the +// configured value arrives padded with whitespace (e.g. from YAML or env vars). +func TestProviderByTypeAndNameTrimConfiguredValues(t *testing.T) { + registry := NewModelRegistry() + mock := ®istryMockProvider{name: "padded"} + registry.RegisterProviderWithNameAndType(mock, " padded-name ", " openai ") + + if got := registry.ProviderByType("openai"); got != mock { + t.Fatalf("ProviderByType(openai) = %v, want the registered provider", got) + } + if got := registry.ProviderByName("padded-name"); got != mock { + t.Fatalf("ProviderByName(padded-name) = %v, want the registered provider", got) + } + if got := registry.GetProviderTypeForName("padded-name"); got != "openai" { + t.Fatalf("GetProviderTypeForName(padded-name) = %q, want %q", got, "openai") + } + if got := registry.GetProviderNameForType("openai"); got != "padded-name" { + t.Fatalf("GetProviderNameForType(openai) = %q, want %q", got, "padded-name") + } +} diff --git a/internal/providers/responses_adapter.go b/internal/providers/responses_adapter.go index 15649e2d..fad813fe 100644 --- a/internal/providers/responses_adapter.go +++ b/internal/providers/responses_adapter.go @@ -315,12 +315,7 @@ func normalizeResponsesToolChoiceForChat(choice any) any { } func cloneStringAnyMap(src map[string]any) map[string]any { - if src == nil { - return nil - } - dst := make(map[string]any, len(src)) - maps.Copy(dst, src) - return dst + return maps.Clone(src) } // ResponsesViaChat implements the Responses API by converting to/from Chat format. diff --git a/internal/providers/router.go b/internal/providers/router.go index 84eaacb9..3aa45870 100644 --- a/internal/providers/router.go +++ b/internal/providers/router.go @@ -914,59 +914,47 @@ func (r *Router) providerTypes() []string { return result } -// NativeFileProviderTypes returns the registered provider types that support -// native file operations. This inventory is independent of the public model -// catalog whenever the underlying lookup can expose provider types directly. -func (r *Router) NativeFileProviderTypes() []string { +// providerTypesSupporting returns the registered provider types whose backing +// provider satisfies the given capability check. The inventory is independent of +// the public model catalog whenever the underlying lookup can expose provider +// types directly. +func (r *Router) providerTypesSupporting(supports func(core.Provider) bool) []string { providerTypes := r.providerTypes() result := make([]string, 0, len(providerTypes)) for _, providerType := range providerTypes { provider := r.providerByTypeRegistry(providerType) - if provider == nil { - continue - } - if _, ok := provider.(core.NativeFileProvider); !ok { - continue + if provider != nil && supports(provider) { + result = append(result, providerType) } - result = append(result, providerType) } return result } +// NativeFileProviderTypes returns the registered provider types that support +// native file operations. +func (r *Router) NativeFileProviderTypes() []string { + return r.providerTypesSupporting(func(p core.Provider) bool { + _, ok := p.(core.NativeFileProvider) + return ok + }) +} + // NativeBatchProviderTypes returns the registered provider types that support // native batch operations. func (r *Router) NativeBatchProviderTypes() []string { - providerTypes := r.providerTypes() - result := make([]string, 0, len(providerTypes)) - for _, providerType := range providerTypes { - provider := r.providerByTypeRegistry(providerType) - if provider == nil { - continue - } - if _, ok := provider.(core.NativeBatchProvider); !ok { - continue - } - result = append(result, providerType) - } - return result + return r.providerTypesSupporting(func(p core.Provider) bool { + _, ok := p.(core.NativeBatchProvider) + return ok + }) } // NativeResponseProviderTypes returns the registered provider types that // support native Responses lifecycle operations. func (r *Router) NativeResponseProviderTypes() []string { - providerTypes := r.providerTypes() - result := make([]string, 0, len(providerTypes)) - for _, providerType := range providerTypes { - provider := r.providerByTypeRegistry(providerType) - if provider == nil { - continue - } - if _, ok := provider.(core.NativeResponseLifecycleProvider); !ok { - continue - } - result = append(result, providerType) - } - return result + return r.providerTypesSupporting(func(p core.Provider) bool { + _, ok := p.(core.NativeResponseLifecycleProvider) + return ok + }) } // Passthrough routes an opaque provider-native request by provider type. diff --git a/internal/responsecache/stream_cache.go b/internal/responsecache/stream_cache.go index 7c12320a..74cac062 100644 --- a/internal/responsecache/stream_cache.go +++ b/internal/responsecache/stream_cache.go @@ -2,6 +2,7 @@ package responsecache import ( "bytes" + "maps" "net/http" "strings" @@ -260,14 +261,7 @@ func toJSONMap(value any) (map[string]any, error) { } func cloneJSONMap(src map[string]any) map[string]any { - if src == nil { - return nil - } - dst := make(map[string]any, len(src)) - for key, value := range src { - dst[key] = value - } - return dst + return maps.Clone(src) } func jsonNumberToInt(value any) (int, bool) { diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go index c61b6bc6..c543b140 100644 --- a/internal/server/handlers_test.go +++ b/internal/server/handlers_test.go @@ -1884,7 +1884,7 @@ func TestChatCompletion_UsesExplicitTranslatedRequestPatcher(t *testing.T) { }, } - patcher := guardrails.NewRequestPatcher(pipeline) + patcher := guardrails.NewWorkflowRequestPatcher(staticPipelineResolver{pipeline: pipeline}) e := echo.New() handler := newHandler(inner, nil, nil, nil, nil, nil, nil, patcher) @@ -1956,7 +1956,7 @@ func TestBatches_UsesExplicitGuardrailBatchPreparer(t *testing.T) { RequestCounts: core.BatchRequestCounts{Total: 1}, }, } - batchPreparer := guardrails.NewBatchPreparer(mock, pipeline) + batchPreparer := guardrails.NewWorkflowBatchPreparer(mock, staticPipelineResolver{pipeline: pipeline}) e := echo.New() handler := NewHandler(mock, nil, nil, nil) @@ -7331,3 +7331,12 @@ func TestIsNativeBatchResultsPending(t *testing.T) { t.Fatal("expected terminal anthropic batch not to be treated as pending") } } + +// staticPipelineResolver returns a fixed guardrails pipeline regardless of +// context, letting tests drive the production WorkflowRequestPatcher / +// WorkflowBatchPreparer with an explicit pipeline. +type staticPipelineResolver struct{ pipeline *guardrails.Pipeline } + +func (s staticPipelineResolver) PipelineForContext(context.Context) *guardrails.Pipeline { + return s.pipeline +} diff --git a/internal/server/native_response_service.go b/internal/server/native_response_service.go index 00bd0960..13f7dd1e 100644 --- a/internal/server/native_response_service.go +++ b/internal/server/native_response_service.go @@ -58,7 +58,7 @@ func (s *nativeResponseService) GetResponse(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } auditResponseEntry(c, providerType) return c.JSON(http.StatusOK, resp) @@ -90,7 +90,7 @@ func (s *nativeResponseService) ListResponseInputItems(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } if resp.Object == "" { resp.Object = "list" @@ -120,7 +120,7 @@ func (s *nativeResponseService) CancelResponse(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } normalizeCanceledResponse(resp, id, providerType) stored.Response = resp @@ -139,7 +139,7 @@ func (s *nativeResponseService) CancelResponse(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } normalizeCanceledResponse(resp, id, providerType) auditResponseEntry(c, providerType) @@ -185,7 +185,7 @@ func (s *nativeResponseService) DeleteResponse(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } if resp.ID == "" { resp.ID = id @@ -219,7 +219,7 @@ func (s *nativeResponseService) CountResponseInputTokens(c *echo.Context) error return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } if resp.Object == "" { resp.Object = "response.input_tokens" @@ -249,7 +249,7 @@ func (s *nativeResponseService) CompactResponse(c *echo.Context) error { return handleError(c, err) } if resp == nil { - return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "provider returned empty response", nil)) + return handleError(c, core.NewEmptyProviderResponseError(providerType)) } if resp.Object == "" { resp.Object = "response.compaction" diff --git a/internal/usage/extractor.go b/internal/usage/extractor.go index 2e662f9a..a1b48f39 100644 --- a/internal/usage/extractor.go +++ b/internal/usage/extractor.go @@ -132,12 +132,7 @@ func applyUsageCosts(entry *UsageEntry, provider, endpoint string, pricing ...*c // cloneRawData creates a shallow copy of the raw data map to prevent races // when the original map might be mutated after the entry is enqueued. func cloneRawData(src map[string]any) map[string]any { - if src == nil { - return nil - } - dst := make(map[string]any, len(src)) - maps.Copy(dst, src) - return dst + return maps.Clone(src) } // ExtractFromResponsesResponse extracts usage data from a ResponsesResponse. From 4a6b8040ba9d3a360af4e9d36d432984a7cd453c Mon Sep 17 00:00:00 2001 From: "Jakub A. W" Date: Wed, 1 Jul 2026 01:00:52 +0200 Subject: [PATCH 2/2] chore(guardrails): address PR review nitpicks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - clone.go: document why the cloneMessageContent fallback returns the original (shared) reference for unrecognized content shapes — chat content is normalized to nil/string/[]ContentPart before reaching it and guardrails replace whole values rather than mutating in place, so the branch is defensive and the shared reference is safe. - provider_harness_test.go: rename the test-only Options type to GuardedProviderOptions so it can't collide with a future package-level type in package guardrails. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/guardrails/clone.go | 5 ++++ internal/guardrails/provider_harness_test.go | 10 ++++---- internal/guardrails/provider_test.go | 24 ++++++++++---------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/internal/guardrails/clone.go b/internal/guardrails/clone.go index 7f1afe5e..c3907267 100644 --- a/internal/guardrails/clone.go +++ b/internal/guardrails/clone.go @@ -46,6 +46,11 @@ func cloneMessageContent(content any) any { default: parts, ok := core.NormalizeContentParts(content) if !ok { + // Unrecognized content shapes cannot be deep-copied generically, so + // they are returned as-is. Guardrails replace whole content values + // rather than mutating them in place, so sharing the reference is + // safe; chat content is normalized to nil/string/[]ContentPart + // before reaching here, making this branch defensive. return value } return cloneContentParts(parts) diff --git a/internal/guardrails/provider_harness_test.go b/internal/guardrails/provider_harness_test.go index 0560c05b..8e2dc39e 100644 --- a/internal/guardrails/provider_harness_test.go +++ b/internal/guardrails/provider_harness_test.go @@ -19,19 +19,19 @@ import ( type GuardedProvider struct { inner core.RoutableProvider pipeline *Pipeline - options Options + options GuardedProviderOptions } -// Options mirrors the batch-processing toggle exercised by the tests. -type Options struct { +// GuardedProviderOptions mirrors the batch-processing toggle exercised by the tests. +type GuardedProviderOptions struct { EnableForBatchProcessing bool } func NewGuardedProvider(inner core.RoutableProvider, pipeline *Pipeline) *GuardedProvider { - return NewGuardedProviderWithOptions(inner, pipeline, Options{}) + return NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{}) } -func NewGuardedProviderWithOptions(inner core.RoutableProvider, pipeline *Pipeline, options Options) *GuardedProvider { +func NewGuardedProviderWithOptions(inner core.RoutableProvider, pipeline *Pipeline, options GuardedProviderOptions) *GuardedProvider { return &GuardedProvider{inner: inner, pipeline: pipeline, options: options} } diff --git a/internal/guardrails/provider_test.go b/internal/guardrails/provider_test.go index 39fcc3b6..3e32c16c 100644 --- a/internal/guardrails/provider_test.go +++ b/internal/guardrails/provider_test.go @@ -1230,7 +1230,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled(t *testing.T) { pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1271,7 +1271,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_InputFile(t *testing pipeline := NewPipeline() gr, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(gr, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) _, err := guarded.CreateBatch(context.Background(), "mock", &core.BatchRequest{ InputFileID: "file_source", @@ -1307,7 +1307,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_InputFileCleansUpOnF pipeline := NewPipeline() gr, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(gr, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) _, err := guarded.CreateBatch(context.Background(), "mock", &core.BatchRequest{ InputFileID: "file_source", @@ -1326,7 +1326,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_SkipsEmbeddingsItems pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/embeddings", @@ -1356,7 +1356,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_TextOnlyContentArray pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1409,7 +1409,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_RewritesStructuredTe return out, nil }, }, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1481,7 +1481,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_PreservesOpaqueChatF pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptInject, "guardrail system") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1559,7 +1559,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_RewritesChatContentW return out, nil }, }, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1613,7 +1613,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_PreservesOpaqueRespo pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptOverride, "guardrail instructions") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/responses", @@ -1686,7 +1686,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_RewritesResponsesInp } pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/responses", Requests: []core.BatchRequestItem{ @@ -1735,7 +1735,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_NormalizesFullURLRes pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptOverride, "guardrail instructions") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions", @@ -1774,7 +1774,7 @@ func TestGuardedProvider_CreateBatch_BatchGuardrailsEnabled_PreservesSystemMessa pipeline := NewPipeline() g, _ := NewSystemPromptGuardrail("test", SystemPromptDecorator, "prefix") pipeline.Add(g, 0) - guarded := NewGuardedProviderWithOptions(inner, pipeline, Options{EnableForBatchProcessing: true}) + guarded := NewGuardedProviderWithOptions(inner, pipeline, GuardedProviderOptions{EnableForBatchProcessing: true}) req := &core.BatchRequest{ Endpoint: "/v1/chat/completions",