diff --git a/.gitignore b/.gitignore index c36c3871..1f15d22c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ dist coverage playwright-report -__screenshots__ \ No newline at end of file +__screenshots__ + +# Copied from ../tracker/dist (see packages/server/package.json copytracker script) +packages/server/app/tracker/tracker.js \ No newline at end of file diff --git a/README.md b/README.md index dedd7bdf..7f0d3f51 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,26 @@ Counterscale.trackPageview(); The deployment URL can always be changed to go behind a custom domain you own. [More here](https://developers.cloudflare.com/workers/configuration/routing/custom-domains/). +### Allowed Origins: Restricting Which Sites Can Report + +By default, a Counterscale deployment records hits from any site that loads your tracker (or posts to `/collect`). To limit recording to specific domains, set the `TRACKER_ALLOWED_ORIGINS` environment variable on the Cloudflare Worker to a comma-separated list of allowed domains: + +``` +TRACKER_ALLOWED_ORIGINS=example.com,myblog.io,acme.dev +``` + +- Each entry matches the specified domain **and any of its subdomains**, you only need the parent domain. `example.com` already covers `blog.example.com`, `app.example.com`, and any other subdomain, so there's no need to list them separately. Lookalikes such as `notexample.com` are not matched. Entries may be bare hostnames or include a scheme (`https://example.com`), both are treated the same. +- When set, hits whose origin isn't on the list are silently ignored: the tracker still returns a normal response, but no data is recorded. +- Leave it empty (the default) or set it to `*` to disable the allowlist and record from any origin. + +You can set this variable in one of two ways: + +- **Cloudflare dashboard:** Workers & Pages → Counterscale worker → Settings → Variables and Secrets → add or set `TRACKER_ALLOWED_ORIGINS`, then redeploy. +- **From source:** edit the `vars` block in `packages/server/wrangler.json` and redeploy. + +> [!NOTE] +> This is a best-effort filter. The signals it checks (`Origin`, `Referer`, and the reported hostname) are supplied by the client and can be spoofed by non-browser tools, so it deters casual or accidental cross-site reporting rather than a determined attacker. Recorded data is also partitioned by site ID, which further limits the impact of unwanted hits. + ## CLI Commands Counterscale provides a command-line interface (CLI) to help you install, configure, and manage your deployment. diff --git a/packages/cli/src/commands/install.ts b/packages/cli/src/commands/install.ts index a98bc3eb..d4754a9a 100644 --- a/packages/cli/src/commands/install.ts +++ b/packages/cli/src/commands/install.ts @@ -218,6 +218,11 @@ export async function install( // If --advanced is true, prompt the user for worker name and analytics dataset name. // Otherwise, stick to the default values read from the server package. if (opts.advanced) { + log.warn( + "If you previously installed with a custom worker name or analytics dataset, " + + "re-enter the same values below — accepting the defaults will repoint your " + + "deployment at a fresh dataset and your existing analytics will appear empty.", + ); ({ workerName, analyticsDataset } = await promptProjectConfig( workerName, analyticsDataset, diff --git a/packages/cli/src/lib/__tests__/config.test.ts b/packages/cli/src/lib/__tests__/config.test.ts index 647cbe0f..af152617 100644 --- a/packages/cli/src/lib/__tests__/config.test.ts +++ b/packages/cli/src/lib/__tests__/config.test.ts @@ -254,7 +254,7 @@ describe("CLI Functions", () => { "/target/wrangler.json", initialConfig, "new-worker", - "new-dataset", + "newDataset", ); // Verify writeFileSync was called with the correct arguments @@ -271,8 +271,11 @@ describe("CLI Functions", () => { // Verify worker name and dataset were updated expect(writtenConfig.name).toBe("new-worker"); expect(writtenConfig.analytics_engine_datasets[0].dataset).toBe( - "new-dataset", + "newDataset", ); + // Verify CF_DATASET_NAME var was set so the dashboard's SQL + // read path uses the same dataset as the AE binding. + expect(writtenConfig.vars.CF_DATASET_NAME).toBe("newDataset"); // Verify paths were made absolute expect(writtenConfig.build.cwd).toMatch(/^\//); // Should start with / @@ -294,7 +297,7 @@ describe("CLI Functions", () => { "/target/wrangler.json", initialConfig, "new-worker", - "new-dataset", + "newDataset", accountId, ); @@ -319,7 +322,7 @@ describe("CLI Functions", () => { "/target/wrangler.json", initialConfig, "new-worker", - "new-dataset", + "newDataset", ); const writtenConfig = JSON.parse( diff --git a/packages/cli/src/lib/config.ts b/packages/cli/src/lib/config.ts index 61e04866..728cfb1e 100644 --- a/packages/cli/src/lib/config.ts +++ b/packages/cli/src/lib/config.ts @@ -123,6 +123,11 @@ export function readInitialServerConfig() { * converted to be absolute. This makes it so that the `wrangler deploy` command can be * run from any directory. */ +// Mirrors the validation in AnalyticsEngineAPI's constructor — the dataset +// name is interpolated into raw SQL on the read path, and only matching names +// will pass the server-side guard. +export const DATASET_NAME_PATTERN = /^[A-Za-z0-9_]+$/; + export async function stageDeployConfig( targetPath: string, initialDeployConfig: ReturnType, @@ -130,6 +135,12 @@ export async function stageDeployConfig( analyticsDataset: string, accountId?: string, ): Promise { + if (!DATASET_NAME_PATTERN.test(analyticsDataset)) { + throw new Error( + `Invalid Analytics Engine dataset name: ${analyticsDataset}. Only letters, digits, and underscores are allowed.`, + ); + } + const serverPkgDir = getServerPkgDir(); const outDeployConfig = makePathsAbsolute( @@ -138,6 +149,10 @@ export async function stageDeployConfig( ); outDeployConfig.name = workerName; outDeployConfig.analytics_engine_datasets[0].dataset = analyticsDataset; + outDeployConfig.vars = { + ...(outDeployConfig.vars ?? {}), + CF_DATASET_NAME: analyticsDataset, + }; if (accountId) { outDeployConfig.account_id = accountId; diff --git a/packages/server/.dev.vars.example b/packages/server/.dev.vars.example index 8946785f..91fd650d 100644 --- a/packages/server/.dev.vars.example +++ b/packages/server/.dev.vars.example @@ -6,3 +6,4 @@ CF_PASSWORD_HASH='' CF_JWT_SECRET='' CF_AUTH_ENABLED='' CF_STORAGE_ENABLED='' +CF_DATASET_NAME='' diff --git a/packages/server/app/analytics/__tests__/collect.test.ts b/packages/server/app/analytics/__tests__/collect.test.ts index c70f0f6d..8b262f0f 100644 --- a/packages/server/app/analytics/__tests__/collect.test.ts +++ b/packages/server/app/analytics/__tests__/collect.test.ts @@ -427,3 +427,151 @@ describe("collectRequestHandler", () => { expect(blobs[14]).toBe(""); // utm_content (empty) }); }); + +describe("collectRequestHandler allowlist enforcement", () => { + const UA = + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36"; + + function buildRequest(opts: { + h: string; + r?: string; + origin?: string; + referer?: string; + }) { + const headers: Record = { "user-agent": UA }; + if (opts.origin) headers["origin"] = opts.origin; + if (opts.referer) headers["referer"] = opts.referer; + return { + method: "GET", + url: + "https://example.com/collect?" + + new URLSearchParams({ + sid: "example", + h: opts.h, + p: "/", + r: opts.r ?? "", + ht: "1", + }).toString(), + headers: { + get: (header: string) => headers[header], + }, + }; + } + + function makeEnv(allowed?: string) { + return { + WEB_COUNTER_AE: { writeDataPoint: vi.fn() }, + TRACKER_ALLOWED_ORIGINS: allowed, + } as unknown as Env; + } + + test("writes when h matches a listed origin via subdomain", () => { + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ h: "https://docs.pmux.io" }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("silently drops (200 gif, no write) when h is not allowed", () => { + const env = makeEnv("pmux.io"); + const response = collectRequestHandler( + buildRequest({ h: "https://evil.com" }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).not.toHaveBeenCalled(); + expect(response.status).toBe(200); + expect(response.headers.get("Content-Type")).toBe("image/gif"); + // Drop path must NOT set Last-Modified: the tracker uses it for + // cookieless visit counting, and updating it on a dropped hit would + // corrupt session state. + expect(response.headers.get("Last-Modified")).toBeNull(); + }); + + test("writes when h is a bare hostname (no scheme) that is allowed", () => { + const env = makeEnv("pmux.io"); + collectRequestHandler(buildRequest({ h: "docs.pmux.io" }) as any, env); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("writes when only the Referer header is present and allowed", () => { + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ + h: "https://docs.pmux.io", + referer: "https://docs.pmux.io/guide", + }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("does not drop legit traffic from an opaque 'null' Origin", () => { + // Sandboxed iframes send Origin: null; this must not block an + // otherwise-allowed hit. + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ h: "https://docs.pmux.io", origin: "null" }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("'*' in the allowlist disables enforcement (allow all)", () => { + const env = makeEnv("*"); + collectRequestHandler( + buildRequest({ h: "https://anything.com" }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("drops when h is allowed but the Origin header is not", () => { + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ + h: "https://docs.pmux.io", + origin: "https://evil.com", + }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).not.toHaveBeenCalled(); + }); + + test("writes when h, Origin, and Referer headers all match", () => { + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ + h: "https://docs.pmux.io", + origin: "https://docs.pmux.io", + referer: "https://docs.pmux.io/guide", + }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("ignores the analytics referrer (r) param for enforcement", () => { + // r is the visitor's traffic source, not the embedding page; it must + // not be validated against the allowlist. + const env = makeEnv("pmux.io"); + collectRequestHandler( + buildRequest({ + h: "https://docs.pmux.io", + r: "https://google.com", + }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); + + test("writes for any host when the allowlist is unset (opt-in)", () => { + const env = makeEnv(undefined); + collectRequestHandler( + buildRequest({ h: "https://anything.com" }) as any, + env, + ); + expect(env.WEB_COUNTER_AE.writeDataPoint).toHaveBeenCalled(); + }); +}); diff --git a/packages/server/app/analytics/__tests__/query.test.ts b/packages/server/app/analytics/__tests__/query.test.ts index c21d8d6b..c56bda8c 100644 --- a/packages/server/app/analytics/__tests__/query.test.ts +++ b/packages/server/app/analytics/__tests__/query.test.ts @@ -682,3 +682,50 @@ describe("intervalToSql", () => { }); }); }); + +describe("AnalyticsEngineAPI dataset name", () => { + let fetch: Mock; + beforeEach(() => { + fetch = global.fetch = vi.fn(); + fetch.mockResolvedValue(createFetchResponse({ data: [] })); + vi.useFakeTimers(); + }); + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + test("defaults to metricsDataset when no dataset arg is provided", () => { + const api = new AnalyticsEngineAPI("acct", "tok"); + expect(api.dataset).toBe("metricsDataset"); + }); + + test("defaults to metricsDataset when empty string is provided", () => { + const api = new AnalyticsEngineAPI("acct", "tok", ""); + expect(api.dataset).toBe("metricsDataset"); + }); + + test("uses the provided custom dataset name", () => { + const api = new AnalyticsEngineAPI("acct", "tok", "counterscaleMetrics"); + expect(api.dataset).toBe("counterscaleMetrics"); + }); + + test("rejects invalid dataset names", () => { + expect( + () => new AnalyticsEngineAPI("acct", "tok", "bad name; DROP"), + ).toThrow(/Invalid Analytics Engine dataset name/); + }); + + test("getCounts emits SQL referencing the custom dataset", async () => { + const api = new AnalyticsEngineAPI( + "acct", + "tok", + "counterscaleMetrics", + ); + await api.getCounts("site1", "7d"); + expect(fetch).toHaveBeenCalled(); + const body = fetch.mock.calls[0][1].body as string; + expect(body).toContain("FROM counterscaleMetrics"); + expect(body).not.toContain("FROM metricsDataset"); + }); +}); diff --git a/packages/server/app/analytics/collect.ts b/packages/server/app/analytics/collect.ts index 193b9cec..a8576fb4 100644 --- a/packages/server/app/analytics/collect.ts +++ b/packages/server/app/analytics/collect.ts @@ -1,6 +1,11 @@ import type { AnalyticsEngineDataset } from "@cloudflare/workers-types"; import { IDevice, UAParser } from "ua-parser-js"; import { maskBrowserVersion } from "~/lib/utils"; +import { + parseAllowedOrigins, + extractHost, + isHostAllowed, +} from "~/lib/allowedOrigins"; // Cookieless visitor/session tracking // Uses the approach described here: https://notes.normally.com/cookieless-unique-visitor-counts/ @@ -142,6 +147,18 @@ export function collectRequestHandler( return new Response("Missing siteId", { status: 400 }); } + // Optional allowlist enforcement. When TRACKER_ALLOWED_ORIGINS is set, drop + // (silently) any hit whose reported host or Origin/Referer header isn't an + // allowed origin. CORS can't gate the worker, so we enforce it here. + // Best-effort: these signals are client-controlled and spoofable. + const allowedOrigins = parseAllowedOrigins(env.TRACKER_ALLOWED_ORIGINS); + if ( + allowedOrigins.length > 0 && + !requestIsAllowed(request, params, allowedOrigins) + ) { + return trackingGifResponse(); + } + const userAgent = request.headers.get("user-agent") || undefined; const parsedUserAgent = new UAParser(userAgent); @@ -210,7 +227,37 @@ export function collectRequestHandler( writeDataPoint(env.WEB_COUNTER_AE, data); - // encode 1x1 transparent gif + return trackingGifResponse(nextLastModifiedDate); +} + +/** + * Returns true if the request's reported host (h) and any present + * Origin/Referer header all resolve to an allowed origin. Requires at least + * one usable signal. The analytics referrer (params.r) is intentionally NOT + * checked — it's the visitor's traffic source, not the embedding page. + */ +function requestIsAllowed( + request: Request, + params: { [key: string]: string }, + allowedOrigins: string[], +): boolean { + const candidates = [ + params.h, + request.headers.get("origin"), + request.headers.get("referer"), + ]; + const hosts = candidates + .map((candidate) => extractHost(candidate)) + .filter((host): host is string => host !== null); + + return ( + hosts.length > 0 && + hosts.every((host) => isHostAllowed(host, allowedOrigins)) + ); +} + +/** Encodes the 1x1 transparent tracking gif response. */ +function trackingGifResponse(nextLastModifiedDate?: Date): Response { const gif = "R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"; const gifData = atob(gif); const gifLength = gifData.length; diff --git a/packages/server/app/analytics/query.ts b/packages/server/app/analytics/query.ts index 5810f7eb..d66482af 100644 --- a/packages/server/app/analytics/query.ts +++ b/packages/server/app/analytics/query.ts @@ -166,9 +166,12 @@ function filtersToSql(filters: SearchFilters) { * See: https://developers.cloudflare.com/analytics/analytics-engine/sql-reference/ */ +export const DEFAULT_DATASET_NAME = "metricsDataset"; + export class AnalyticsEngineAPI { cfApiToken: string; cfAccountId: string; + dataset: string; defaultHeaders: { "content-type": string; "X-Source": string; @@ -176,10 +179,22 @@ export class AnalyticsEngineAPI { }; defaultUrl: string; - constructor(cfAccountId: string, cfApiToken: string) { + constructor( + cfAccountId: string, + cfApiToken: string, + dataset?: string, + ) { this.cfAccountId = cfAccountId; this.cfApiToken = cfApiToken; + const resolved = dataset || DEFAULT_DATASET_NAME; + if (!/^[A-Za-z0-9_]+$/.test(resolved)) { + throw new Error( + `Invalid Analytics Engine dataset name: ${resolved}`, + ); + } + this.dataset = resolved; + this.defaultUrl = `https://api.cloudflare.com/client/v4/accounts/${this.cfAccountId}/analytics_engine/sql`; this.defaultHeaders = { "content-type": "application/json;charset=UTF-8", @@ -246,7 +261,7 @@ export class AnalyticsEngineAPI { /* output as UTC */ toDateTime(_bucket, 'Etc/UTC') as bucket - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= toDateTime('${localStartTime.format("YYYY-MM-DD HH:mm:ss")}') AND timestamp < toDateTime('${localEndTime.format("YYYY-MM-DD HH:mm:ss")}') AND ${ColumnMappings.siteId} = '${siteId}' @@ -352,7 +367,7 @@ export class AnalyticsEngineAPI { SELECT SUM(_sample_interval) as count, ${ColumnMappings.newVisitor} as isVisitor, ${ColumnMappings.bounce} as isBounce - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= ${startIntervalSql} AND timestamp < ${endIntervalSql} ${filterStr} AND ${siteIdColumn} = '${siteId}' @@ -416,7 +431,7 @@ export class AnalyticsEngineAPI { const _column = ColumnMappings[column]; const query = ` SELECT ${_column}, SUM(_sample_interval) as count - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= ${startIntervalSql} AND timestamp < ${endIntervalSql} AND ${ColumnMappings.newVisitor} = 1 AND ${ColumnMappings.siteId} = '${siteId}' @@ -492,7 +507,7 @@ export class AnalyticsEngineAPI { ${ColumnMappings.newVisitor} as isVisitor, ${ColumnMappings.bounce} as isBounce, ${columnsStrWithAliases} - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= toDateTime('${startDateTimeSql}') AND timestamp < toDateTime('${endDateTimeSql}') GROUP BY timestamp, ${ColumnMappings.siteId}, @@ -584,7 +599,7 @@ export class AnalyticsEngineAPI { SELECT ${_column}, ${ColumnMappings.newVisitor} as isVisitor, SUM(_sample_interval) as count - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= ${startIntervalSql} AND timestamp < ${endIntervalSql} AND ${ColumnMappings.newVisitor} = 0 AND ${ColumnMappings.siteId} = '${siteId}' @@ -889,7 +904,7 @@ export class AnalyticsEngineAPI { const query = ` SELECT SUM(_sample_interval) as count, ${ColumnMappings.siteId} as siteId - FROM metricsDataset + FROM ${this.dataset} WHERE timestamp >= ${startIntervalSql} AND timestamp < ${endIntervalSql} GROUP BY siteId ORDER BY count DESC @@ -936,7 +951,7 @@ export class AnalyticsEngineAPI { SELECT MIN(timestamp) as earliestEvent, ${ColumnMappings.bounce} as isBounce - FROM metricsDataset + FROM ${this.dataset} WHERE ${ColumnMappings.siteId} = '${siteId}' GROUP by isBounce `; diff --git a/packages/server/app/lib/__tests__/allowedOrigins.test.ts b/packages/server/app/lib/__tests__/allowedOrigins.test.ts new file mode 100644 index 00000000..4351204c --- /dev/null +++ b/packages/server/app/lib/__tests__/allowedOrigins.test.ts @@ -0,0 +1,135 @@ +import { describe, expect, test } from "vitest"; +import { + parseAllowedOrigins, + extractHost, + isHostAllowed, +} from "../allowedOrigins"; + +describe("parseAllowedOrigins", () => { + test("returns empty list for undefined", () => { + expect(parseAllowedOrigins(undefined)).toEqual([]); + }); + + test("returns empty list for empty string", () => { + expect(parseAllowedOrigins("")).toEqual([]); + }); + + test("splits comma-separated entries and trims whitespace", () => { + expect(parseAllowedOrigins("foo.com, bar.com ,baz.com")).toEqual([ + "foo.com", + "bar.com", + "baz.com", + ]); + }); + + test("strips scheme, port, path, and wildcard prefix; lowercases", () => { + expect( + parseAllowedOrigins( + "https://Foo.com, http://bar.com:8080/path, *.baz.com", + ), + ).toEqual(["foo.com", "bar.com", "baz.com"]); + }); + + test("drops empty entries from trailing/duplicate commas", () => { + expect(parseAllowedOrigins("foo.com,,bar.com,")).toEqual([ + "foo.com", + "bar.com", + ]); + }); + + test("treats a lone '*' as allow-all (empty list, enforcement off)", () => { + expect(parseAllowedOrigins("*")).toEqual([]); + expect(parseAllowedOrigins("https://*")).toEqual([]); + }); + + test("drops a '*' entry but keeps real hosts alongside it", () => { + expect(parseAllowedOrigins("*, foo.com")).toEqual(["foo.com"]); + }); +}); + +describe("extractHost", () => { + test("extracts hostname from a full URL", () => { + expect(extractHost("https://docs.pmux.io")).toBe("docs.pmux.io"); + }); + + test("extracts hostname from a URL with path and query", () => { + expect(extractHost("https://docs.pmux.io/guide?x=1")).toBe( + "docs.pmux.io", + ); + }); + + test("handles a bare hostname without scheme", () => { + expect(extractHost("example.com")).toBe("example.com"); + }); + + test("lowercases the host", () => { + expect(extractHost("https://Docs.PMUX.io")).toBe("docs.pmux.io"); + }); + + test("returns null for empty or nullish input", () => { + expect(extractHost("")).toBeNull(); + expect(extractHost(null)).toBeNull(); + expect(extractHost(undefined)).toBeNull(); + }); + + test("treats the opaque 'null' origin as absent", () => { + // Sandboxed iframes / file:// pages send the literal Origin: null. + expect(extractHost("null")).toBeNull(); + }); + + test("rejects candidates carrying userinfo (anti-spoofing)", () => { + // new URL("https://evil.com@pmux.io/").hostname is "pmux.io"; without + // this guard an attacker-controlled `h` could impersonate an allowed host. + expect(extractHost("https://evil.com@pmux.io/")).toBeNull(); + expect(extractHost("https://pmux.io@evil.com/")).toBeNull(); + }); + + test("handles a bare host with a port without corrupting it", () => { + expect(extractHost("localhost:3000")).toBe("localhost"); + }); + + test("does not corrupt a bracketed IPv6 host", () => { + expect(extractHost("[::1]")).toBe("[::1]"); + }); + + test("returns null for schemes with no host (data:, javascript:)", () => { + expect(extractHost("data:text/html,hi")).toBeNull(); + expect(extractHost("javascript:alert(1)")).toBeNull(); + }); +}); + +describe("isHostAllowed", () => { + const allowed = parseAllowedOrigins( + "shiftinbits.com,constellationdev.io,pmux.io", + ); + + test("matches an exact host", () => { + expect(isHostAllowed("pmux.io", allowed)).toBe(true); + }); + + test("matches a subdomain of a listed host", () => { + expect(isHostAllowed("docs.pmux.io", allowed)).toBe(true); + expect(isHostAllowed("app.constellationdev.io", allowed)).toBe(true); + }); + + test("matches a deep subdomain", () => { + expect(isHostAllowed("a.b.pmux.io", allowed)).toBe(true); + }); + + test("rejects a sibling domain that merely ends with the name", () => { + expect(isHostAllowed("evil-pmux.io", allowed)).toBe(false); + expect(isHostAllowed("notpmux.io", allowed)).toBe(false); + }); + + test("rejects an unlisted host", () => { + expect(isHostAllowed("example.com", allowed)).toBe(false); + }); + + test("rejects null host", () => { + expect(isHostAllowed(null, allowed)).toBe(false); + }); + + test("rejects any host against an empty allowlist", () => { + expect(isHostAllowed("pmux.io", [])).toBe(false); + }); +}); diff --git a/packages/server/app/lib/allowedOrigins.ts b/packages/server/app/lib/allowedOrigins.ts new file mode 100644 index 00000000..ab9daf02 --- /dev/null +++ b/packages/server/app/lib/allowedOrigins.ts @@ -0,0 +1,86 @@ +// Helpers for the TRACKER_ALLOWED_ORIGINS allowlist. +// +// Matching is host-based (not scheme-based): an entry like "pmux.io" matches +// the host itself and any subdomain (e.g. "docs.pmux.io"). Enforcement of this +// list lives in the /collect handler, since CORS headers cannot actually block +// who loads or POSTs to the worker. + +/** + * Normalize a single allowlist entry to a bare, lowercase hostname. + * Strips scheme, leading "*." wildcard, port, and path. + */ +function normalizeEntry(entry: string): string { + return entry + .trim() + .toLowerCase() + .replace(/^[a-z][a-z0-9+.-]*:\/\//, "") // scheme:// + .replace(/^\*\./, "") // wildcard prefix + .split("/")[0] // path + .split(":")[0]; // port +} + +/** + * Parse the comma-separated TRACKER_ALLOWED_ORIGINS env var into a list of + * normalized hostnames. Returns [] when unset/empty (enforcement is opt-in). + * A lone "*" entry is dropped so that TRACKER_ALLOWED_ORIGINS="*" reads as + * "allow all" (empty list = enforcement off) rather than silently blocking + * every host (no real hostname equals "*"). + */ +export function parseAllowedOrigins(value: string | undefined): string[] { + return (value ?? "") + .split(",") + .map(normalizeEntry) + .filter((host) => host.length > 0 && host !== "*"); +} + +function tryParseUrl(value: string): URL | null { + try { + return new URL(value); + } catch { + return null; + } +} + +/** + * Extract a lowercase hostname from a candidate that may be a full URL + * (e.g. "https://docs.pmux.io/x") or a bare host (e.g. "example.com:3000"). + * Returns null for empty/nullish input, the opaque "null" origin, schemes + * with no host (data:, javascript:), or any candidate carrying userinfo + * (e.g. "https://evil.com@pmux.io/", whose host parses as the allowed one). + */ +export function extractHost( + candidate: string | null | undefined, +): string | null { + if (!candidate) return null; + const value = candidate.trim(); + // The literal "null" is what browsers send as the Origin for sandboxed + // iframes / file:// pages — treat it as no signal, not a host named "null". + if (value.length === 0 || value.toLowerCase() === "null") return null; + + // Parse as-is first; fall back to assuming a bare host when there's no + // scheme, or when the value parsed as a scheme with no authority (e.g. + // "localhost:3000" parses as scheme "localhost", "data:..."/"javascript:..." + // have no host). The bare-host parse also handles ports and bracketed IPv6. + let url = tryParseUrl(value); + if (!url || url.hostname === "") { + url = tryParseUrl(`https://${value}`); + } + if (!url) return null; + + // Reject userinfo: "https://evil.com@pmux.io/" has hostname "pmux.io", so + // an attacker-controlled value could otherwise impersonate an allowed host. + if (url.username !== "" || url.password !== "") return null; + + return url.hostname ? url.hostname.toLowerCase() : null; +} + +/** + * True if the host exactly equals a listed origin or is a subdomain of one. + */ +export function isHostAllowed(host: string | null, allowed: string[]): boolean { + if (!host) return false; + const lower = host.toLowerCase(); + return allowed.some( + (entry) => lower === entry || lower.endsWith(`.${entry}`), + ); +} diff --git a/packages/server/app/load-context.ts b/packages/server/app/load-context.ts index 6e7fa07a..90d4cdc5 100644 --- a/packages/server/app/load-context.ts +++ b/packages/server/app/load-context.ts @@ -25,6 +25,7 @@ export const getLoadContext: GetLoadContext = ({ context }) => { const analyticsEngine = new AnalyticsEngineAPI( context.cloudflare.env.CF_ACCOUNT_ID, context.cloudflare.env.CF_BEARER_TOKEN, + context.cloudflare.env.CF_DATASET_NAME, ); return { diff --git a/packages/server/app/routes/$script.ts b/packages/server/app/routes/$script.ts index 997e24c5..91cee274 100644 --- a/packages/server/app/routes/$script.ts +++ b/packages/server/app/routes/$script.ts @@ -1,19 +1,18 @@ import type { LoaderFunctionArgs } from "react-router"; -export async function loader({ params, context, request }: LoaderFunctionArgs) { - const requestedScript = params.script; +// Bundled at build time via Vite's ?raw suffix. The source file lives at +// app/tracker/tracker.js and is populated by the `copytracker` npm script. +import trackerSource from "../tracker/tracker.js?raw"; +export async function loader({ params, context }: LoaderFunctionArgs) { + const requestedScript = params.script; if (!requestedScript || !requestedScript.endsWith(".js")) { return new Response("Not Found", { status: 404 }); } const customScriptName = context.cloudflare.env.CF_TRACKER_SCRIPT_NAME; const defaultScriptName = "tracker"; - - // Extract the base name without extension for comparison const requestedBaseName = requestedScript.replace(".js", ""); - - // Check if requested script matches either default or custom name const isDefaultScript = requestedBaseName === defaultScriptName; const isCustomScript = customScriptName && requestedBaseName === customScriptName; @@ -22,12 +21,16 @@ export async function loader({ params, context, request }: LoaderFunctionArgs) { return new Response("Script not found", { status: 404 }); } - try { - const url = new URL(request.url); - const trackerUrl = `${url.protocol}//${url.host}/tracker.js`; - return await context.cloudflare.env.ASSETS.fetch(trackerUrl) - } catch (error) { - console.error("Error serving tracker script:", error); - return new Response("Error serving script", { status: 500 }); - } + // The script is served with a wildcard ACAO. CORS cannot gate who loads a + //