diff --git a/birdnet_analyzer/analyze/__init__.py b/birdnet_analyzer/analyze/__init__.py index a844f982..4b3b7a6e 100644 --- a/birdnet_analyzer/analyze/__init__.py +++ b/birdnet_analyzer/analyze/__init__.py @@ -63,6 +63,7 @@ def main(): # Load species list from location filter or provided list cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = args.lat, args.lon, args.week cfg.LOCATION_FILTER_THRESHOLD = args.sf_thresh + cfg.TOP_N = args.top_n if cfg.LATITUDE == -1 and cfg.LONGITUDE == -1: if not args.slist: diff --git a/birdnet_analyzer/analyze/utils.py b/birdnet_analyzer/analyze/utils.py index 76692903..75de16c7 100644 --- a/birdnet_analyzer/analyze/utils.py +++ b/birdnet_analyzer/analyze/utils.py @@ -93,11 +93,10 @@ def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_p start, end = timestamp.split("-", 1) for c in result[timestamp]: - if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST): - selection_id += 1 - label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] - code = cfg.CODES[c[0]] if c[0] in cfg.CODES else c[0] - rstring += f"{selection_id}\tSpectrogram 1\t1\t{start}\t{end}\t{low_freq}\t{high_freq}\t{label.split('_', 1)[-1]}\t{code}\t{c[1]:.4f}\t{afile_path}\t{start}\n" + selection_id += 1 + label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] + code = cfg.CODES[c[0]] if c[0] in cfg.CODES else c[0] + rstring += f"{selection_id}\tSpectrogram 1\t1\t{start}\t{end}\t{low_freq}\t{high_freq}\t{label.split('_', 1)[-1]}\t{code}\t{c[1]:.4f}\t{afile_path}\t{start}\n" # Write result string to file out_string += rstring @@ -133,11 +132,10 @@ def generate_audacity(timestamps: list[str], result: dict[str, list], result_pat rstring = "" for c in result[timestamp]: - if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST): - label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] - ts = timestamp.replace("-", "\t") - lbl = label.replace("_", ", ") - rstring += f"{ts}\t{lbl}\t{c[1]:.4f}\n" + label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] + ts = timestamp.replace("-", "\t") + lbl = label.replace("_", ", ") + rstring += f"{ts}\t{lbl}\t{c[1]:.4f}\n" # Write result string to file out_string += rstring @@ -169,23 +167,22 @@ def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_ start, end = timestamp.split("-", 1) for c in result[timestamp]: - if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST): - label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] - rstring += "{},{},{},{},{},{},{},{:.4f},{:.4f},{:.4f},{},{},{}\n".format( - parent_folder.rstrip("/"), - folder_name, - filename, - start, - float(end) - float(start), - label.split("_", 1)[0], - label.split("_", 1)[-1], - c[1], - cfg.LATITUDE, - cfg.LONGITUDE, - cfg.WEEK, - cfg.SIG_OVERLAP, - (1.0 - cfg.SIGMOID_SENSITIVITY) + 1.0, - ) + label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] + rstring += "{},{},{},{},{},{},{},{:.4f},{:.4f},{:.4f},{},{},{}\n".format( + parent_folder.rstrip("/"), + folder_name, + filename, + start, + float(end) - float(start), + label.split("_", 1)[0], + label.split("_", 1)[-1], + c[1], + cfg.LATITUDE, + cfg.LONGITUDE, + cfg.WEEK, + cfg.SIG_OVERLAP, + (1.0 - cfg.SIGMOID_SENSITIVITY) + 1.0, + ) # Write result string to file out_string += rstring @@ -214,10 +211,8 @@ def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str for c in result[timestamp]: start, end = timestamp.split("-", 1) - - if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST): - label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] - rstring += f"{start},{end},{label.split('_', 1)[0]},{label.split('_', 1)[-1]},{c[1]:.4f},{afile_path}\n" + label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])] + rstring += f"{start},{end},{label.split('_', 1)[0]},{label.split('_', 1)[-1]},{c[1]:.4f},{afile_path}\n" # Write result string to file out_string += rstring @@ -582,11 +577,20 @@ def analyze_file(item): pred = p[i] # Assign scores to labels - p_labels = zip(cfg.LABELS, pred, strict=True) + p_labels = [ + p + for p in zip(cfg.LABELS, pred, strict=True) + if (cfg.TOP_N or p[1] >= cfg.MIN_CONFIDENCE) + and (not cfg.SPECIES_LIST or p[0] in cfg.SPECIES_LIST) + ] # Sort by score p_sorted = sorted(p_labels, key=operator.itemgetter(1), reverse=True) + if cfg.TOP_N: + p_sorted = p_sorted[: cfg.TOP_N] + + # TODO hier schon top n oder min conf raussortieren # Store top 5 results and advance indices results[str(s_start) + "-" + str(s_end)] = p_sorted diff --git a/birdnet_analyzer/cli.py b/birdnet_analyzer/cli.py index 145e4b58..a342d2a5 100644 --- a/birdnet_analyzer/cli.py +++ b/birdnet_analyzer/cli.py @@ -109,9 +109,9 @@ def species_args(): ) p.add_argument( "--sf_thresh", - type=lambda a: max(0.01, min(0.99, float(a))), + type=lambda a: max(0.0001, min(0.99, float(a))), default=cfg.LOCATION_FILTER_THRESHOLD, - help="Minimum species occurrence frequency threshold for location filter. Values in [0.01, 0.99].", + help="Minimum species occurrence frequency threshold for location filter. Values in [0.0001, 0.99].", ) return p @@ -220,8 +220,8 @@ def min_conf_args(): p.add_argument( "--min_conf", default=cfg.MIN_CONFIDENCE, - type=lambda a: max(0.01, min(0.99, float(a))), - help="Minimum confidence threshold. Values in [0.01, 0.99].", + type=lambda a: max(0.00001, min(0.99, float(a))), + help="Minimum confidence threshold. Values in [0.00001, 0.99].", ) return p @@ -337,6 +337,12 @@ def __call__(self, parser, args, values, option_string=None): help="Skip files that have already been analyzed.", ) + parser.add_argument( + "--top_n", + type=lambda a: max(1, int(a)), + help="Saves only the top N predictions for each segment independent of their score. Threshold will be ignored.", + ) + return parser @@ -490,9 +496,9 @@ def species_parser(): def train_parser(): """ Creates an argument parser for training a custom classifier with BirdNET. - The parser includes arguments for various training parameters such as input data path, crop mode, - output path, number of epochs, batch size, validation split ratio, learning rate, hidden units, - dropout rate, mixup, upsampling ratio and mode, model format, model save mode, cache mode and file, + The parser includes arguments for various training parameters such as input data path, crop mode, + output path, number of epochs, batch size, validation split ratio, learning rate, hidden units, + dropout rate, mixup, upsampling ratio and mode, model format, model save mode, cache mode and file, and hyperparameter tuning options. Returns: argparse.ArgumentParser: Configured argument parser for training a custom classifier. diff --git a/birdnet_analyzer/config.py b/birdnet_analyzer/config.py index a920304f..60a1c9c4 100644 --- a/birdnet_analyzer/config.py +++ b/birdnet_analyzer/config.py @@ -48,6 +48,9 @@ BANDPASS_FMIN: int = 0 BANDPASS_FMAX: int = 15000 +# Top N species to display in selection table, ignored if set to None +TOP_N = None + # Audio speed AUDIO_SPEED: float = 1.0 diff --git a/birdnet_analyzer/gui/analysis.py b/birdnet_analyzer/gui/analysis.py index ad53ba14..ea5f0a6f 100644 --- a/birdnet_analyzer/gui/analysis.py +++ b/birdnet_analyzer/gui/analysis.py @@ -35,6 +35,8 @@ def analyze_file_wrapper(entry): def run_analysis( input_path: str, output_path: str | None, + use_top_n: bool, + top_n: int, confidence: float, sensitivity: float, overlap: float, @@ -96,6 +98,7 @@ def run_analysis( cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, -1 if use_yearlong else week cfg.LOCATION_FILTER_THRESHOLD = sf_thresh cfg.SKIP_EXISTING_RESULTS = skip_existing + cfg.TOP_N = top_n if use_top_n else None if species_list_choice == gu._CUSTOM_SPECIES: if not species_list_file or not species_list_file.name: diff --git a/birdnet_analyzer/gui/multi_file.py b/birdnet_analyzer/gui/multi_file.py index d615790a..f1ae835e 100644 --- a/birdnet_analyzer/gui/multi_file.py +++ b/birdnet_analyzer/gui/multi_file.py @@ -14,6 +14,8 @@ def run_batch_analysis( output_path, + use_top_n, + top_n, confidence, sensitivity, overlap, @@ -52,6 +54,8 @@ def run_batch_analysis( return run_analysis( None, output_path, + use_top_n, + top_n, confidence, sensitivity, overlap, @@ -129,9 +133,16 @@ def select_directory_wrapper(): # Nishant - Function modified for For Folder se show_progress=False, ) - confidence_slider, sensitivity_slider, overlap_slider, audio_speed_slider, fmin_number, fmax_number = ( - gu.sample_sliders() - ) + ( + use_top_n, + top_n_input, + confidence_slider, + sensitivity_slider, + overlap_slider, + audio_speed_slider, + fmin_number, + fmax_number, + ) = gu.sample_sliders() ( species_list_radio, @@ -199,6 +210,8 @@ def select_directory_wrapper(): # Nishant - Function modified for For Folder se inputs = [ output_directory_predict_state, + use_top_n, + top_n_input, confidence_slider, sensitivity_slider, overlap_slider, diff --git a/birdnet_analyzer/gui/single_file.py b/birdnet_analyzer/gui/single_file.py index 49888e1d..b00d9302 100644 --- a/birdnet_analyzer/gui/single_file.py +++ b/birdnet_analyzer/gui/single_file.py @@ -11,6 +11,8 @@ def run_single_file_analysis( input_path, + use_top_n, + top_n, confidence, sensitivity, overlap, @@ -42,6 +44,8 @@ def run_single_file_analysis( result_filepath = run_analysis( input_path, None, + use_top_n, + top_n, confidence, sensitivity, overlap, @@ -97,9 +101,16 @@ def build_single_analysis_tab(): ) audio_path_state = gr.State() - confidence_slider, sensitivity_slider, overlap_slider, audio_speed_slider, fmin_number, fmax_number = ( - gu.sample_sliders(False) - ) + ( + use_top_n, + top_n_input, + confidence_slider, + sensitivity_slider, + overlap_slider, + audio_speed_slider, + fmin_number, + fmax_number, + ) = gu.sample_sliders(False) ( species_list_radio, @@ -147,6 +158,8 @@ def try_generate_spectrogram(audio_path, generate_spectrogram): inputs = [ audio_path_state, + use_top_n, + top_n_input, confidence_slider, sensitivity_slider, overlap_slider, diff --git a/birdnet_analyzer/gui/utils.py b/birdnet_analyzer/gui/utils.py index 2f7bb439..d36eeefb 100644 --- a/birdnet_analyzer/gui/utils.py +++ b/birdnet_analyzer/gui/utils.py @@ -190,11 +190,11 @@ def build_footer(): f"""