Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ inference_kwargs:
incontext:
number: 5
method: bert_similarity
bert_scorer_device: auto
cache_bert_scorer: true
Comment on lines +9 to +10

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer an implementation where retroactively editing configs isn't needed

datasets:
MoralDesert: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_adept_high-1715105775-input-output.json
maximization: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_soartech_high-1716581856-input-output.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ inference_kwargs:
incontext:
number: 5
method: prompt_bert_similarity
bert_scorer_device: auto
cache_bert_scorer: true
leave_one_out_strategy: null
normalization: globalnorm
datasets:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ inference_kwargs:
incontext:
number: 5
method: prompt_bert_similarity
bert_scorer_device: auto
cache_bert_scorer: true
leave_one_out_strategy: null
normalization: null
sort_actions: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ inference_kwargs:
incontext:
number: 5
method: scenario_bert_similarity
# Device for the BERT scorer: 'auto' (cuda if available), 'cpu', or 'cuda'
bert_scorer_device: auto
# Reuse a single BERT scorer instance across calls (avoids reloading the model)
cache_bert_scorer: true
datasets:
MoralDesert: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_adept_high-1715105775-input-output.json
maximization: /data/shared/samba/integrated_results_metrics_eval/captured_dataset_for_chris/baseline_soartech_high-1716581856-input-output.json
Expand Down
68 changes: 56 additions & 12 deletions align_system/utils/incontext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import numpy as np
from abc import ABCMeta, abstractmethod
from bert_score import score as bert_score
from bert_score import BERTScorer
from omegaconf import ListConfig, OmegaConf

from align_system.utils import adm_utils
Expand All @@ -25,7 +25,38 @@
)


def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False):
_bert_scorer = None
_bert_scorer_device = None


def get_bert_scorer(device="auto", cache_scorer=True):
"""Build (or fetch the cached) BERTScorer instance.

Args:
device: Device to run the scorer on ('cpu', 'cuda', etc.).
'auto' (or None) selects 'cuda' when available, else 'cpu'
cache_scorer: If True, reuse a single scorer instance across calls
to avoid reloading the model

Returns:
BERTScorer instance on the requested device
"""
global _bert_scorer, _bert_scorer_device

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking we can arrive at a solution that doesn't require globals. Will include details on my suggested approach in the main review text


if device is None or device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"

if not cache_scorer:
return BERTScorer(lang="en", device=device)

if _bert_scorer is None or _bert_scorer_device != device:
_bert_scorer = BERTScorer(lang="en", device=device)
_bert_scorer_device = device

return _bert_scorer


def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly I'm not a fan of how we have to now pass around bert_scorer_device and cache_bert_scorer parameters in several places

"""Common BERT similarity selection logic for all strategies.

Args:
Expand All @@ -36,11 +67,14 @@ def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_ex
score_adjustments: Optional list of score adjustments (same length as candidates)
least_similar_examples: If True, selects least similar examples to approximate domain shift
between train and eval on train data only
bert_scorer_device: Device to run the BERT scorer on ('auto', 'cpu', 'cuda', etc.)
cache_bert_scorer: If True, reuse a single BERTScorer instance across calls

Returns:
List of selected candidates with 'similarity_score' field added
"""
_, _, scores = bert_score([reference_text] * len(texts_to_compare), texts_to_compare, lang="en")
scorer = get_bert_scorer(device=bert_scorer_device, cache_scorer=cache_bert_scorer)
_, _, scores = scorer.score([reference_text] * len(texts_to_compare), texts_to_compare)

if score_adjustments is not None:
for i, adjustment in enumerate(score_adjustments):
Expand Down Expand Up @@ -72,7 +106,7 @@ def select_random_strategy(possible_examples, n_examples, **kwargs):
return selected_with_scores


def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, **kwargs):
def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs):
"""Scenario-based BERT similarity selection strategy"""
final_candidates = list({ex['scenario_description']: ex for ex in possible_examples}.values())
possible_scenarios = [icl_sample["scenario_description"] for icl_sample in final_candidates]
Expand All @@ -82,11 +116,13 @@ def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scen
possible_scenarios,
scenario_to_match,
n_examples,
least_similar_examples=least_similar_examples
least_similar_examples=least_similar_examples,
bert_scorer_device=bert_scorer_device,
cache_bert_scorer=cache_bert_scorer
)


def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, **kwargs):
def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs):
"""Prompt-based BERT similarity selection strategy"""
final_candidates = list({ex['prompt']: ex for ex in possible_examples}.values())
possible_prompts = [icl_sample["prompt"] for icl_sample in final_candidates]
Expand All @@ -96,11 +132,13 @@ def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt
possible_prompts,
prompt_to_match,
n_examples,
least_similar_examples=least_similar_examples
least_similar_examples=least_similar_examples,
bert_scorer_device=bert_scorer_device,
cache_bert_scorer=cache_bert_scorer
)


def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs):
def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs):
"""Action-matching with BERT similarity selection strategy"""
action_types = set([action.action_type for action in actions])
possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples]
Expand All @@ -119,11 +157,13 @@ def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_ma
prompt_to_match,
n_examples,
score_adjustments,
least_similar_examples=least_similar_examples
least_similar_examples=least_similar_examples,
bert_scorer_device=bert_scorer_device,
cache_bert_scorer=cache_bert_scorer
)


def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs):
def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs):
"""Character-matching with BERT similarity selection strategy"""
action_chars = set([action.character_id for action in actions])
possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples]
Expand All @@ -142,7 +182,9 @@ def select_matching_characters_strategy(possible_examples, n_examples, prompt_to
prompt_to_match,
n_examples,
score_adjustments,
least_similar_examples=least_similar_examples
least_similar_examples=least_similar_examples,
bert_scorer_device=bert_scorer_device,
cache_bert_scorer=cache_bert_scorer
)


Expand Down Expand Up @@ -460,7 +502,9 @@ def select_icl_examples(self, sys_kdma_name, scenario_description_to_match, prom
scenario_to_match=scenario_description_to_match,
prompt_to_match=prompt_to_match,
actions=actions,
least_similar_examples=least_similar_examples
least_similar_examples=least_similar_examples,
bert_scorer_device=self.incontext_settings.get("bert_scorer_device", "auto"),
cache_bert_scorer=self.incontext_settings.get("cache_bert_scorer", True)
)

if self.incontext_settings.get("most_similar_first", True):
Expand Down