Add batch querying support#198
Open
CoolJosh0221 wants to merge 2 commits into
Open
Conversation
…WeightedMeta Selecting N samples previously cost N training rounds, since make_query() returns one entry id per call. This adds a batch path built on the standardized _get_scores() contract (ntucllab#197), while keeping make_query()'s single-int public contract and the per-entry (entry_id, label) observer callback untouched. - QueryStrategy.make_query_batch(batch_size): default = stable descending top-k of _get_scores(); returns np.ndarray of distinct entry ids, most preferred first. TypeError for non-integer batch_size; ValueError for batch_size < 1, an empty pool, or batch_size > n_unlabeled (no silent clamp). Ties break deterministically (stable sort), so make_query_batch(1) may differ from make_query() only at ties. - Semantic overrides where top-k is unfaithful: - RandomSampling: uniform sampling without replacement. - CoreSet: true iterative k-center greedy (Sener & Savarese 2018) with a running min-distance vector; honors metric and transformer. - EpsilonUncertaintySampling: Binomial(batch_size, epsilon) exploration picks drawn from the complement of the top-uncertainty picks, so the batch is always exactly batch_size distinct ids. - ALBL and VarianceReduction: explicit NotImplementedError (inherently sequential / no per-sample scoring). - DiversityWeightedMeta: wraps any score-based strategy so batches are not just top-k with near-duplicate redundancy. Greedy utility (1 - lmbda) * s_norm + lmbda * d_norm, where s_norm is a monotone min-max of the base scores (rank-faithful: never re-interprets score semantics, so confidence-flavored scores like HintSVM's are handled by construction) and d is the min distance to already-selected batch members. First pick = base argmax. Optional candidate_pool_size cap. - Dataset.update_batch(entry_ids, labels): validates lengths, rejects duplicate ids, empty input is a no-op; applies labels through the existing per-entry update() path so observers (ALBL, QUIRE, QBC) see exactly the same incremental notification stream as sequential calls. - CoreSet._get_scores: cdist -> sklearn pairwise_distances, fixing a latent crash on sparse feature matrices (identical dense results). - Tests: 81 new (batch contract across all strategies, override semantics, diversity guarantee with a near-duplicate control fixture, rank-faithfulness with negative/adversarial-magnitude mocks, sparse inputs, update_batch equivalence with sequential updates for QUIRE/QBC/ALBL observers, error paths). Full suite: 217 passing. - Docs: Sphinx entry, README batch-querying section, and examples/batch_query_plot.py (12 vs 120 training rounds for the same 120-label budget on the diabetes dataset).
scikit-learn deprecated `multi_class` in 1.5 and removed it in 1.7, so `LogisticRegression(..., multi_class="ovr")` now raises TypeError. On top of that, `solver="liblinear"` no longer performs one-vs-rest for multiclass data (n_classes >= 3): it raises and directs callers to OneVsRestClassifier. Together these broke 24 tests under scikit-learn 1.8. All changes are behavior-preserving: - Drop `multi_class="ovr"` wherever it was paired with `solver="liblinear"`; liblinear only ever did one-vs-rest, so removing it is a no-op. - Drop `'multi_class': 'multinomial'` from the default logreg_param of MaximumLossReductionMaximalConfidence; multinomial is now the default for the newton-cg solver. - EER iris tests: wrap in SklearnProbaAdapter(OneVsRestClassifier(LogisticRegression(solver="liblinear"))) to retain one-vs-rest. Verified to reproduce the exact recorded query sequences, so no assertion values changed. - LogisticRegression / SklearnAdapter delegation tests: use the default solver with max_iter=1000 on both sides (the tests only assert wrapper-vs-sklearn equality, so the solver is incidental; max_iter avoids a ConvergenceWarning on unscaled iris). - Update the CostSensitiveReferencePairEncoding docstring example. Pre-existing issue independent of the batch-querying work; fails on master too. 217 passed (was 24 failed, 193 passed).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
QueryStrategy.make_query_batch(batch_size)with default score-based top-k selection.Dataset.update_batch(entry_ids, labels)while preserving existing per-entry observer callback behavior.DiversityWeightedMetafor diversity-aware batch selection over existing query strategies.Notes
make_query()keeps returning a single entry id; existing public single-query behavior is unchanged.Dataset.update_batch()validates length, dimensionality, and duplicate entry ids before applying per-entry updates.make_query_batch()because their existing semantics are not directly batch-compatible.Tests
python -m unittest -vMPLBACKEND=Agg MPLCONFIGDIR=/tmp/libact-mpl python examples/batch_query_plot.py