Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions stanza/pipeline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ def build_default_config_option(model_specs):
downloading all models
"""
# handle case when processor variants are used
if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs):
model_spec_0 = model_specs[0]
variant_pkgs = PROCESSOR_VARIANTS.get(model_spec_0.processor)
if variant_pkgs is not None and any(ms.package in variant_pkgs for ms in model_specs):
if len(model_specs) > 1:
raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec.processor))
return f"{model_specs[0].processor}_with_{model_specs[0].package}", True
raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec_0.processor))
return f"{model_spec_0.processor}_with_{model_spec_0.package}", True
# handle case when identity is specified as lemmatizer
elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs):
if any(ms.processor == LEMMA and ms.package == 'identity' for ms in model_specs):
if len(model_specs) > 1:
raise IllegalPackageError("Identity processor selected for lemma, but multiple packages requested")
return f"{LEMMA}_use_identity", True
Expand All @@ -114,6 +116,7 @@ def filter_variants(model_specs):
# given a language and models path, build a default configuration
def build_default_config(resources, lang, model_dir, load_list):
default_config = {}
join_path = os.path.join # localize to avoid global name lookup
for processor, model_specs in load_list:
option = build_default_config_option(model_specs)
if option is not None:
Expand All @@ -122,41 +125,35 @@ def build_default_config(resources, lang, model_dir, load_list):
default_config[option[0]] = option[1]
continue

model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs]
dependencies = [model_spec.dependencies for model_spec in model_specs]

# Special case for NER: load multiple models at once
# The pattern will be:
# a list of ner_model_path
# a list of ner_dependencies
# where each item in ner_dependencies is a map
# the map may contain forward_charlm_path, backward_charlm_path, or any other deps
# The user will be able to override the defaults using a semicolon separated string
# TODO: at least use the same config pattern for all other models
# Generate all paths in a single step using list comprehension and local reference
mpfx = join_path(model_dir, lang, processor)
model_paths = [f"{mpfx}/{ms.package}.pt" for ms in model_specs]
dependencies = [ms.dependencies for ms in model_specs]

if processor == NER:
default_config[f"{processor}_model_path"] = model_paths
dependency_paths = []
for dependency_block in dependencies:
if not dependency_block:
for dep_block, ms in zip(dependencies, model_specs):
if not dep_block:
dependency_paths.append({})
continue
dependency_paths.append({})
for dependency in dependency_block:
dep_processor, dep_model = dependency
dependency_paths[-1][f"{dep_processor}_path"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt')
dep_map = {}
for dep_processor, dep_model in dep_block:
dep_map[f"{dep_processor}_path"] = f"{join_path(model_dir, lang, dep_processor)}/{dep_model}.pt"
dependency_paths.append(dep_map)
default_config[f"{processor}_dependencies"] = dependency_paths
continue

if len(model_specs) > 1:
raise IllegalPackageError("Specified multiple packages for {}, which currently only handles one package".format(processor))

default_config[f"{processor}_model_path"] = model_paths[0]
if not dependencies[0]: continue
for dependency in dependencies[0]:
dep_processor, dep_model = dependency
default_config[f"{processor}_{dep_processor}_path"] = os.path.join(
model_dir, lang, dep_processor, dep_model + '.pt'
)
deps_0 = dependencies[0]
if not deps_0:
continue
mpfx0 = join_path(model_dir, lang, processor) # minor micro-opt: cache path prefix, though only one model_spec here
for dep_processor, dep_model in deps_0:
default_config[f"{processor}_{dep_processor}_path"] = f"{join_path(model_dir, lang, dep_processor)}/{dep_model}.pt"

return default_config

Expand Down