Skip to content

Commit 6a72dc6

Browse files
glenn-jochereladco
authored andcommitted
Fix hyp_evolve.yaml indexing bug (ultralytics#6604)
* Fix `hyp_evolve.yaml` indexing bug Bug caused hyp_evolve.yaml to display latest generation result rather than best generation result. * Update plots.py * Update general.py * Update general.py * Update general.py
1 parent dc0ce61 commit 6a72dc6

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

utils/general.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
783783
LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
784784

785785

786-
def print_mutation(results, hyp, save_dir, bucket):
786+
def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
787787
evolve_csv = save_dir / 'evolve.csv'
788788
evolve_yaml = save_dir / 'hyp_evolve.yaml'
789789
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
@@ -803,21 +803,23 @@ def print_mutation(results, hyp, save_dir, bucket):
803803
with open(evolve_csv, 'a') as f:
804804
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
805805

806-
# Print to screen
807-
LOGGER.info(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
808-
LOGGER.info(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n')
809-
810806
# Save yaml
811807
with open(evolve_yaml, 'w') as f:
812808
data = pd.read_csv(evolve_csv)
813809
data = data.rename(columns=lambda x: x.strip()) # strip keys
814-
i = np.argmax(fitness(data.values[:, :7])) #
810+
i = np.argmax(fitness(data.values[:, :4])) #
811+
generations = len(data)
815812
f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
816813
f'# Best generation: {i}\n' +
817-
f'# Last generation: {len(data) - 1}\n' +
814+
f'# Last generation: {generations - 1}\n' +
818815
'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
819816
'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
820-
yaml.safe_dump(hyp, f, sort_keys=False)
817+
yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
818+
819+
# Print to screen
820+
LOGGER.info(prefix + f'{generations} generations finished, current result:\n' +
821+
prefix + ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' +
822+
prefix + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n')
821823

822824
if bucket:
823825
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload

utils/plots.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *;
381381
j = np.argmax(f) # max fitness index
382382
plt.figure(figsize=(10, 12), tight_layout=True)
383383
matplotlib.rc('font', **{'size': 8})
384+
print(f'Best results from row {j} of {evolve_csv}:')
384385
for i, k in enumerate(keys[7:]):
385386
v = x[:, 7 + i]
386387
mu = v[j] # best single result

0 commit comments

Comments
 (0)