-
Notifications
You must be signed in to change notification settings - Fork 5
able to config bert scorer for gpu/cpu, as well as cache it #281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| import random | ||
| import numpy as np | ||
| from abc import ABCMeta, abstractmethod | ||
| from bert_score import score as bert_score | ||
| from bert_score import BERTScorer | ||
| from omegaconf import ListConfig, OmegaConf | ||
|
|
||
| from align_system.utils import adm_utils | ||
|
|
@@ -25,7 +25,38 @@ | |
| ) | ||
|
|
||
|
|
||
| def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False): | ||
| _bert_scorer = None | ||
| _bert_scorer_device = None | ||
|
|
||
|
|
||
| def get_bert_scorer(device="auto", cache_scorer=True): | ||
| """Build (or fetch the cached) BERTScorer instance. | ||
|
|
||
| Args: | ||
| device: Device to run the scorer on ('cpu', 'cuda', etc.). | ||
| 'auto' (or None) selects 'cuda' when available, else 'cpu' | ||
| cache_scorer: If True, reuse a single scorer instance across calls | ||
| to avoid reloading the model | ||
|
|
||
| Returns: | ||
| BERTScorer instance on the requested device | ||
| """ | ||
| global _bert_scorer, _bert_scorer_device | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking we can arrive at a solution that doesn't require globals. Will include details on my suggested approach in the main review text |
||
|
|
||
| if device is None or device == "auto": | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| if not cache_scorer: | ||
| return BERTScorer(lang="en", device=device) | ||
|
|
||
| if _bert_scorer is None or _bert_scorer_device != device: | ||
| _bert_scorer = BERTScorer(lang="en", device=device) | ||
| _bert_scorer_device = device | ||
|
|
||
| return _bert_scorer | ||
|
|
||
|
|
||
| def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_examples, score_adjustments=None, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly I'm not a fan of how we have to now pass around |
||
| """Common BERT similarity selection logic for all strategies. | ||
|
|
||
| Args: | ||
|
|
@@ -36,11 +67,14 @@ def bert_similarity_selection(candidates, texts_to_compare, reference_text, n_ex | |
| score_adjustments: Optional list of score adjustments (same length as candidates) | ||
| least_similar_examples: If True, selects least similar examples to approximate domain shift | ||
| between train and eval on train data only | ||
| bert_scorer_device: Device to run the BERT scorer on ('auto', 'cpu', 'cuda', etc.) | ||
| cache_bert_scorer: If True, reuse a single BERTScorer instance across calls | ||
|
|
||
| Returns: | ||
| List of selected candidates with 'similarity_score' field added | ||
| """ | ||
| _, _, scores = bert_score([reference_text] * len(texts_to_compare), texts_to_compare, lang="en") | ||
| scorer = get_bert_scorer(device=bert_scorer_device, cache_scorer=cache_bert_scorer) | ||
| _, _, scores = scorer.score([reference_text] * len(texts_to_compare), texts_to_compare) | ||
|
|
||
| if score_adjustments is not None: | ||
| for i, adjustment in enumerate(score_adjustments): | ||
|
|
@@ -72,7 +106,7 @@ def select_random_strategy(possible_examples, n_examples, **kwargs): | |
| return selected_with_scores | ||
|
|
||
|
|
||
| def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, **kwargs): | ||
| def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scenario_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): | ||
| """Scenario-based BERT similarity selection strategy""" | ||
| final_candidates = list({ex['scenario_description']: ex for ex in possible_examples}.values()) | ||
| possible_scenarios = [icl_sample["scenario_description"] for icl_sample in final_candidates] | ||
|
|
@@ -82,11 +116,13 @@ def select_scenario_bert_similarity_strategy(possible_examples, n_examples, scen | |
| possible_scenarios, | ||
| scenario_to_match, | ||
| n_examples, | ||
| least_similar_examples=least_similar_examples | ||
| least_similar_examples=least_similar_examples, | ||
| bert_scorer_device=bert_scorer_device, | ||
| cache_bert_scorer=cache_bert_scorer | ||
| ) | ||
|
|
||
|
|
||
| def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, **kwargs): | ||
| def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt_to_match, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): | ||
| """Prompt-based BERT similarity selection strategy""" | ||
| final_candidates = list({ex['prompt']: ex for ex in possible_examples}.values()) | ||
| possible_prompts = [icl_sample["prompt"] for icl_sample in final_candidates] | ||
|
|
@@ -96,11 +132,13 @@ def select_prompt_bert_similarity_strategy(possible_examples, n_examples, prompt | |
| possible_prompts, | ||
| prompt_to_match, | ||
| n_examples, | ||
| least_similar_examples=least_similar_examples | ||
| least_similar_examples=least_similar_examples, | ||
| bert_scorer_device=bert_scorer_device, | ||
| cache_bert_scorer=cache_bert_scorer | ||
| ) | ||
|
|
||
|
|
||
| def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs): | ||
| def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): | ||
| """Action-matching with BERT similarity selection strategy""" | ||
| action_types = set([action.action_type for action in actions]) | ||
| possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples] | ||
|
|
@@ -119,11 +157,13 @@ def select_matching_actions_strategy(possible_examples, n_examples, prompt_to_ma | |
| prompt_to_match, | ||
| n_examples, | ||
| score_adjustments, | ||
| least_similar_examples=least_similar_examples | ||
| least_similar_examples=least_similar_examples, | ||
| bert_scorer_device=bert_scorer_device, | ||
| cache_bert_scorer=cache_bert_scorer | ||
| ) | ||
|
|
||
|
|
||
| def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, **kwargs): | ||
| def select_matching_characters_strategy(possible_examples, n_examples, prompt_to_match, actions, least_similar_examples=False, bert_scorer_device="auto", cache_bert_scorer=True, **kwargs): | ||
| """Character-matching with BERT similarity selection strategy""" | ||
| action_chars = set([action.character_id for action in actions]) | ||
| possible_prompts = [icl_sample["prompt"] for icl_sample in possible_examples] | ||
|
|
@@ -142,7 +182,9 @@ def select_matching_characters_strategy(possible_examples, n_examples, prompt_to | |
| prompt_to_match, | ||
| n_examples, | ||
| score_adjustments, | ||
| least_similar_examples=least_similar_examples | ||
| least_similar_examples=least_similar_examples, | ||
| bert_scorer_device=bert_scorer_device, | ||
| cache_bert_scorer=cache_bert_scorer | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -460,7 +502,9 @@ def select_icl_examples(self, sys_kdma_name, scenario_description_to_match, prom | |
| scenario_to_match=scenario_description_to_match, | ||
| prompt_to_match=prompt_to_match, | ||
| actions=actions, | ||
| least_similar_examples=least_similar_examples | ||
| least_similar_examples=least_similar_examples, | ||
| bert_scorer_device=self.incontext_settings.get("bert_scorer_device", "auto"), | ||
| cache_bert_scorer=self.incontext_settings.get("cache_bert_scorer", True) | ||
| ) | ||
|
|
||
| if self.incontext_settings.get("most_similar_first", True): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer an implementation where retroactively editing configs isn't needed