diff --git a/Justfile b/Justfile index 7e8b7fb1..7a7beebf 100644 --- a/Justfile +++ b/Justfile @@ -40,6 +40,9 @@ fmt: check-fuzz: cargo check --manifest-path crates/composefs/fuzz/Cargo.toml +# Run unit + non-privileged integration tests (no VM, no root) +test-all: test test-integration + # Run all checks (clippy + fmt + test + fuzz build) check: clippy check-feature-combos fmt-check test check-fuzz diff --git a/crates/composefs-ctl/src/lib.rs b/crates/composefs-ctl/src/lib.rs index 91a73efc..49db3cb2 100644 --- a/crates/composefs-ctl/src/lib.rs +++ b/crates/composefs-ctl/src/lib.rs @@ -22,8 +22,12 @@ pub use composefs_http; #[cfg(feature = "oci")] pub use composefs_oci; +#[cfg(any(feature = "oci", feature = "http"))] +use std::collections::HashMap; use std::io::Read; use std::path::Path; +#[cfg(any(feature = "oci", feature = "http"))] +use std::sync::Mutex; use std::{ffi::OsString, path::PathBuf}; #[cfg(feature = "oci")] @@ -35,9 +39,15 @@ use anyhow::{Context as _, Result}; use clap::{Parser, Subcommand, ValueEnum}; #[cfg(feature = "oci")] use comfy_table::{Table, presets::UTF8_FULL}; +#[cfg(any(feature = "oci", feature = "http"))] +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use rustix::fs::{CWD, Mode, OFlags}; use serde::Serialize; +#[cfg(any(feature = "oci", feature = "http"))] +use composefs::progress::{ + ComponentId, ProgressEvent, ProgressReporter, ProgressUnit, SharedReporter, +}; use composefs_boot::BootOps; #[cfg(feature = "oci")] use composefs_boot::write_boot; @@ -53,6 +63,94 @@ use composefs::{ tree::RegularFile, }; +/// An `indicatif`-backed [`ProgressReporter`] for use in the CLI. +/// +/// Renders per-component progress bars via [`MultiProgress`]. When a component +/// completes or is skipped the bar is removed; human-readable messages are +/// printed above the bar group via [`MultiProgress::println`]. +#[cfg(any(feature = "oci", feature = "http"))] +struct IndicatifReporter { + multi: MultiProgress, + bars: Mutex>, +} + +#[cfg(any(feature = "oci", feature = "http"))] +impl IndicatifReporter { + fn new() -> Self { + IndicatifReporter { + multi: MultiProgress::new(), + bars: Mutex::new(HashMap::new()), + } + } + + /// Build a shared reporter from this instance. + fn into_shared(self) -> SharedReporter { + Arc::new(self) + } +} + +#[cfg(any(feature = "oci", feature = "http"))] +impl std::fmt::Debug for IndicatifReporter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("IndicatifReporter").finish_non_exhaustive() + } +} + +#[cfg(any(feature = "oci", feature = "http"))] +impl ProgressReporter for IndicatifReporter { + fn report(&self, event: ProgressEvent) { + match event { + ProgressEvent::Started { id, total, unit } => { + let bar = if let Some(total) = total { + self.multi.add(ProgressBar::new(total)) + } else { + self.multi.add(ProgressBar::new_spinner()) + }; + let style = match unit { + ProgressUnit::Bytes => ProgressStyle::with_template( + "[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}", + ), + ProgressUnit::Items => ProgressStyle::with_template( + "[eta {eta}] {bar:40.cyan/blue} {pos:>7}/{len:7} objects {msg}", + ), + // Future unit variants fall back to a generic spinner. + _ => ProgressStyle::with_template( + "[eta {eta}] {bar:40.cyan/blue} {pos}/{len} {msg}", + ), + }; + bar.set_style( + style + .unwrap_or_else(|_| ProgressStyle::default_bar()) + .progress_chars("##-"), + ); + bar.set_message(id.to_string()); + self.bars.lock().unwrap().insert(id, bar); + } + ProgressEvent::Progress { id, fetched, .. } => { + if let Some(bar) = self.bars.lock().unwrap().get(&id) { + bar.set_position(fetched); + } + } + ProgressEvent::Done { id, .. } => { + if let Some(bar) = self.bars.lock().unwrap().remove(&id) { + bar.finish_and_clear(); + } + } + ProgressEvent::Skipped { id } => { + if let Some(bar) = self.bars.lock().unwrap().remove(&id) { + bar.finish_with_message("skipped"); + } + } + ProgressEvent::Message(msg) => { + let _ = self.multi.println(msg); + } + // `ProgressEvent` is #[non_exhaustive]: new variants added to the library + // will be silently ignored here until cfsctl is updated to handle them. + _ => {} + } + } +} + /// JSON output wrapper for `cfsctl fsck --json`. #[derive(Serialize)] struct FsckJsonOutput { @@ -986,8 +1084,10 @@ where // If no explicit name provided, use the image reference as the tag let tag_name = name.as_deref().unwrap_or(image); + let reporter: SharedReporter = IndicatifReporter::new().into_shared(); let opts = composefs_oci::PullOptions { local_fetch: local_fetch.into(), + progress: Some(reporter), ..Default::default() }; @@ -1244,10 +1344,158 @@ where } #[cfg(feature = "http")] Command::Fetch { url, name } => { - let (digest, verity) = composefs_http::download(&url, &name, Arc::clone(&repo)).await?; + let reporter: SharedReporter = IndicatifReporter::new().into_shared(); + let (digest, verity) = composefs_http::download( + &url, + &name, + Arc::clone(&repo), + composefs_http::DownloadOptions { + progress: Some(reporter), + }, + ) + .await?; println!("content {digest}"); println!("verity {}", verity.to_hex()); } } Ok(()) } + +#[cfg(test)] +#[cfg(any(feature = "oci", feature = "http"))] +mod tests { + use super::*; + use composefs::progress::{ProgressEvent, ProgressUnit}; + + // ── IndicatifReporter ──────────────────────────────────────────────────── + + /// A complete valid lifecycle (Started → Progress → Done) must not panic, + /// even without a real terminal (indicatif handles headless gracefully). + #[test] + fn test_indicatif_reporter_valid_lifecycle() { + let reporter = IndicatifReporter::new(); + // Message before any component + reporter.report(ProgressEvent::Message("starting pull".into())); + // Byte-tracked component + reporter.report(ProgressEvent::Started { + id: "sha256:abc".into(), + total: Some(1_000_000), + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Progress { + id: "sha256:abc".into(), + fetched: 500_000, + total: Some(1_000_000), + }); + reporter.report(ProgressEvent::Done { + id: "sha256:abc".into(), + transferred: 1_000_000, + }); + // Item-counted component (HTTP objects) + reporter.report(ProgressEvent::Started { + id: "objects:stream".into(), + total: Some(200), + unit: ProgressUnit::Items, + }); + reporter.report(ProgressEvent::Progress { + id: "objects:stream".into(), + fetched: 100, + total: Some(200), + }); + reporter.report(ProgressEvent::Done { + id: "objects:stream".into(), + transferred: 200, + }); + // Skipped component + reporter.report(ProgressEvent::Started { + id: "sha256:cached".into(), + total: None, + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Skipped { + id: "sha256:cached".into(), + }); + } + + /// Progress/Done events for an ID that was never `Started` must not panic. + /// + /// This guards against error-recovery paths where a `Started` event may + /// have been suppressed or the reporter was attached after the operation + /// began. + #[test] + fn test_indicatif_reporter_unknown_id_no_panic() { + let reporter = IndicatifReporter::new(); + // Progress for unknown ID — should silently ignore + reporter.report(ProgressEvent::Progress { + id: "ghost".into(), + fetched: 42, + total: None, + }); + // Done for unknown ID — should silently ignore + reporter.report(ProgressEvent::Done { + id: "ghost".into(), + transferred: 42, + }); + // Skipped for unknown ID — should silently ignore + reporter.report(ProgressEvent::Skipped { id: "ghost".into() }); + } + + /// A spinner-style bar (unknown total) must not panic. + #[test] + fn test_indicatif_reporter_spinner_lifecycle() { + let reporter = IndicatifReporter::new(); + // Started with unknown total → spinner + reporter.report(ProgressEvent::Started { + id: "layer:unknown-size".into(), + total: None, + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Progress { + id: "layer:unknown-size".into(), + fetched: 1024, + total: None, + }); + reporter.report(ProgressEvent::Done { + id: "layer:unknown-size".into(), + transferred: 2048, + }); + } + + /// Multiple concurrent components must not interfere with each other. + #[test] + fn test_indicatif_reporter_multiple_concurrent_components() { + let reporter = IndicatifReporter::new(); + // Start two layers in parallel + reporter.report(ProgressEvent::Started { + id: "layer:a".into(), + total: Some(100), + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Started { + id: "layer:b".into(), + total: Some(200), + unit: ProgressUnit::Bytes, + }); + // Interleaved progress + reporter.report(ProgressEvent::Progress { + id: "layer:a".into(), + fetched: 50, + total: Some(100), + }); + reporter.report(ProgressEvent::Progress { + id: "layer:b".into(), + fetched: 100, + total: Some(200), + }); + // Layer B finishes first + reporter.report(ProgressEvent::Done { + id: "layer:b".into(), + transferred: 200, + }); + // Layer A finishes + reporter.report(ProgressEvent::Done { + id: "layer:a".into(), + transferred: 100, + }); + } +} diff --git a/crates/composefs-http/Cargo.toml b/crates/composefs-http/Cargo.toml index d6838856..f7048327 100644 --- a/crates/composefs-http/Cargo.toml +++ b/crates/composefs-http/Cargo.toml @@ -15,7 +15,6 @@ anyhow = { version = "1.0.87", default-features = false } bytes = { version = "1.7.1", default-features = false } composefs = { workspace = true } hex = { version = "0.4.0", default-features = false } -indicatif = { version = "0.18.0", default-features = false } reqwest = { version = "0.13.0", features = ["zstd"] } sha2 = { version = "0.11.0", default-features = false } tokio = { version = "1.24.2", default-features = false } diff --git a/crates/composefs-http/src/lib.rs b/crates/composefs-http/src/lib.rs index cced211b..4c880db8 100644 --- a/crates/composefs-http/src/lib.rs +++ b/crates/composefs-http/src/lib.rs @@ -15,19 +15,53 @@ use std::{ use anyhow::{Result, bail}; use bytes::Bytes; use composefs::util::DigestWrite; -use indicatif::{ProgressBar, ProgressStyle}; use reqwest::{Client, Response, Url}; use sha2::{Digest, Sha256}; use tokio::task::JoinSet; +use composefs::progress::{ComponentId, NullReporter, ProgressEvent, ProgressUnit, SharedReporter}; use composefs::{ fsverity::FsVerityHashValue, repository::Repository, splitstream::SplitStreamReader, }; +/// Initial number of concurrent HTTP object fetch requests. +/// +/// Matches the default `SETTINGS_MAX_CONCURRENT_STREAMS` value from RFC 7540 +/// §6.5.2. This bounds the JoinSet backlog while new tasks are queued as +/// existing ones complete. +const INITIAL_CONCURRENT_REQUESTS: usize = 100; + +/// Options for a [`download`] operation. +#[derive(Default)] +pub struct DownloadOptions { + /// Progress reporter for this download operation. + /// + /// When `None`, all progress events are silently discarded. Supply a + /// [`SharedReporter`] implementation (e.g. an `indicatif`-backed renderer) + /// to receive [`ProgressEvent`]s as the download proceeds. + pub progress: Option, +} + +impl std::fmt::Debug for DownloadOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DownloadOptions") + .field( + "progress", + if self.progress.is_some() { + &"Some()" + } else { + &"None" + }, + ) + .finish() + } +} + struct Downloader { client: Client, repo: Arc>, url: Url, + reporter: SharedReporter, } impl Downloader { @@ -70,15 +104,6 @@ impl Downloader { } async fn ensure_stream(self: &Arc, name: &str) -> Result<(String, ObjectID)> { - let progress = ProgressBar::new(2); // the first object gets "ensured" twice - progress.set_style( - ProgressStyle::with_template( - "[eta {eta}] {bar:40.cyan/blue} Fetching {pos} / {len} splitstreams", - ) - .unwrap() - .progress_chars("##-"), - ); - // Ideally we'll get a symlink, but we might get the data directly let (data, is_symlink) = self.fetch("streams/", name).await?; let my_id = if is_symlink { @@ -86,7 +111,10 @@ impl Downloader { } else { self.repo.ensure_object_async(data.into()).await? }; - progress.inc(1); + + self.reporter.report(ProgressEvent::Message(format!( + "Fetching splitstreams for {name}" + ))); let mut objects_todo = HashSet::new(); @@ -99,9 +127,9 @@ impl Downloader { while let Some(id) = splitstreams_todo.pop() { // this is the slow part (downloads, writing to disk, etc.) if self.ensure_object(&id).await? { - progress.inc(1); - } else { - progress.dec_length(1); + self.reporter.report(ProgressEvent::Message(format!( + "Fetched splitstream {id:?}" + ))); } // this part is fast: it only touches the header @@ -111,7 +139,6 @@ impl Downloader { // This is the (normal) case if we encounter a splitstream we didn't see yet... None => { splitstreams_todo.push(verity.clone()); - progress.inc_length(1); } // This is the case where we've already been asked to fetch this stream. We'll @@ -143,25 +170,21 @@ impl Downloader { })?; } - progress.finish(); - - let progress = ProgressBar::new(objects_todo.len() as u64); - progress.set_style( - ProgressStyle::with_template( - "[eta {eta}] {bar:40.cyan/blue} Fetching {pos} / {len} objects", - ) - .unwrap() - .progress_chars("##-"), - ); + let objects_total = objects_todo.len() as u64; + let fetch_id = ComponentId::from(format!("objects:{name}")); + self.reporter.report(ProgressEvent::Started { + id: fetch_id.clone(), + total: Some(objects_total), + unit: ProgressUnit::Items, + }); // Fetch all the objects let mut set = JoinSet::>::new(); let mut iter = objects_todo.into_iter(); + let mut fetched: u64 = 0; - // Queue up 100 initial requests - // See SETTINGS_MAX_CONCURRENT_STREAMS in RFC 7540 - // We might actually want to increase this... - for id in iter.by_ref().take(100) { + // Queue up the initial batch of concurrent requests. + for id in iter.by_ref().take(INITIAL_CONCURRENT_REQUESTS) { let self_ = Arc::clone(self); set.spawn(async move { self_.ensure_object(&id).await }); } @@ -171,10 +194,12 @@ impl Downloader { while let Some(result) = set.join_next().await { if result?? { // a download - progress.inc(1); - } else { - // a not-download - progress.dec_length(1); + fetched += 1; + self.reporter.report(ProgressEvent::Progress { + id: fetch_id.clone(), + fetched, + total: Some(objects_total), + }); } if let Some(id) = iter.next() { @@ -183,18 +208,17 @@ impl Downloader { } } - progress.finish(); + self.reporter.report(ProgressEvent::Done { + id: fetch_id, + transferred: fetched, + }); // Now that we have all of the objects, we can verify that the merged-content of each // splitstream corresponds to its claimed body content checksum, if any... - let progress = ProgressBar::new(splitstreams.len() as u64); - progress.set_style( - ProgressStyle::with_template( - "[eta {eta}] {bar:40.cyan/blue} Verifying {pos} / {len} splitstreams", - ) - .unwrap() - .progress_chars("##-"), - ); + self.reporter.report(ProgressEvent::Message(format!( + "Verifying {} splitstreams", + splitstreams.len() + ))); let mut my_sha256 = None; // TODO: This can definitely happen in parallel... @@ -217,12 +241,8 @@ impl Downloader { if id == my_id { my_sha256 = Some(measured_checksum); } - - progress.inc(1); } - progress.finish(); - // We've definitely set this by now: `my_id` is in `splitstreams`. let my_sha256 = my_sha256.unwrap(); @@ -241,6 +261,7 @@ impl Downloader { /// * `url` - The base HTTP URL where the splitstream repository is hosted /// * `name` - The name of the splitstream to download (located under `streams/` on the server) /// * `repo` - The repository where downloaded objects will be stored +/// * `opts` - Download options including an optional progress reporter /// /// # Returns /// @@ -258,11 +279,15 @@ pub async fn download( url: &str, name: &str, repo: Arc>, + opts: DownloadOptions, ) -> Result<(String, ObjectID)> { + let reporter: SharedReporter = opts.progress.unwrap_or_else(|| Arc::new(NullReporter)); + let downloader = Arc::new(Downloader { client: Client::new(), repo, url: Url::parse(url)?, + reporter, }); downloader.ensure_stream(name).await diff --git a/crates/composefs-integration-tests/src/tests/cstor.rs b/crates/composefs-integration-tests/src/tests/cstor.rs index 0ae5e1b3..554cd79d 100644 --- a/crates/composefs-integration-tests/src/tests/cstor.rs +++ b/crates/composefs-integration-tests/src/tests/cstor.rs @@ -92,8 +92,14 @@ fn privileged_test_cstor_vs_skopeo_equivalence() -> Result<()> { // Import from the OCI directory via skopeo/tar path let skopeo_image_ref = format!("oci:{}:test", oci_path.display()); println!("Importing via skopeo/OCI: {}", skopeo_image_ref); - let (skopeo_pull_result, _skopeo_stats) = - composefs_oci::pull_image(&skopeo_repo, &skopeo_image_ref, None, None).await?; + let (skopeo_pull_result, _skopeo_stats) = composefs_oci::pull_image( + &skopeo_repo, + &skopeo_image_ref, + None, + None, + std::sync::Arc::new(composefs_oci::NullReporter), + ) + .await?; let (skopeo_config_digest, skopeo_config_verity) = skopeo_pull_result.into_config(); // Get layer maps from both configs diff --git a/crates/composefs-oci/Cargo.toml b/crates/composefs-oci/Cargo.toml index c721295f..0bdc4e3b 100644 --- a/crates/composefs-oci/Cargo.toml +++ b/crates/composefs-oci/Cargo.toml @@ -27,7 +27,7 @@ composefs-boot = { workspace = true, optional = true } containers-image-proxy = { version = "0.10", default-features = false } cstorage = { package = "composefs-storage", path = "../composefs-storage", version = "0.4.0", optional = true } hex = { version = "0.4.0", default-features = false } -indicatif = { version = "0.18.0", default-features = false, features = ["tokio"] } +indicatif = { version = "0.17.0", default-features = false } rustix = { version = "1.0.0", features = ["fs"] } serde = { version = "1.0", default-features = false, features = ["derive"] } thiserror = { version = "2.0.0", default-features = false } diff --git a/crates/composefs-oci/src/cstor.rs b/crates/composefs-oci/src/cstor.rs index 00cdc768..18eb759a 100644 --- a/crates/composefs-oci/src/cstor.rs +++ b/crates/composefs-oci/src/cstor.rs @@ -44,7 +44,6 @@ use std::sync::Arc; use anyhow::{Context, Result}; use base64::Engine; -use indicatif::{ProgressBar, ProgressStyle}; use composefs::{ INLINE_CONTENT_MAX_V0, @@ -61,6 +60,7 @@ use cstorage::{ pub use cstorage::init_if_helper; use crate::oci_image::manifest_identifier; +use crate::progress::{ComponentId, ProgressEvent, ProgressUnit, SharedReporter}; use crate::skopeo::{OCI_CONFIG_CONTENT_TYPE, OCI_MANIFEST_CONTENT_TYPE, TAR_LAYER_CONTENT_TYPE}; use crate::{ContentAndVerity, ImportStats, OciDigest, config_identifier, layer_identifier}; @@ -98,6 +98,7 @@ pub async fn import_from_containers_storage( zerocopy: bool, storage_root: Option<&std::path::Path>, additional_image_stores: &[&std::path::Path], + reporter: SharedReporter, ) -> Result<(CstorImportResult, ImportStats)> { // Check if we can access files directly or need a proxy if can_bypass_file_permissions() { @@ -119,6 +120,7 @@ pub async fn import_from_containers_storage( zerocopy, storage_root.as_deref(), &additional_image_stores, + reporter, ) }) .await @@ -132,7 +134,7 @@ pub async fn import_from_containers_storage( "storage_root and additional_image_stores are not supported in rootless mode" ); } - import_from_containers_storage_proxied(repo, image_id, reference, zerocopy).await + import_from_containers_storage_proxied(repo, image_id, reference, zerocopy, reporter).await } } @@ -147,6 +149,7 @@ fn import_from_containers_storage_direct( zerocopy: bool, storage_root: Option<&std::path::Path>, additional_image_stores: &[std::path::PathBuf], + reporter: SharedReporter, ) -> Result<(CstorImportResult, ImportStats)> { let mut stats = ImportStats::default(); let mut ctx = ImportContext::default(); @@ -218,43 +221,41 @@ fn import_from_containers_storage_direct( stats.layers = storage_layer_ids.len() as u64; - // Import each layer with progress bar - let progress = ProgressBar::new(storage_layer_ids.len() as u64); - progress.set_style( - ProgressStyle::default_bar() - .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}") - .expect("valid template") - .progress_chars("=>-"), - ); - let mut layer_refs = Vec::with_capacity(storage_layer_ids.len()); for (storage_layer_id, diff_id) in storage_layer_ids.iter().zip(diff_ids.iter()) { let content_id = layer_identifier(diff_id); - let diff_id_str: &str = diff_id.as_ref(); - let short_id = diff_id_str.get(..19).unwrap_or(diff_id_str); + let id = ComponentId::from(diff_id.to_string()); let layer_verity = if let Some(existing) = repo.has_stream(&content_id)? { - progress.set_message(format!("Already have {short_id}...")); + reporter.report(ProgressEvent::Skipped { id }); stats.layers_already_present += 1; existing } else { - progress.set_message(format!("Importing {short_id}...")); + reporter.report(ProgressEvent::Started { + id: id.clone(), + total: None, + unit: ProgressUnit::Bytes, + }); let (layer_store, layer) = stores .iter() .find_map(|s| Layer::open(s, storage_layer_id).ok().map(|l| (s, l))) .with_context(|| format!("Failed to open layer {}", storage_layer_id))?; let (verity, layer_stats) = import_layer_direct(repo, layer_store, &layer, diff_id, zerocopy, &mut ctx)?; + let bytes = layer_stats.new_bytes(); stats.merge(&layer_stats); + reporter.report(ProgressEvent::Done { + id, + transferred: bytes, + }); verity }; layer_refs.push((diff_id.clone(), layer_verity)); - progress.inc(1); } - progress.finish_with_message("Layers imported"); - finalize_import(repo, &image, &layer_refs, reference, &progress, stats) + reporter.report(ProgressEvent::Message("Layers imported".to_string())); + finalize_import(repo, &image, &layer_refs, reference, &reporter, stats) } /// Proxied (rootless) implementation of containers-storage import. @@ -266,6 +267,7 @@ async fn import_from_containers_storage_proxied( image_id: &str, reference: Option<&str>, zerocopy: bool, + reporter: SharedReporter, ) -> Result<(CstorImportResult, ImportStats)> { let mut stats = ImportStats::default(); let mut ctx = ImportContext::default(); @@ -306,15 +308,6 @@ async fn import_from_containers_storage_proxied( stats.layers = image_info.storage_layer_ids.len() as u64; - // Import each layer with progress bar - let progress = ProgressBar::new(image_info.storage_layer_ids.len() as u64); - progress.set_style( - ProgressStyle::default_bar() - .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}") - .expect("valid template") - .progress_chars("=>-"), - ); - let mut layer_refs = Vec::with_capacity(image_info.storage_layer_ids.len()); for (storage_layer_id, diff_id) in image_info @@ -323,15 +316,18 @@ async fn import_from_containers_storage_proxied( .zip(image_info.layer_diff_ids.iter()) { let content_id = layer_identifier(diff_id); - let diff_id_str: &str = diff_id.as_ref(); - let short_id = diff_id_str.get(..19).unwrap_or(diff_id_str); + let id = ComponentId::from(diff_id.to_string()); let layer_verity = if let Some(existing) = repo.has_stream(&content_id)? { - progress.set_message(format!("Already have {short_id}...")); + reporter.report(ProgressEvent::Skipped { id }); stats.layers_already_present += 1; existing } else { - progress.set_message(format!("Importing {short_id}...")); + reporter.report(ProgressEvent::Started { + id: id.clone(), + total: None, + unit: ProgressUnit::Bytes, + }); let (verity, layer_stats) = import_layer_proxied( repo, &mut proxy, @@ -342,14 +338,19 @@ async fn import_from_containers_storage_proxied( &mut ctx, ) .await?; + let bytes = layer_stats.new_bytes(); stats.merge(&layer_stats); + reporter.report(ProgressEvent::Done { + id, + transferred: bytes, + }); verity }; layer_refs.push((diff_id.clone(), layer_verity)); - progress.inc(1); } - progress.finish_with_message("Layers imported"); + + reporter.report(ProgressEvent::Message("Layers imported".to_string())); // Config and manifest metadata don't have restrictive file permissions, // so we can read them directly without the proxy. @@ -362,7 +363,7 @@ async fn import_from_containers_storage_proxied( // Shutdown the proxy before the blocking finalization proxy.shutdown().await.context("Failed to shutdown proxy")?; - finalize_import(repo, &image, &layer_refs, reference, &progress, stats) + finalize_import(repo, &image, &layer_refs, reference, &reporter, stats) } /// Create config + manifest splitstreams, generate the EROFS image, and tag. @@ -378,7 +379,7 @@ fn finalize_import( image: &Image, layer_refs: &[(OciDigest, ObjectID)], reference: Option<&str>, - progress: &ProgressBar, + reporter: &SharedReporter, stats: ImportStats, ) -> Result<(CstorImportResult, ImportStats)> { // Read the raw config JSON bytes from metadata @@ -391,10 +392,14 @@ fn finalize_import( let content_id = config_identifier(&config_digest); let config_verity = if let Some(existing) = repo.has_stream(&content_id)? { - progress.println(format!("Already have config {config_digest}")); + reporter.report(ProgressEvent::Message(format!( + "Already have config {config_digest}" + ))); existing } else { - progress.println(format!("Creating config splitstream {config_digest}")); + reporter.report(ProgressEvent::Message(format!( + "Creating config splitstream {config_digest}" + ))); let mut writer = repo.create_stream(OCI_CONFIG_CONTENT_TYPE)?; for (diff_id, verity) in layer_refs { @@ -414,10 +419,14 @@ fn finalize_import( let manifest_content_id = manifest_identifier(&manifest_digest); let manifest_verity = if let Some(existing) = repo.has_stream(&manifest_content_id)? { - progress.println(format!("Already have manifest {manifest_digest}")); + reporter.report(ProgressEvent::Message(format!( + "Already have manifest {manifest_digest}" + ))); existing } else { - progress.println(format!("Creating manifest splitstream {manifest_digest}")); + reporter.report(ProgressEvent::Message(format!( + "Creating manifest splitstream {manifest_digest}" + ))); let mut writer = repo.create_stream(OCI_MANIFEST_CONTENT_TYPE)?; let config_ref_key = format!("config:{config_digest}"); diff --git a/crates/composefs-oci/src/lib.rs b/crates/composefs-oci/src/lib.rs index f91cfedd..10aab08c 100644 --- a/crates/composefs-oci/src/lib.rs +++ b/crates/composefs-oci/src/lib.rs @@ -19,6 +19,8 @@ pub mod image; pub mod layer; pub mod oci_image; pub mod oci_layout; +/// Re-exported from [`composefs::progress`]; use that path directly in new code. +pub mod progress; pub mod skopeo; pub mod tar; @@ -70,6 +72,7 @@ pub use oci_image::{ oci_fsck, oci_fsck_image, remove_referrer, remove_referrers_for_subject, resolve_ref, tag_image, untag_image, }; +pub use progress::{ComponentId, NullReporter, ProgressEvent, ProgressReporter, SharedReporter}; pub use skopeo::pull_image; /// Statistics from an image import operation. @@ -227,7 +230,7 @@ pub enum LocalFetchOpt { /// /// Use `Default::default()` for the common case (skopeo transport, no /// containers-storage import). -#[derive(Debug, Default)] +#[derive(Default)] pub struct PullOptions<'a> { /// Image proxy configuration passed to skopeo (ignored for /// `containers-storage:` references when `local_fetch` is not @@ -248,6 +251,32 @@ pub struct PullOptions<'a> { /// `additionalimagestore=` option in containers/storage. /// Only relevant when `local_fetch` is not [`Disabled`](LocalFetchOpt::Disabled). pub additional_image_stores: &'a [&'a std::path::Path], + + /// Progress reporter for this pull operation. + /// + /// When `None`, all progress events are silently discarded. Supply a + /// [`SharedReporter`] implementation (e.g. an `indicatif`-backed renderer) + /// to receive [`ProgressEvent`]s as the pull proceeds. + pub progress: Option, +} + +impl<'a> std::fmt::Debug for PullOptions<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PullOptions") + .field("img_proxy_config", &self.img_proxy_config) + .field("local_fetch", &self.local_fetch) + .field("storage_root", &self.storage_root) + .field("additional_image_stores", &self.additional_image_stores) + .field( + "progress", + if self.progress.is_some() { + &"Some()" + } else { + &"None" + }, + ) + .finish() + } } /// Result of a pull operation. @@ -367,6 +396,10 @@ pub async fn pull( reference: Option<&str>, opts: PullOptions<'_>, ) -> Result> { + let reporter: SharedReporter = opts + .progress + .unwrap_or_else(|| std::sync::Arc::new(NullReporter)); + #[cfg(feature = "containers-storage")] if opts.local_fetch != LocalFetchOpt::Disabled && let Some(image_id) = cstor::parse_containers_storage_ref(imgref) @@ -380,6 +413,7 @@ pub async fn pull( zerocopy, opts.storage_root, opts.additional_image_stores, + reporter, ) .await?; return Ok(PullResult { @@ -392,7 +426,7 @@ pub async fn pull( } let (result, stats) = - skopeo::pull_image(repo, imgref, reference, opts.img_proxy_config).await?; + skopeo::pull_image(repo, imgref, reference, opts.img_proxy_config, reporter).await?; Ok(crate::PullResult { manifest_digest: result.manifest_digest, manifest_verity: result.manifest_verity, @@ -2037,4 +2071,204 @@ mod test { "EROFS should contain hostname" ); } + + // ── Progress API integration tests ─────────────────────────────────────── + + /// Create a minimal OCI layout directory with one (empty) tar layer. + /// + /// Returns the path to the OCI layout directory. The image is pinned to + /// the current host platform so `import_oci_layout` can resolve it. + /// + /// The layer is an empty tar archive (valid tar, zero entries), which is + /// sufficient to exercise the `import_layer_from_file` progress path. + fn make_test_oci_layout(parent: &std::path::Path) -> std::path::PathBuf { + use cap_std_ext::cap_std; + use containers_image_proxy::oci_spec::image::{ + Arch, ConfigBuilder, ImageConfigurationBuilder, Os, PlatformBuilder, RootFsBuilder, + }; + use ocidir::OciDir; + + let oci_dir = parent.join("oci-layout"); + std::fs::create_dir_all(&oci_dir).unwrap(); + let dir = + cap_std::fs::Dir::open_ambient_dir(&oci_dir, cap_std::ambient_authority()).unwrap(); + let ocidir = OciDir::ensure(dir).unwrap(); + + let mut manifest = ocidir.new_empty_manifest().unwrap().build().unwrap(); + let mut config = ImageConfigurationBuilder::default() + .architecture(Arch::default()) + .os(Os::default()) + .rootfs( + RootFsBuilder::default() + .typ("layers") + .diff_ids(Vec::::new()) + .build() + .unwrap(), + ) + .config(ConfigBuilder::default().build().unwrap()) + .build() + .unwrap(); + + // Create an empty tar layer (finish the builder immediately without adding any entries) + let layer = ocidir + .create_layer(None) + .unwrap() + .into_inner() + .unwrap() + .complete() + .unwrap(); + ocidir.push_layer(&mut manifest, &mut config, layer, "layer", None); + + let platform = PlatformBuilder::default() + .architecture(Arch::default()) + .os(Os::default()) + .build() + .unwrap(); + ocidir + .insert_manifest_and_config(manifest, config, None, platform) + .unwrap(); + + oci_dir + } + + /// Pulling a fresh OCI layout image (no prior cache) must emit at least one + /// `Started` event per layer and a matching `Done` event, via the + /// `import_oci_layout` fast path. + /// + /// This is the primary integration test for the progress API: it verifies + /// that the oci_layout fast path actually emits events (previously it + /// emitted none). + #[tokio::test] + async fn test_oci_layout_pull_emits_started_and_done() { + use crate::oci_layout::import_oci_layout; + use crate::progress::ProgressEvent; + use crate::progress::test_support::RecordingReporter; + use composefs::fsverity::Sha256HashValue; + use composefs::test::TestRepo; + + let layout_dir = tempfile::tempdir().unwrap(); + let layout_path = make_test_oci_layout(layout_dir.path()); + + let test_repo = TestRepo::::new(); + let repo = &test_repo.repo; + let recorder = std::sync::Arc::new(RecordingReporter::new()); + let reporter: crate::progress::SharedReporter = + std::sync::Arc::clone(&recorder) as crate::progress::SharedReporter; + + import_oci_layout(repo, &layout_path, None, reporter) + .await + .expect("import_oci_layout should succeed"); + + let events = recorder.events(); + + // There must be at least one Started event + let started_count = events + .iter() + .filter(|e| matches!(e, ProgressEvent::Started { .. })) + .count(); + assert!( + started_count >= 1, + "expected at least one Started event, got {started_count} (total events: {})", + events.len() + ); + + // Every Started must have a matching Done or Skipped + let started_ids: std::collections::HashSet = events + .iter() + .filter_map(|e| { + if let ProgressEvent::Started { id, .. } = e { + Some(id.as_str().to_owned()) + } else { + None + } + }) + .collect(); + for started_id in &started_ids { + let has_terminal = events.iter().any(|e| match e { + ProgressEvent::Done { id, .. } | ProgressEvent::Skipped { id } => { + id.as_str() == started_id + } + _ => false, + }); + assert!( + has_terminal, + "Started for '{started_id}' has no matching Done or Skipped" + ); + } + } + + /// Re-importing the same OCI layout (layers already cached) must emit + /// `Skipped` events rather than `Started`/`Done`. + #[tokio::test] + async fn test_oci_layout_reimport_emits_skipped() { + use crate::oci_layout::import_oci_layout; + use crate::progress::test_support::RecordingReporter; + use crate::progress::{NullReporter, ProgressEvent}; + use composefs::fsverity::Sha256HashValue; + use composefs::test::TestRepo; + + let layout_dir = tempfile::tempdir().unwrap(); + let layout_path = make_test_oci_layout(layout_dir.path()); + + let test_repo = TestRepo::::new(); + let repo = &test_repo.repo; + + // First import (populates cache) + let null: crate::progress::SharedReporter = std::sync::Arc::new(NullReporter); + import_oci_layout(repo, &layout_path, None, null) + .await + .expect("first import should succeed"); + + // Second import (everything already cached) + let recorder = std::sync::Arc::new(RecordingReporter::new()); + let reporter: crate::progress::SharedReporter = + std::sync::Arc::clone(&recorder) as crate::progress::SharedReporter; + import_oci_layout(repo, &layout_path, None, reporter) + .await + .expect("second import should succeed"); + + let events = recorder.events(); + + // On reimport, layers are cached: expect Skipped, not Done + let done_count = events + .iter() + .filter(|e| matches!(e, ProgressEvent::Done { .. })) + .count(); + let skipped_count = events + .iter() + .filter(|e| matches!(e, ProgressEvent::Skipped { .. })) + .count(); + assert_eq!( + done_count, 0, + "no Done events expected on reimport (layers cached), got {done_count}" + ); + assert!( + skipped_count >= 1, + "expected at least one Skipped on reimport, got {skipped_count}" + ); + } + + /// The `import_oci_layout` function with `NullReporter` (via `SharedReporter` + /// wrapping `NullReporter`) must not panic now that it uses the reporter internally. + /// + /// This verifies the zero-overhead default path still works correctly. + #[tokio::test] + async fn test_import_oci_layout_with_null_reporter_does_not_panic() { + use crate::oci_layout::import_oci_layout; + use crate::progress::NullReporter; + use composefs::fsverity::Sha256HashValue; + use composefs::test::TestRepo; + + let layout_dir = tempfile::tempdir().unwrap(); + let layout_path = make_test_oci_layout(layout_dir.path()); + + let test_repo = TestRepo::::new(); + let repo = &test_repo.repo; + + // NullReporter: zero overhead, no events collected + let reporter: crate::progress::SharedReporter = std::sync::Arc::new(NullReporter); + import_oci_layout(repo, &layout_path, None, reporter) + .await + .expect("import_oci_layout with NullReporter should not panic"); + } } diff --git a/crates/composefs-oci/src/oci_layout.rs b/crates/composefs-oci/src/oci_layout.rs index 08bb72f6..2c831cf0 100644 --- a/crates/composefs-oci/src/oci_layout.rs +++ b/crates/composefs-oci/src/oci_layout.rs @@ -34,6 +34,7 @@ use composefs::repository::{ObjectStoreMethod, Repository}; use crate::layer::{decompress_async, import_tar_async, is_tar_media_type, store_blob_async}; use crate::oci_image::manifest_identifier; +use crate::progress::{ComponentId, ProgressEvent, ProgressRead, ProgressUnit, SharedReporter}; use crate::skopeo::OCI_BLOB_CONTENT_TYPE; use crate::skopeo::{OCI_CONFIG_CONTENT_TYPE, OCI_MANIFEST_CONTENT_TYPE}; use crate::{ImportStats, config_identifier, layer_identifier}; @@ -71,12 +72,15 @@ fn resolve_manifest(ocidir: &OciDir, tag: Option<&str>) -> Result( repo: &Arc>, layout_path: &Path, layout_tag: Option<&str>, + reporter: SharedReporter, ) -> Result<(PullResult, ImportStats)> { // Open the OCI layout directory let dir = cap_std::fs::Dir::open_ambient_dir(layout_path, cap_std::ambient_authority()) @@ -93,11 +97,17 @@ pub async fn import_oci_layout( // Import config and layers let config_descriptor = manifest.config(); let layers = manifest.layers(); + reporter.report(ProgressEvent::Message(format!( + "Importing {} layers from OCI layout", + layers.len() + ))); let (config_digest, config_verity, layer_refs, stats) = - import_config_and_layers(repo, &ocidir, layers, config_descriptor) + import_config_and_layers(repo, &ocidir, layers, config_descriptor, &reporter) .await .with_context(|| format!("Failed to import config {}", config_descriptor.digest()))?; + reporter.report(ProgressEvent::Message("Storing manifest".to_string())); + // Store the manifest let manifest_content_id = manifest_identifier(&manifest_digest); let manifest_verity = if let Some(verity) = repo.has_stream(&manifest_content_id)? { @@ -146,6 +156,7 @@ async fn import_config_and_layers( ocidir: &OciDir, manifest_layers: &[Descriptor], config_descriptor: &Descriptor, + reporter: &SharedReporter, ) -> Result<(OciDigest, ObjectID, Vec<(OciDigest, ObjectID)>, ImportStats)> { let config_digest: OciDigest = config_descriptor.digest().clone(); let content_id = config_identifier(&config_digest); @@ -187,6 +198,13 @@ async fn import_config_and_layers( layer_refs.len() ); + // Emit Skipped for each cached layer so callers can close any open progress bars + for (diff_id, _) in &layer_refs { + reporter.report(ProgressEvent::Skipped { + id: ComponentId::from(diff_id.to_string()), + }); + } + return Ok((config_digest, config_id, layer_refs, ImportStats::default())); } @@ -216,17 +234,26 @@ async fn import_config_and_layers( let diff_id = (*diff_id).clone(); let repo = Arc::clone(repo); let permit = Arc::clone(&sem).acquire_owned().await?; + let reporter = Arc::clone(reporter); let layer_file = ocidir .read_blob(descriptor) .with_context(|| format!("Opening layer blob {}", descriptor.digest()))?; let media_type = descriptor.media_type().clone(); + let layer_size = descriptor.size(); layer_tasks.spawn(async move { let _permit = permit; - let (verity, layer_stats) = - import_layer_from_file(&repo, &diff_id, layer_file, &media_type).await?; + let (verity, layer_stats) = import_layer_from_file( + &repo, + &diff_id, + layer_file, + &media_type, + layer_size, + &reporter, + ) + .await?; anyhow::Ok((idx, diff_id, verity, layer_stats)) }); } @@ -270,30 +297,55 @@ async fn import_config_and_layers( } /// Import a single layer by streaming from a file handle. +/// +/// Emits `Started`/`Done` (or `Skipped`) progress events via `reporter`. async fn import_layer_from_file( repo: &Arc>, diff_id: &OciDigest, layer_file: std::fs::File, media_type: &MediaType, + layer_size: u64, + reporter: &SharedReporter, ) -> Result<(ObjectID, ImportStats)> { let content_id = layer_identifier(diff_id); + let id = ComponentId::from(diff_id.to_string()); if let Some(layer_id) = repo.has_stream(&content_id)? { debug!("Already have layer {diff_id}"); + reporter.report(ProgressEvent::Skipped { id }); return Ok((layer_id, ImportStats::default())); } debug!("Importing layer {diff_id}"); - - // Convert std::fs::File to tokio::fs::File for async I/O - let async_file = tokio::fs::File::from_std(layer_file); + reporter.report(ProgressEvent::Started { + id: id.clone(), + total: Some(layer_size), + unit: ProgressUnit::Bytes, + }); + + // Wrap the file reader to emit Progress events as compressed bytes are read. + // This sits before decompression so `fetched` tracks bytes-on-disk, + // matching the `total` from the descriptor size above. + // + // The watch channel provides backpressure: if the renderer is slow, intermediate + // byte counts are coalesced rather than queued, keeping the I/O path non-blocking. + let (async_file, progress_driver) = ProgressRead::new( + tokio::fs::File::from_std(layer_file), + Arc::clone(reporter), + id.clone(), + Some(layer_size), + ); let (object_id, layer_stats) = if is_tar_media_type(media_type) { + // Run the progress driver concurrently with the import. let reader = decompress_async(async_file, media_type)?; - import_tar_async(repo.clone(), reader).await? + let (result, ()) = tokio::join!(import_tar_async(repo.clone(), reader), progress_driver); + result? } else { - // Non-tar blob: store as object and create splitstream wrapper - let (object_id, size, method) = store_blob_async(repo, async_file).await?; + // Non-tar blob: store as object and create splitstream wrapper. + // Run the progress driver concurrently with the blob store. + let (store_result, ()) = tokio::join!(store_blob_async(repo, async_file), progress_driver); + let (object_id, size, method) = store_result?; let mut stats = ImportStats::default(); match method { @@ -318,18 +370,27 @@ async fn import_layer_from_file( stream.add_external_size(size); stream.write_reference(object_id)?; let stream_id = repo.write_stream(stream, &content_id, None)?; + reporter.report(ProgressEvent::Done { + id, + transferred: size, + }); return Ok((stream_id, stats)); }; // Register the stream with its content identifier repo.register_stream(&object_id, &content_id, None).await?; + reporter.report(ProgressEvent::Done { + id, + transferred: layer_size, + }); Ok((object_id, layer_stats)) } #[cfg(test)] mod tests { use super::*; + use crate::progress::NullReporter; #[test] fn test_parse_oci_layout_ref() { @@ -417,7 +478,8 @@ mod tests { .unwrap(); let repo = std::sync::Arc::new(repo); - let result = import_oci_layout(&repo, layout_path, None).await; + let reporter = std::sync::Arc::new(NullReporter); + let result = import_oci_layout(&repo, layout_path, None, reporter).await; let err = result.expect_err("should fail with no matching platform"); let err_msg = format!("{err:#}"); assert!( diff --git a/crates/composefs-oci/src/progress.rs b/crates/composefs-oci/src/progress.rs new file mode 100644 index 00000000..3fc618ea --- /dev/null +++ b/crates/composefs-oci/src/progress.rs @@ -0,0 +1,6 @@ +// Progress types now live in the core `composefs` crate. +// Re-export everything from there so existing code keeps compiling while +// callers migrate their imports. +#[cfg(any(test, feature = "test"))] +pub use composefs::progress::test_support; +pub use composefs::progress::*; diff --git a/crates/composefs-oci/src/skopeo.rs b/crates/composefs-oci/src/skopeo.rs index f4b93d4b..bd5a7269 100644 --- a/crates/composefs-oci/src/skopeo.rs +++ b/crates/composefs-oci/src/skopeo.rs @@ -18,7 +18,6 @@ use containers_image_proxy::{ ConvertedLayerInfo, ImageProxy, ImageProxyConfig, ImageReference, OpenedImage, Transport, }; use fn_error_context::context; -use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use rustix::process::geteuid; use tokio::{io::AsyncReadExt, sync::Semaphore, task::JoinSet}; @@ -33,6 +32,7 @@ use crate::{ layer::{decompress_async, import_tar_async, is_tar_media_type, store_blob_async}, layer_identifier, oci_image::{manifest_identifier, tag_image}, + progress::{ComponentId, ProgressEvent, ProgressRead, ProgressUnit, SharedReporter}, }; /// Result of pulling an OCI image. @@ -75,7 +75,7 @@ struct ImageOp { repo: Arc>, proxy: ImageProxy, img: OpenedImage, - progress: MultiProgress, + reporter: SharedReporter, transport: Transport, } @@ -84,6 +84,7 @@ impl ImageOp { repo: &Arc>, image_ref: &ImageReference, img_proxy_config: Option, + reporter: SharedReporter, ) -> Result { // Fail fast if the repository is not writable, before starting // the image proxy or doing any network I/O. @@ -142,12 +143,11 @@ impl ImageOp { .open_image_ref(image_ref) .await .context("Opening image")?; - let progress = MultiProgress::new(); Ok(ImageOp { repo: Arc::clone(repo), proxy, img, - progress, + reporter, transport, }) } @@ -165,8 +165,9 @@ impl ImageOp { let content_id = layer_identifier(diff_id); if let Some(layer_id) = self.repo.has_stream(&content_id)? { - self.progress - .println(format!("Already have layer {diff_id}"))?; + self.reporter.report(ProgressEvent::Skipped { + id: ComponentId::from(diff_id.to_string()), + }); Ok((layer_id, ImportStats::default())) } else { // Otherwise, we need to fetch it... @@ -197,21 +198,40 @@ impl ImageOp { // See https://github.com/containers/containers-image-proxy-rs/issues/71 let blob_reader = blob_reader.take(descriptor.size()); - let bar = self.progress.add(ProgressBar::new(descriptor.size())); - bar.set_style(ProgressStyle::with_template("[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}") - .unwrap() - .progress_chars("##-")); - let progress = bar.wrap_async_read(blob_reader); - self.progress.println(format!("Fetching layer {diff_id}"))?; + let id = ComponentId::from(diff_id.to_string()); + self.reporter.report(ProgressEvent::Started { + id: id.clone(), + total: Some(descriptor.size()), + unit: ProgressUnit::Bytes, + }); + + // Wrap the blob reader to emit Progress events as compressed bytes are read. + // This sits before decompression so `fetched` tracks bytes-over-the-wire, + // matching the `total` from the descriptor size above. + // + // The watch channel provides backpressure: if the renderer is slow, intermediate + // byte counts are coalesced rather than queued, keeping the I/O path non-blocking. + let (blob_reader, progress_driver) = ProgressRead::new( + blob_reader, + Arc::clone(&self.reporter), + id.clone(), + Some(descriptor.size()), + ); let media_type = descriptor.media_type(); let (object_id, layer_stats) = if is_tar_media_type(media_type) { - // Tar layers: decompress and split into a splitstream - let reader = decompress_async(progress, media_type)?; - import_tar_async(self.repo.clone(), reader).await? + // Tar layers: decompress and split into a splitstream. + // Run the progress driver concurrently with the import. + let reader = decompress_async(blob_reader, media_type)?; + let (result, ()) = + tokio::join!(import_tar_async(self.repo.clone(), reader), progress_driver); + result? } else { - // Non-tar layers (OCI artifacts): stream raw bytes to object store - let (object_id, size, method) = store_blob_async(&self.repo, progress).await?; + // Non-tar layers (OCI artifacts): stream raw bytes to object store. + // Run the progress driver concurrently with the blob store. + let (store_result, ()) = + tokio::join!(store_blob_async(&self.repo, blob_reader), progress_driver); + let (object_id, size, method) = store_result?; driver.await?; let mut stats = ImportStats::default(); @@ -237,6 +257,10 @@ impl ImageOp { stream.add_external_size(size); stream.write_reference(object_id)?; let stream_id = self.repo.write_stream(stream, &content_id, None)?; + self.reporter.report(ProgressEvent::Done { + id, + transferred: size, + }); return Ok((stream_id, stats)); }; @@ -249,6 +273,11 @@ impl ImageOp { .register_stream(&object_id, &content_id, None) .await?; + self.reporter.report(ProgressEvent::Done { + id, + transferred: descriptor.size(), + }); + Ok((object_id, layer_stats)) } } @@ -268,8 +297,9 @@ impl ImageOp { if let Some(config_id) = self.repo.has_stream(&content_id)? { // We already got this config - need to read the layer refs and diff_ids from it - self.progress - .println(format!("Already have container config {config_digest}"))?; + self.reporter.report(ProgressEvent::Message(format!( + "Already have container config {config_digest}" + ))); let (data, named_refs) = crate::oci_image::read_external_splitstream( &self.repo, @@ -310,8 +340,9 @@ impl ImageOp { )) } else { // We need to add the config to the repo - self.progress - .println(format!("Fetching config {config_digest}"))?; + self.reporter.report(ProgressEvent::Message(format!( + "Fetching config {config_digest}" + ))); let (mut config, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?; let config = async move { @@ -433,12 +464,14 @@ impl ImageOp { let manifest_content_id = manifest_identifier(&manifest_digest); let manifest_verity = if let Some(verity) = self.repo.has_stream(&manifest_content_id)? { - self.progress - .println(format!("Already have manifest {manifest_digest}"))?; + self.reporter.report(ProgressEvent::Message(format!( + "Already have manifest {manifest_digest}" + ))); verity } else { - self.progress - .println(format!("Storing manifest {manifest_digest}"))?; + self.reporter.report(ProgressEvent::Message(format!( + "Storing manifest {manifest_digest}" + ))); let mut splitstream = self.repo.create_stream(OCI_MANIFEST_CONTENT_TYPE)?; @@ -483,6 +516,7 @@ pub async fn pull_image( imgref: &str, reference: Option<&str>, img_proxy_config: Option, + reporter: SharedReporter, ) -> Result<(PullResult, ImportStats)> { // Fail fast if the repository is not writable, before doing any I/O. repo.ensure_writable()?; @@ -494,10 +528,10 @@ pub async fn pull_image( let (result, stats) = if image_ref.transport == Transport::OciDir { let (path_str, layout_tag) = crate::oci_layout::parse_oci_layout_ref(&image_ref.name); let layout_path = std::path::Path::new(path_str); - crate::oci_layout::import_oci_layout(repo, layout_path, layout_tag).await? + crate::oci_layout::import_oci_layout(repo, layout_path, layout_tag, reporter).await? } else { // Standard path: use skopeo proxy for other transports - let op = Arc::new(ImageOp::new(repo, &image_ref, img_proxy_config).await?); + let op = Arc::new(ImageOp::new(repo, &image_ref, img_proxy_config, reporter).await?); op.pull() .await .with_context(|| format!("Unable to pull container image {imgref}"))? @@ -534,7 +568,8 @@ pub async fn pull( reference: Option<&str>, img_proxy_config: Option, ) -> Result<(OciDigest, ObjectID, ImportStats)> { - let (result, stats) = pull_image(repo, imgref, reference, img_proxy_config).await?; + let reporter = Arc::new(crate::progress::NullReporter); + let (result, stats) = pull_image(repo, imgref, reference, img_proxy_config, reporter).await?; let (config_digest, config_verity) = result.into_config(); Ok((config_digest, config_verity, stats)) } diff --git a/crates/composefs/fuzz/Cargo.lock b/crates/composefs/fuzz/Cargo.lock index e8640e0a..3bbaf91c 100644 --- a/crates/composefs/fuzz/Cargo.lock +++ b/crates/composefs/fuzz/Cargo.lock @@ -66,7 +66,7 @@ dependencies = [ [[package]] name = "composefs" -version = "0.3.0" +version = "0.4.0" dependencies = [ "anyhow", "composefs-ioctls", @@ -97,7 +97,7 @@ dependencies = [ [[package]] name = "composefs-ioctls" -version = "0.3.0" +version = "0.4.0" dependencies = [ "rustix", "thiserror", diff --git a/crates/composefs/src/lib.rs b/crates/composefs/src/lib.rs index 38d55e1f..ebd74307 100644 --- a/crates/composefs/src/lib.rs +++ b/crates/composefs/src/lib.rs @@ -14,6 +14,7 @@ pub mod fs; pub mod fsverity; pub mod mount; pub mod mountcompat; +pub mod progress; pub mod repository; pub mod splitstream; pub mod tree; diff --git a/crates/composefs/src/progress.rs b/crates/composefs/src/progress.rs new file mode 100644 index 00000000..4787105b --- /dev/null +++ b/crates/composefs/src/progress.rs @@ -0,0 +1,565 @@ +//! Progress reporting API for pull and download operations. +//! +//! Library crates emit [`ProgressEvent`]s through a [`ProgressReporter`] trait +//! object. The default implementation, [`NullReporter`], discards all events +//! at zero cost. Callers such as `cfsctl` supply their own implementation +//! (e.g. an `indicatif`-backed renderer) via [`PullOptions::progress`]. + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, ReadBuf}; + +/// Identity of a component being tracked, typically an OCI layer diff_id or +/// an HTTP object path. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ComponentId(String); + +impl ComponentId { + /// Return the underlying string slice. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl> From for ComponentId { + fn from(s: S) -> Self { + ComponentId(s.into()) + } +} + +impl std::fmt::Display for ComponentId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +/// The unit of measurement for a progress component. +/// +/// Progress events may track either raw bytes (for layer downloads) or an +/// abstract item count (for object fetches where individual sizes are unknown). +/// Renderers should adapt their display accordingly. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum ProgressUnit { + /// The `fetched`/`total` fields count bytes. + Bytes, + /// The `fetched`/`total` fields count discrete items (e.g. objects). + Items, +} + +/// Events emitted during a pull or download operation. +#[derive(Debug, Clone, PartialEq)] +#[non_exhaustive] +pub enum ProgressEvent { + /// A new component (layer/object) has started being fetched. + Started { + /// Identifier for this component. + id: ComponentId, + /// Total amount to transfer (bytes or items depending on `unit`), if known. + total: Option, + /// Unit of measurement for `total` and subsequent `Progress` events. + unit: ProgressUnit, + }, + /// Progress update for a component. + Progress { + /// Identifier for this component. + id: ComponentId, + /// Amount transferred so far (bytes or items depending on the `Started` unit). + fetched: u64, + /// Total amount (bytes or items), if known. + total: Option, + }, + /// A component was skipped because it was already present. + /// + /// This event may be emitted without a preceding [`ProgressEvent::Started`] + /// when the component is determined to be cached before any download begins. + /// Renderers must handle this case gracefully. + Skipped { + /// Identifier for the skipped component. + id: ComponentId, + }, + /// A component completed successfully. + Done { + /// Identifier for this component. + id: ComponentId, + /// Amount actually transferred (bytes or items per the `Started` unit). + transferred: u64, + }, + /// A human-readable status message (replaces progress-bar text lines). + Message(String), +} + +/// Receives progress events from a pull or download operation. +/// +/// Implementations must be `Send + Sync` so they can be shared across async +/// tasks. All methods take `&self` so that the reporter can be held behind an +/// `Arc` without requiring interior mutability beyond what the implementation +/// itself manages (typically a `Mutex`). +pub trait ProgressReporter: Send + Sync { + /// Handle a single progress event. + fn report(&self, event: ProgressEvent); +} + +/// A no-op reporter that discards all events. +/// +/// This is the default when no reporter is provided. Because it has no +/// branches or allocations it compiles away entirely in release builds. +#[derive(Debug, Default)] +pub struct NullReporter; + +impl ProgressReporter for NullReporter { + #[inline] + fn report(&self, _event: ProgressEvent) {} +} + +/// Convenience type alias for a shared, type-erased progress reporter. +pub type SharedReporter = Arc; + +/// An [`AsyncRead`] wrapper that tracks bytes read via a `watch` channel. +/// +/// The reader itself is intentionally minimal: it only increments a counter and +/// publishes it through a non-blocking [`tokio::sync::watch`] channel on each +/// successful read. This keeps the hot I/O path free from any reporter logic. +/// +/// Backpressure is handled by the watch channel itself: if the progress +/// renderer is slow, intermediate byte counts are coalesced — the sender +/// never blocks waiting for the receiver to catch up. +/// +/// Use [`ProgressRead::new`] to construct the reader and its companion driver +/// future. The driver must run concurrently with the read (e.g. via +/// `tokio::join!`) to actually emit [`ProgressEvent::Progress`] events. +/// +/// Place this wrapper *before* any decompressor so that the `fetched` counter +/// reflects compressed bytes-over-the-wire, matching the `total` from the +/// preceding [`ProgressEvent::Started`] event. +pub struct ProgressRead { + inner: R, + /// Non-blocking sender; updating it on every read is fine. + tx: tokio::sync::watch::Sender, +} + +impl std::fmt::Debug for ProgressRead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProgressRead") + .field("inner", &self.inner) + .field("bytes_read", &*self.tx.borrow()) + .finish_non_exhaustive() + } +} + +impl ProgressRead { + /// Wrap `inner` and return `(reader, driver)`. + /// + /// The driver is a future that translates raw byte counts into + /// [`ProgressEvent::Progress`] events via `reporter`. It completes when + /// the reader is dropped (i.e. the channel closes). Run it concurrently: + /// + /// ```ignore + /// let (reader, driver) = ProgressRead::new(blob, reporter, id, total); + /// let decompressor = decompress_async(reader, media_type)?; + /// let (import_result, ()) = tokio::join!(import_tar_async(repo, decompressor), driver); + /// ``` + /// + /// `total` should match the value passed to the preceding `Started` event + /// so the renderer can compute a meaningful percentage. + pub fn new( + inner: R, + reporter: SharedReporter, + id: ComponentId, + total: Option, + ) -> (Self, impl Future) { + let (tx, mut rx) = tokio::sync::watch::channel(0u64); + let driver = async move { + while rx.changed().await.is_ok() { + let fetched = *rx.borrow_and_update(); + reporter.report(ProgressEvent::Progress { + id: id.clone(), + fetched, + total, + }); + } + }; + (Self { inner, tx }, driver) + } +} + +impl AsyncRead for ProgressRead { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let before = buf.filled().len(); + let result = Pin::new(&mut self.inner).poll_read(cx, buf); + if let Poll::Ready(Ok(())) = &result { + let n = (buf.filled().len() - before) as u64; + if n > 0 { + // Overflow-safe: update by adding the delta. Errors are + // ignored — if the driver has already dropped its receiver + // (e.g. the pull was cancelled), we simply stop sending. + self.tx.send_modify(|v| *v += n); + } + } + result + } +} + +// Bring `Future` into scope for the `impl Future` return type. +use std::future::Future; + +#[cfg(any(test, feature = "test"))] +pub mod test_support { + //! Test helpers for verifying progress event sequences. + + use std::sync::Mutex; + + use super::{ProgressEvent, ProgressReporter}; + + /// A [`ProgressReporter`] that records all events for later inspection. + /// + /// Useful in unit tests to assert that the correct sequence of events + /// was emitted during a pull or download operation. + pub struct RecordingReporter { + events: Mutex>, + } + + impl std::fmt::Debug for RecordingReporter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecordingReporter") + .field("events", &self.events.lock().unwrap().len()) + .finish() + } + } + + impl Default for RecordingReporter { + fn default() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + } + + impl RecordingReporter { + /// Create a new empty recorder. + pub fn new() -> Self { + Self::default() + } + + /// Return a snapshot of all events recorded so far. + pub fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + } + + impl ProgressReporter for RecordingReporter { + fn report(&self, event: ProgressEvent) { + self.events.lock().unwrap().push(event); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use super::test_support::RecordingReporter; + use super::*; + + // ── NullReporter ──────────────────────────────────────────────────────── + + /// Calling `report` on `NullReporter` with every variant must not panic. + #[test] + fn test_null_reporter_does_not_panic() { + let reporter = NullReporter; + reporter.report(ProgressEvent::Started { + id: "layer1".into(), + total: Some(1024), + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Progress { + id: "layer1".into(), + fetched: 512, + total: Some(1024), + }); + reporter.report(ProgressEvent::Skipped { + id: "layer2".into(), + }); + reporter.report(ProgressEvent::Done { + id: "layer1".into(), + transferred: 1024, + }); + reporter.report(ProgressEvent::Message("done".to_string())); + } + + // ── ComponentId ───────────────────────────────────────────────────────── + + /// `ComponentId` can be constructed from `&str` and `String`, and its + /// `Display` impl round-trips the inner value. + #[test] + fn test_component_id_conversions() { + let cases = [ + "sha256:abc123", + "objects:my-stream", + "", + "docker://quay.io/foo:latest", + ]; + for input in cases { + let from_str: ComponentId = input.into(); + let from_string: ComponentId = input.to_string().into(); + assert_eq!( + from_str.as_str(), + input, + "ComponentId::from(&str) should store value" + ); + assert_eq!( + from_string.as_str(), + input, + "ComponentId::from(String) should store value" + ); + assert_eq!(from_str.to_string(), input, "Display should round-trip"); + assert_eq!(from_str, from_string, "both constructors should be equal"); + } + } + + /// `ComponentId` implements `Hash` + `Eq` correctly, so it works as a + /// `HashMap` key — which `IndicatifReporter` relies on. + #[test] + fn test_component_id_hash_map_key() { + let mut map: HashMap = HashMap::new(); + let id: ComponentId = "layer1".into(); + map.insert(id.clone(), 42); + + assert_eq!( + map.get(&ComponentId::from("layer1")), + Some(&42), + "lookup by equal ComponentId should succeed" + ); + assert_eq!( + map.get(&ComponentId::from("layer2")), + None, + "lookup by different ComponentId should return None" + ); + + // Ensure remove also works (used in IndicatifReporter on Done/Skipped) + let removed = map.remove(&id); + assert_eq!(removed, Some(42)); + assert!(map.is_empty()); + } + + // ── ProgressEvent ──────────────────────────────────────────────────────── + + /// Every `ProgressEvent` variant must implement `Debug` without panicking. + #[test] + fn test_progress_event_debug_all_variants() { + let events = [ + ProgressEvent::Started { + id: "x".into(), + total: Some(100), + unit: ProgressUnit::Bytes, + }, + ProgressEvent::Started { + id: "y".into(), + total: None, + unit: ProgressUnit::Items, + }, + ProgressEvent::Progress { + id: "x".into(), + fetched: 50, + total: Some(100), + }, + ProgressEvent::Skipped { id: "z".into() }, + ProgressEvent::Done { + id: "x".into(), + transferred: 100, + }, + ProgressEvent::Message("status update".into()), + ]; + for event in &events { + let debug = format!("{event:?}"); + assert!(!debug.is_empty(), "Debug output must not be empty"); + } + } + + /// `ProgressEvent` must be `Clone` and the clone must have the same + /// `Debug` representation as the original. + #[test] + fn test_progress_event_clone() { + let event = ProgressEvent::Started { + id: "layer".into(), + total: Some(1000), + unit: ProgressUnit::Bytes, + }; + let cloned = event.clone(); + assert_eq!( + format!("{event:?}"), + format!("{cloned:?}"), + "Clone should produce an identical value" + ); + } + + // ── RecordingReporter ──────────────────────────────────────────────────── + + /// `RecordingReporter` captures events in order and returns them via + /// `events()`. + #[test] + fn test_recording_reporter_captures_events_in_order() { + let reporter = RecordingReporter::new(); + reporter.report(ProgressEvent::Message("hello".into())); + reporter.report(ProgressEvent::Started { + id: "c1".into(), + total: Some(100), + unit: ProgressUnit::Bytes, + }); + reporter.report(ProgressEvent::Done { + id: "c1".into(), + transferred: 100, + }); + + let events = reporter.events(); + assert_eq!(events.len(), 3, "all three events should be recorded"); + assert!( + matches!(&events[0], ProgressEvent::Message(m) if m == "hello"), + "first event should be Message" + ); + assert!( + matches!(&events[1], ProgressEvent::Started { id, .. } if id.as_str() == "c1"), + "second event should be Started for c1" + ); + assert!( + matches!(&events[2], ProgressEvent::Done { id, .. } if id.as_str() == "c1"), + "third event should be Done for c1" + ); + } + + /// `SharedReporter = Arc` must be safely usable + /// from multiple threads simultaneously. + #[test] + fn test_shared_reporter_is_send_sync() { + let inner = Arc::new(RecordingReporter::new()); + let handles: Vec<_> = (0..4u32) + .map(|i| { + let r = Arc::clone(&inner); + std::thread::spawn(move || { + r.report(ProgressEvent::Message(format!("thread {i}"))); + }) + }) + .collect(); + for handle in handles { + handle.join().expect("thread should not panic"); + } + assert_eq!( + inner.events().len(), + 4, + "all four threads should have recorded their event" + ); + } + + // ── ProgressUnit ───────────────────────────────────────────────────────── + + /// Both `ProgressUnit` variants must be accessible and `Debug`-able. + #[test] + fn test_progress_unit_variants() { + let bytes = ProgressUnit::Bytes; + let items = ProgressUnit::Items; + assert_ne!(bytes, items); + assert!(!format!("{bytes:?}").is_empty()); + assert!(!format!("{items:?}").is_empty()); + } + + // ── ProgressRead ───────────────────────────────────────────────────────── + + /// Helper: run `ProgressRead` over `data` with a concurrent driver task, + /// and return all recorded `Progress` events. + async fn run_progress_read( + data: Vec, + id: ComponentId, + total: Option, + ) -> Vec { + use tokio::io::AsyncReadExt; + + let reporter = Arc::new(test_support::RecordingReporter::new()); + let cursor = tokio::io::BufReader::new(std::io::Cursor::new(data)); + let (mut reader, driver) = + ProgressRead::new(cursor, Arc::clone(&reporter) as SharedReporter, id, total); + // Spawn the driver so it runs independently. When the reader is + // dropped (after read_to_end), the watch sender closes and the driver + // task completes on its own. + let driver_handle = tokio::spawn(driver); + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + // Drop the reader explicitly so the watch sender closes, which lets + // the driver task observe channel closure and exit. + drop(reader); + driver_handle.await.unwrap(); + reporter.events() + } + + /// `ProgressRead` emits at least one `Progress` event when non-empty data + /// is read. Every byte goes through the watch channel, so any non-empty + /// read must produce at least one event. + #[tokio::test] + async fn test_progress_read_emits_events() { + let id: ComponentId = "test-layer".into(); + let total: u64 = 1024; + let data = vec![0u8; total as usize]; + let events = run_progress_read(data, id.clone(), Some(total)).await; + + let progress_events: Vec<_> = events + .iter() + .filter(|e| matches!(e, ProgressEvent::Progress { .. })) + .collect(); + + assert!( + !progress_events.is_empty(), + "expected at least one Progress event" + ); + // All events must carry the correct id and total + for event in &progress_events { + if let ProgressEvent::Progress { + id: eid, + total: etot, + .. + } = event + { + assert_eq!(eid, &id); + assert_eq!(*etot, Some(total)); + } + } + // The last Progress event must report fetched == total + if let Some(ProgressEvent::Progress { fetched, .. }) = progress_events.last() { + assert_eq!( + *fetched, total, + "last Progress event should have fetched == total" + ); + } + } + + /// `ProgressRead` with a zero-length source emits no `Progress` events + /// since the watch value never changes from its initial state. + #[tokio::test] + async fn test_progress_read_empty_source_no_events() { + let events = run_progress_read(vec![], "empty".into(), Some(0)).await; + assert!( + events.is_empty(), + "no events should be emitted for an empty source" + ); + } + + /// Every byte is sent through the watch channel, so even a single byte + /// should produce exactly one `Progress` event. + #[tokio::test] + async fn test_progress_read_single_byte_one_event() { + let events = run_progress_read(vec![42u8], "single".into(), Some(1)).await; + let progress_count = events + .iter() + .filter(|e| matches!(e, ProgressEvent::Progress { .. })) + .count(); + assert_eq!( + progress_count, 1, + "single byte should produce exactly one Progress event" + ); + } +}