Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a5363f8
gitignore
max-mauermann Sep 3, 2024
8789dd9
first changes, embeddings get written to db
max-mauermann Sep 12, 2024
011a66a
embeddings are not inserted if an embedding with equal source + offse…
max-mauermann Sep 23, 2024
39d3fbe
added embeddings to gui, still needs translations and validation
max-mauermann Sep 24, 2024
5a1d18a
de/en translations + validation
max-mauermann Sep 25, 2024
f03fb27
rudimentary search script
max-mauermann Sep 25, 2024
c400de7
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Nov 5, 2024
855aeb4
embeddings gui (WIP)
max-mauermann Nov 5, 2024
e83945c
.
max-mauermann Nov 19, 2024
7a345b6
spectrogram shows now
max-mauermann Nov 19, 2024
ae5ed72
show search result spectrograms
max-mauermann Nov 28, 2024
01320a1
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Nov 28, 2024
f1008f3
play buttons
max-mauermann Nov 28, 2024
0dd876a
.
max-mauermann Nov 28, 2024
8374ecb
dynamically create result plots and max_results input
max-mauermann Nov 29, 2024
3c0355b
exporting results + en language
max-mauermann Dec 2, 2024
b5adfa5
toggle all button + progress
max-mauermann Dec 3, 2024
46170ee
correctly closing database + validation
max-mauermann Dec 4, 2024
72dc7e6
results are now on pages + sorting search results
max-mauermann Dec 10, 2024
6d76a81
changed to perch_hoplite repository and updated code
max-mauermann Jan 15, 2025
49acaab
updates
max-mauermann Jan 20, 2025
f6ac488
updates
max-mauermann Jan 21, 2025
8b6865b
updated cli params. score is now written to export file names
max-mauermann Jan 30, 2025
2cedacf
included perch in requirements
max-mauermann Jan 31, 2025
e759752
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Jan 31, 2025
a8e0149
.
max-mauermann Jan 31, 2025
c8495ec
small fixes and german translations
max-mauermann Jan 31, 2025
75b3a4d
added speed and bandpass filter config (at least for scripts, gui sti…
max-mauermann Feb 4, 2025
7ed800c
.
max-mauermann Feb 4, 2025
1aeca4b
spectrograms in gui are now using the speed/bandpass from the setting…
max-mauermann Feb 11, 2025
048757f
.
max-mauermann Feb 11, 2025
d9aea2e
replaced settings file with hoplite metadata
max-mauermann Feb 11, 2025
a228c5b
implemented crop mode for query sample selection.
max-mauermann Feb 12, 2025
7906294
implemented crop mode in gui and adjusted query sample selection
max-mauermann Feb 12, 2025
b9f75dd
german/english language
max-mauermann Feb 12, 2025
c8de379
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Feb 17, 2025
ba1ae3b
changed embeddings stuff to new structure
max-mauermann Feb 17, 2025
c083e1f
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Feb 17, 2025
a99304c
some smaller fixes
max-mauermann Feb 18, 2025
581e3a6
added embeddings to dependencies
max-mauermann Feb 19, 2025
e7680cf
Merge branch 'main' into embeddings-with-hoplite
max-mauermann Feb 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 132 additions & 104 deletions birdnet_analyzer/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,16 @@
import birdnet_analyzer.model as model
import birdnet_analyzer.utils as utils

SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))


def writeErrorLog(msg):
"""
Appends an error message to the error log file.

Args:
msg (str): The error message to be logged.
"""
with open(cfg.ERROR_LOG_FILE, "a") as elog:
elog.write(msg + "\n")


def saveAsEmbeddingsFile(results: dict[str], fpath: str):
"""Write embeddings to file

Args:
results: A dictionary containing the embeddings at timestamp.
fpath: The path for the embeddings file.
"""
with open(fpath, "w") as f:
for timestamp in results:
f.write(timestamp.replace("-", "\t") + "\t" + ",".join(map(str, results[timestamp])) + "\n")
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.db import interface as hoplite
from functools import partial
from tqdm import tqdm
import json

DATASET_NAME: str = "birdnet_analyzer_dataset"
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))

def analyzeFile(item):
def analyzeFile(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
"""Extracts the embeddings for a file.

Args:
Expand All @@ -51,15 +34,16 @@ def analyzeFile(item):

offset = 0
duration = cfg.FILE_SPLITTING_DURATION
fileLengthSeconds = int(audio.getAudioFileLength(fpath, cfg.SAMPLE_RATE))
results = {}
fileLengthSeconds = int(audio.getAudioFileLength(fpath))

# Start time
start_time = datetime.datetime.now()

# Status
print(f"Analyzing {fpath}", flush=True)

source_id = fpath

# Process each chunk
try:
while offset < fileLengthSeconds:
Expand All @@ -68,6 +52,7 @@ def analyzeFile(item):
samples = []
timestamps = []


for c in range(len(chunks)):
# Add to batch
samples.append(chunks[c])
Expand All @@ -90,11 +75,25 @@ def analyzeFile(item):
# Get timestamp
s_start, s_end = timestamps[i]

# Get prediction
embeddings = e[i]
s_start = s_start
s_end = s_end

# create the source in the database to
db._get_source_id(DATASET_NAME, source_id, insert=True)

# Check if embedding already exists
existing_embedding = db.get_embeddings_by_source(DATASET_NAME, source_id, np.array([s_start, s_end]))

if existing_embedding.size == 0:
# Get prediction
embeddings = e[i]

# Store embeddings
results[f"{s_start}-{s_end}"] = embeddings
# Store embeddings
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))

# Insert into database
db.insert_embedding(embeddings, embeddings_source)
db.commit()

# Reset batch
samples = []
Expand All @@ -109,81 +108,56 @@ def analyzeFile(item):

return

# Save as embeddings file
try:
# We have to check if output path is a file or directory
if cfg.OUTPUT_PATH.rsplit(".", 1)[-1].lower() not in ["txt", "csv"]:
fpath = fpath.replace(cfg.INPUT_PATH, "")
fpath = fpath[1:] if fpath[0] in ["/", "\\"] else fpath

# Make target directory if it doesn't exist
fdir = os.path.join(cfg.OUTPUT_PATH, os.path.dirname(fpath))
os.makedirs(fdir, exist_ok=True)

saveAsEmbeddingsFile(
results, os.path.join(cfg.OUTPUT_PATH, fpath.rsplit(".", 1)[0] + ".birdnet.embeddings.txt")
)
else:
saveAsEmbeddingsFile(results, cfg.OUTPUT_PATH)

except Exception as ex:
# Write error log
print(f"Error: Cannot save embeddings for {fpath}.", flush=True)
utils.writeErrorLog(ex)

return

delta_time = (datetime.datetime.now() - start_time).total_seconds()
print("Finished {} in {:.2f} seconds".format(fpath, delta_time), flush=True)

def getDatabase(db_path: str):
"""Get the database object. Creates or opens the databse.
Args:
db: The path to the database.
Returns:
The database object.
"""

if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser(description="Extract feature embeddings with BirdNET")
parser.add_argument(
"--i",
default=os.path.join(SCRIPT_DIR, "example/"),
help="Path to input file or folder. If this is a file, --o needs to be a file too.",
)
parser.add_argument(
"--o",
default=os.path.join(SCRIPT_DIR, "example/"),
help="Path to output file or folder. If this is a file, --i needs to be a file too.",
)
parser.add_argument(
"--overlap",
type=float,
default=0.0,
help="Overlap of prediction segments. Values in [0.0, 2.9]. Defaults to 0.0.",
)
parser.add_argument("--threads", type=int, default=4, help="Number of CPU threads.")
parser.add_argument(
"--batchsize", type=int, default=1, help="Number of samples to process at the same time. Defaults to 1."
)
parser.add_argument(
"--fmin",
type=int,
default=cfg.SIG_FMIN,
help="Minimum frequency for bandpass filter in Hz. Defaults to {} Hz.".format(cfg.SIG_FMIN),
)
parser.add_argument(
"--fmax",
type=int,
default=cfg.SIG_FMAX,
help="Maximum frequency for bandpass filter in Hz. Defaults to {} Hz.".format(cfg.SIG_FMAX),
)
if not os.path.exists(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
db = sqlite_usearch_impl.SQLiteUsearchDB.create(
db_path=db_path,
usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024) #TODO dont hardcode this
)
return db
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)

def checkDatabaseSettingsFile(db_path: str):
"""Checks if the database settings file exists. If not, it creates it.
Args:
db: The path to the database.
"""

args = parser.parse_args()
settings_path = os.path.join(db_path, "birdnet_analyzer_settings.json")
if not os.path.exists(settings_path):
with open(settings_path, "w") as f:
json.dump({
"BANDPASS_FMIN": cfg.BANDPASS_FMIN,
"BANDPASS_FMAX": cfg.BANDPASS_FMAX,
"AUDIO_SPEED": cfg.AUDIO_SPEED
}, f, indent=4)

else:
with open(settings_path, "r") as f:
settings = json.load(f)
if (settings["BANDPASS_FMIN"] != cfg.BANDPASS_FMIN or
settings["BANDPASS_FMAX"] != cfg.BANDPASS_FMAX or
settings["AUDIO_SPEED"] != cfg.AUDIO_SPEED):
raise ValueError("Database settings do not match current configuration.")

def run(input, database_path, overlap, threads, batchsize, audio_speed, fmin, fmax):
# Set paths relative to script path (requested in #3)
cfg.MODEL_PATH = os.path.join(SCRIPT_DIR, cfg.MODEL_PATH)
cfg.ERROR_LOG_FILE = os.path.join(SCRIPT_DIR, cfg.ERROR_LOG_FILE)

### Make sure to comment out appropriately if you are not using args. ###

# Set input and output path
cfg.INPUT_PATH = args.i
cfg.OUTPUT_PATH = args.o
cfg.INPUT_PATH = input

# Parse input files
if os.path.isdir(cfg.INPUT_PATH):
Expand All @@ -192,36 +166,90 @@ def analyzeFile(item):
cfg.FILE_LIST = [cfg.INPUT_PATH]

# Set overlap
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(args.overlap)))
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(overlap)))

cfg.AUDIO_SPEED = max(0.01, audio_speed)

# Set bandpass frequency range
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(args.fmin)))
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(args.fmax)))
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))

# Set number of threads
if os.path.isdir(cfg.INPUT_PATH):
cfg.CPU_THREADS = max(1, int(args.threads))
cfg.CPU_THREADS = max(1, int(threads))
cfg.TFLITE_THREADS = 1
else:
cfg.CPU_THREADS = 1
cfg.TFLITE_THREADS = max(1, int(args.threads))
cfg.TFLITE_THREADS = max(1, int(threads))

cfg.CPU_THREADS = 1 # TODO: with the current implementation, we can't use more than 1 thread

# Set batch size
cfg.BATCH_SIZE = max(1, int(args.batchsize))
cfg.BATCH_SIZE = max(1, int(batchsize))

# Add config items to each file list entry.
# We have to do this for Windows which does not
# support fork() and thus each process has to
# have its own config. USE LINUX!
flist = [(f, cfg.getConfig()) for f in cfg.FILE_LIST]

db = getDatabase(database_path)
checkDatabaseSettingsFile(database_path)

# Analyze files
if cfg.CPU_THREADS < 2:
for entry in flist:
analyzeFile(entry)
for entry in tqdm(flist):
analyzeFile(entry, db)
else:
with Pool(cfg.CPU_THREADS) as p:
p.map(analyzeFile, flist)
tqdm(p.imap(partial(analyzeFile, db=db), flist))

db.db.close() #TODO: needed to close db connection and avoid having wal/shm files


if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser(description="Extract feature embeddings with BirdNET")
parser.add_argument(
"--i", default=os.path.join(SCRIPT_DIR, "example/"), help="Path to input file or folder."
)
parser.add_argument(
"--db",
default="example/hoplite-db/",
help="Path to the Hoplite database. Defaults to example/hoplite-db/.",
)
parser.add_argument(
"--overlap",
type=float,
default=0.0,
help="Overlap of prediction segments. Values in [0.0, 2.9]. Defaults to 0.0.",
)
parser.add_argument("--threads", type=int, default=4, help="Number of CPU threads.")
parser.add_argument(
"--batchsize", type=int, default=1, help="Number of samples to process at the same time. Defaults to 1."
)
parser.add_argument(
"--audio_speed",
type=float,
default=1.0,
help="Speed factor for audio playback. Values < 1.0 will slow down the audio, values > 1.0 will speed it up. Defaults to 1.0. Values cant go below 0.01.",
)
parser.add_argument(
"--fmin",
type=int,
default=cfg.SIG_FMIN,
help="Minimum frequency for bandpass filter in Hz. Defaults to {} Hz.".format(cfg.SIG_FMIN),
)
parser.add_argument(
"--fmax",
type=int,
default=cfg.SIG_FMAX,
help="Maximum frequency for bandpass filter in Hz. Defaults to {} Hz.".format(cfg.SIG_FMAX),
)

args = parser.parse_args()

run(args.i, args.db, args.overlap, args.threads, args.batchsize, args.audio_speed, args.fmin, args.fmax)

# A few examples to test
# python3 embeddings.py --i example/ --o example/ --threads 4
Expand Down
2 changes: 2 additions & 0 deletions birdnet_analyzer/gui/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import birdnet_analyzer.gui.segments as gs
import birdnet_analyzer.gui.review as review
import birdnet_analyzer.gui.species as species
import birdnet_analyzer.gui.embeddings as embeddings

gu.open_window(
[
Expand All @@ -15,5 +16,6 @@
gs.build_segments_tab,
review.build_review_tab,
species.build_species_tab,
embeddings.build_embeddings_tab,
]
)
10 changes: 9 additions & 1 deletion birdnet_analyzer/gui/assets/gui.css
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ footer {

#single-file-output td:first-of-type span {
text-align: center;
}
}



#embeddings-search-results {
overflow-y: scroll;
height: 700px;
}

Loading