Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions tools/semibin/bin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
<command detect_errors="exit_code"><![CDATA[
#import re
@FASTA_FILES@

#if $mod.select == 'history':
python '$__tool_directory__/convert.py' '$mod.model' &&

SemiBin2 bin
--input-fasta 'contigs.$input_fasta.ext'
--data '$data'
#if $mod.select == 'history'
--model '$mod.model'
--model 'model.pt'
#else
--environment '$mod.environment'
#end if
Expand All @@ -31,6 +35,10 @@ SemiBin2 bin
--compression none
@MIN_LEN@
--orf-finder '$orf_finder'

#if $mod.select == 'history'
&&
rm 'model.pt'
]]></command>
<inputs>
<expand macro="mode_fasta"/>
Expand All @@ -44,7 +52,7 @@ SemiBin2 bin
<expand macro="environment"/>
</when>
<when value="history">
<param argument="--model" type="data" format="h5" label="Trained semi-supervised deep learning model"/>
<param argument="--model" type="data" format="safetensors" label="Trained semi-supervised deep learning model"/>
</when>
</conditional>
<expand macro="min_len"/>
Expand Down Expand Up @@ -78,7 +86,7 @@ SemiBin2 bin
<param name="data" ftype="csv" value="data.csv"/>
<conditional name="mod">
<param name="select" value="history"/>
<param name="model" ftype="h5" value="model.h5"/>
<param name="model" ftype="safetensors" location="https://zenodo.org/records/17373295/files/model.safetensors"/>
</conditional>
<conditional name="min_len">
<param name="method" value="min-len"/>
Expand Down Expand Up @@ -141,7 +149,7 @@ SemiBin2 bin
<param name="data" ftype="csv" value="data.csv"/>
<conditional name="mod">
<param name="select" value="history"/>
<param name="model" ftype="h5" value="model.h5"/>
<param name="model" ftype="safetensors" location="https://zenodo.org/records/17373295/files/model.safetensors"/>
</conditional>
<conditional name="min_len">
<param name="method" value="min-len"/>
Expand Down Expand Up @@ -172,4 +180,4 @@ Outputs

]]></help>
<expand macro="citations"/>
</tool>
</tool>
2 changes: 1 addition & 1 deletion tools/semibin/concatenate_fasta.xml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ Outputs

]]></help>
<expand macro="citations"/>
</tool>
</tool>
94 changes: 94 additions & 0 deletions tools/semibin/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import pickle
import sys

import torch
from safetensors.torch import load_file, save_file

# -------------------------------
# Metadata encoding/decoding
# -------------------------------


def encode_metadata(obj):
"""
Recursively encode Python objects into tensors:
- torch.Tensor → leave as-is
- dict → recursively encode
- list/tuple → convert to dict {0: v0, 1: v1, ...} and encode recursively
- other → pickle into uint8 tensor
"""
if isinstance(obj, torch.Tensor):
return obj
elif isinstance(obj, dict):
return {k: encode_metadata(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return {str(i): encode_metadata(v) for i, v in enumerate(obj)}
else:
data = pickle.dumps(obj)
return torch.tensor(list(data), dtype=torch.uint8)


def decode_metadata(obj):
"""
Recursively decode tensors back into Python objects.
"""
if isinstance(obj, torch.Tensor):
if obj.dtype == torch.uint8:
data = bytes(obj.tolist())
return pickle.loads(data)
return obj
elif isinstance(obj, dict):
# Convert dicts with all digit keys back to lists
if all(k.isdigit() for k in obj.keys()):
return [decode_metadata(obj[k]) for k in sorted(obj.keys(), key=int)]
else:
return {k: decode_metadata(v) for k, v in obj.items()}
else:
return obj

# -------------------------------
# Flatten/unflatten for SafeTensors
# -------------------------------


def flatten_dict(d, parent_key='', sep='/'):
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items


def unflatten_dict(d, sep='/'):
result = {}
for k, v in d.items():
keys = k.split(sep)
target = result
for key in keys[:-1]:
target = target.setdefault(key, {})
target[keys[-1]] = v
return result


# -------------------------------
# Save .pt as SafeTensors
# -------------------------------

if __name__ == "__main__":
FILE_PATH = sys.argv[1]
if FILE_PATH.endswith('.pt'):
checkpoint = torch.load("model.pt", map_location="cpu")
encoded = encode_metadata(checkpoint)
flat = flatten_dict(encoded)
save_file(flat, os.path.join(os.path.dirname(sys.argv[1]), "model.safetensors"))
print("Saved restorable SafeTensors file!")
else:
loaded_flat = load_file("model_restorable.safetensors")
loaded_nested = unflatten_dict(loaded_flat)
restored_checkpoint = decode_metadata(loaded_nested)
torch.save(restored_checkpoint, os.path.join(os.path.dirname(sys.argv[1]), "model.pt"))
print("Saved restored checkpoint as model_restored.pt!")
3 changes: 1 addition & 2 deletions tools/semibin/generate_cannot_links.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ SemiBin2
#if $ml_threshold:
--ml-threshold $ml_threshold
#end if
--cannot-name 'cannot'
--threads \${GALAXY_SLOTS:-1}
--processes \${GALAXY_SLOTS:-1}
]]></command>
Expand Down Expand Up @@ -136,4 +135,4 @@ Outputs

]]></help>
<expand macro="citations"/>
</tool>
</tool>
2 changes: 1 addition & 1 deletion tools/semibin/generate_sequence_features.xml
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,4 @@ Outputs

]]></help>
<expand macro="citations"/>
</tool>
</tool>
6 changes: 3 additions & 3 deletions tools/semibin/macros.xml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?xml version="1.0"?>
<macros>
<token name="@TOOL_VERSION@">2.1.0</token>
<token name="@VERSION_SUFFIX@">1</token>
<token name="@TOOL_VERSION@">2.2.0</token>
<token name="@VERSION_SUFFIX@">0</token>
<token name="@PROFILE@">21.01</token>
<xml name="biotools">
<xrefs>
Expand Down Expand Up @@ -462,7 +462,7 @@ ln -s '$e' '${identifier}.bam' &&
</collection>
</xml>
<xml name="train_output">
<data name="model" format="h5" from_work_dir="output/model.h5" label="${tool.name} on ${on_string}: Semi-supervised deep learning model" />
<data name="model" format="safetensors" from_work_dir="output/model.safetensors" label="${tool.name} on ${on_string}: Semi-supervised deep learning model" />
</xml>
<xml name="cannot_link_output">
<data name="cannot" format="txt" from_work_dir="output/cannot/cannot.txt" label="${tool.name} on ${on_string}: Cannot-link constraints" />
Expand Down
Binary file removed tools/semibin/test-data/model.h5
Binary file not shown.
57 changes: 52 additions & 5 deletions tools/semibin/train.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ SemiBin2 train_semi
--ratio $min_len.ratio
#end if
--orf-finder '$orf_finder'

&& python '$__tool_directory__/convert.py' 'output/model.pt'
&& rm 'output/model.pt'
]]></command>
<inputs>
<conditional name="mode">
Expand Down Expand Up @@ -101,9 +104,53 @@ SemiBin2 train_semi
<param name="random_seed" value="0"/>
<param name="epoches" value="1"/>
<param name="batch_size" value="2048"/>
<output name="model" ftype="h5">
<output name="model" ftype="safetensors">
<assert_contents>
<has_size value="3119000" delta="2000" />
<has_size value="3115127" delta="2000" />
</assert_contents>
</output>
</test>
<test expect_num_outputs="1">
<conditional name="mode">
<param name="select" value="single"/>
<param name="input_fasta" ftype="fasta.bz2" value="input_single.fasta.bz2"/>
<param name="data" ftype="csv" value="data.csv"/>
<param name="data_split" ftype="csv" value="data_split.csv"/>
<param name="cannot_link" ftype="txt" value="cannot.txt"/>
</conditional>
<conditional name="min_len">
<param name="method" value="min-len"/>
<param name="min_len" value="2500" />
</conditional>
<param name="orf_finder" value="prodigal"/>
<param name="random_seed" value="0"/>
<param name="epoches" value="1"/>
<param name="batch_size" value="2048"/>
<output name="model" ftype="safetensors">
<assert_contents>
<has_size value="3115127" delta="2000" />
</assert_contents>
</output>
</test>
<test expect_num_outputs="1">
<conditional name="mode">
<param name="select" value="single"/>
<param name="input_fasta" ftype="fasta.gz" value="input_single.fasta.gz"/>
<param name="data" ftype="csv" value="data.csv"/>
<param name="data_split" ftype="csv" value="data_split.csv"/>
<param name="cannot_link" ftype="txt" value="cannot.txt"/>
</conditional>
<conditional name="min_len">
<param name="method" value="min-len"/>
<param name="min_len" value="2500" />
</conditional>
<param name="orf_finder" value="prodigal"/>
<param name="random_seed" value="0"/>
<param name="epoches" value="1"/>
<param name="batch_size" value="2048"/>
<output name="model" ftype="safetensors">
<assert_contents>
<has_size value="3115127" delta="2000" />
</assert_contents>
</output>
</test>
Expand Down Expand Up @@ -167,9 +214,9 @@ SemiBin2 train_semi
<param name="random_seed" value="0"/>
<param name="epoches" value="20"/>
<param name="batch_size" value="2048"/>
<output name="model" ftype="h5">
<output name="model" ftype="safetensors">
<assert_contents>
<has_size value="3119000" delta="2000" />
<has_size value="3115127" delta="2000" />
</assert_contents>
</output>
</test>
Expand All @@ -192,4 +239,4 @@ Outputs
@HELP_MODEL@
]]></help>
<expand macro="citations"/>
</tool>
</tool>
Loading