diff --git a/proseco/guide.py b/proseco/guide.py index 42884d96..4cefd3a6 100644 --- a/proseco/guide.py +++ b/proseco/guide.py @@ -402,12 +402,18 @@ def index_combinations(n, m): choose_m = min(len(stage_cands), self.n_guide) n_tries = 0 - for n_tries, comb in enumerate( - index_combinations(len(stage_cands), choose_m), start=1 - ): + for comb in index_combinations(len(stage_cands), choose_m): cands = stage_cands[ list(comb) ] # (note that [(1,2)] is not the same as list((1,2)) + + # If there are any include_ids, then the selected stars must include them. + # If not, skip this combination. + if self.include_ids and not set(self.include_ids).issubset(cands["id"]): + continue + + n_tries += 1 + n_pass, n_tests = run_select_checks( cands ) # This function knows how many tests get run diff --git a/proseco/tests/test_guide.py b/proseco/tests/test_guide.py index 0c916be9..e3317727 100644 --- a/proseco/tests/test_guide.py +++ b/proseco/tests/test_guide.py @@ -22,6 +22,7 @@ get_ax_range, get_guide_catalog, get_pixmag_for_offset, + run_select_checks, ) from ..report_guide import make_report from .test_common import DARK40, OBS_INFO, STD_INFO, mod_std_info @@ -609,6 +610,45 @@ def test_guides_include_bad(): assert "cannot include star id=20" in str(err) +def test_guides_include_close(): + """ + Test force include stars where they would not be selected due to + clustering. + """ + stars = StarsTable.empty() + + stars.add_fake_constellation( + mag=[7.0, 7.0, 7.0, 7.0, 7.0], id=[25, 26, 27, 28, 29], size=2000, n_stars=5 + ) + + stars.add_fake_star(mag=11.0, yang=100, zang=100, id=21) + stars.add_fake_star(mag=11.0, yang=-100, zang=-100, id=22) + stars.add_fake_star(mag=11.0, yang=100, zang=-100, id=23) + stars.add_fake_star(mag=11.0, yang=-100, zang=100, id=24) + + cat1 = get_guide_catalog(**mod_std_info(n_guide=5), stars=stars) + + # Run the cluster checks and confirm all 3 pass + cat1_pass, _ = run_select_checks(cat1) + assert cat1_pass == 3 + + # Confirm that only bright stars are used + assert np.count_nonzero(cat1["mag"] == 7.0) == 5 + + # Force include the faint 4 stars that are also close together + include_ids = [21, 22, 23, 24] + cat2 = get_guide_catalog( + **mod_std_info(n_guide=5), stars=stars, include_ids_guide=include_ids + ) + + # Run the cluster checks and confirm all 3 fail + cat2_pass, _ = run_select_checks(cat2) + assert cat2_pass == 0 + assert np.all(np.in1d(include_ids, cat2["id"])) + # And confirm that only one of the bright stars is used + assert np.count_nonzero(cat2["mag"] == 7.0) == 1 + + @pytest.mark.parametrize("dither", dither_cases) def test_edge_star(dither): """