Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Upcoming release 0.13
=====================

* ENH: Convenient load/save of interface inputs (https://github.com/nipy/nipype/pull/1591)
* TST: reduce the size of docker images & use tags for images (https://github.com/nipy/nipype/pull/1564)
* ENH: Implement missing inputs/outputs in FSL AvScale (https://github.com/nipy/nipype/pull/1563)
* FIX: Fix symlink test in copyfile (https://github.com/nipy/nipype/pull/1570, https://github.com/nipy/nipype/pull/1586)
Expand Down
39 changes: 37 additions & 2 deletions nipype/interfaces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from datetime import datetime as dt
from dateutil.parser import parse as parseutc
from warnings import warn

import simplejson as json

from .traits_extension import (traits, Undefined, TraitDictObject,
TraitListObject, TraitError,
Expand All @@ -56,6 +56,8 @@

iflogger = logging.getLogger('interface')

PY35 = sys.version_info >= (3, 5)

if runtime_profile:
try:
import psutil
Expand Down Expand Up @@ -758,14 +760,19 @@ class BaseInterface(Interface):
_additional_metadata = []
_redirect_x = False

def __init__(self, **inputs):
def __init__(self, from_file=None, **inputs):
if not self.input_spec:
raise Exception('No input_spec in class: %s' %
self.__class__.__name__)

self.inputs = self.input_spec(**inputs)
self.estimated_memory_gb = 1
self.num_threads = 1

if from_file is not None:
self.load_inputs_from_json(from_file, overwrite=False)


@classmethod
def help(cls, returnhelp=False):
""" Prints class help
Expand Down Expand Up @@ -1148,6 +1155,34 @@ def version(self):
self.__class__.__name__)
return self._version

def load_inputs_from_json(self, json_file, overwrite=True):
"""
A convenient way to load pre-set inputs from a JSON file.
"""

with open(json_file) as fhandle:
inputs_dict = json.load(fhandle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be sufficient except for checks:

self.inputs.update(**inputs_dict)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm, I'm getting object does not have update. (?)

Anyways, I will use sets and get_traitsfree to implement this in a cleaner manner.


for key, newval in list(inputs_dict.items()):
if not hasattr(self.inputs, key):
continue
val = getattr(self.inputs, key)
if overwrite or not isdefined(val):
setattr(self.inputs, key, newval)

def save_inputs_to_json(self, json_file):
"""
A convenient way to save current inputs to a JSON file.
"""
inputs = self.inputs.get()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_dict = self.inputs.get_traitsfree() returns a dictionary after removing all undefined traits.

for key, val in list(inputs.items()):
if not isdefined(val):
inputs.pop(key, None)

iflogger.debug('saving inputs {}', inputs)
with open(json_file, 'w') as fhandle:
json.dump(inputs, fhandle, indent=4)


class Stream(object):
"""Function to capture stdout and stderr streams with timestamps
Expand Down
36 changes: 36 additions & 0 deletions nipype/interfaces/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,42 @@ def _run_interface(self, runtime):
nib.BaseInterface.input_spec = None
yield assert_raises, Exception, nib.BaseInterface

def test_BaseInterface_load_save_inputs():
tmp_dir = tempfile.mkdtemp()
tmp_json = os.path.join(tmp_dir, 'settings.json')

def _rem_undefined(indict):
for key, val in list(indict.items()):
if not nib.isdefined(val):
indict.pop(key, None)
return indict

class InputSpec(nib.TraitedSpec):
input1 = nib.traits.Int()
input2 = nib.traits.Float()
input3 = nib.traits.Bool()
input4 = nib.traits.Str()

class DerivedInterface(nib.BaseInterface):
input_spec = InputSpec

def __init__(self, **inputs):
super(DerivedInterface, self).__init__(**inputs)

inputs_dict = {'input1': 12, 'input3': True,
'input4': 'some string'}
bif = DerivedInterface(**inputs_dict)
bif.save_inputs_to_json(tmp_json)
bif2 = DerivedInterface()
bif2.load_inputs_from_json(tmp_json)
yield assert_equal, _rem_undefined(bif2.inputs.get()), inputs_dict

bif3 = DerivedInterface(from_file=tmp_json)
yield assert_equal, _rem_undefined(bif3.inputs.get()), inputs_dict

inputs_dict.update({'input4': 'some other string'})
bif4 = DerivedInterface(from_file=tmp_json, input4='some other string')
yield assert_equal, _rem_undefined(bif4.inputs.get()), inputs_dict

def assert_not_raises(fn, *args, **kwargs):
fn(*args, **kwargs)
Expand Down