Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/template_registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ If you have a custom template class, register an instance so others can fetch it
```python
from typing import Any
from mdio.builder.template_registry import register_template
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


class MyTemplate(AbstractDatasetTemplate):
def __init__(self, domain: SeismicDataDomain = "time"):
super().__init__(domain)
Expand All @@ -82,6 +83,7 @@ class MyTemplate(AbstractDatasetTemplate):
def _load_dataset_attributes(self) -> dict[str, Any]:
return {"surveyType": "2D", "gatherType": "custom"}


# Make it available globally
registered_name = register_template(MyTemplate("time"))
print(registered_name) # "MyTemplateTime"
Expand Down
4 changes: 2 additions & 2 deletions src/mdio/builder/template_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import threading
from typing import TYPE_CHECKING

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.seismic_2d_poststack import Seismic2DPostStackTemplate
from mdio.builder.templates.seismic_2d_prestack_cdp import Seismic2DPreStackCDPTemplate
from mdio.builder.templates.seismic_2d_prestack_shot import Seismic2DPreStackShotTemplate
Expand All @@ -29,7 +29,7 @@
from mdio.builder.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate

if TYPE_CHECKING:
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, data_domain: SeismicDataDomain) -> None:
msg = "domain must be 'depth' or 'time'"
raise ValueError(msg)

self._spatial_dim_names: tuple[str, ...] = ()
self._dim_names: tuple[str, ...] = ()
self._physical_coord_names: tuple[str, ...] = ()
self._logical_coord_names: tuple[str, ...] = ()
Expand Down Expand Up @@ -99,8 +98,8 @@ def trace_domain(self) -> str:

@property
def spatial_dimension_names(self) -> tuple[str, ...]:
"""Returns the names of only the spatial dimensions."""
return copy.deepcopy(self._spatial_dim_names)
"""Returns the names of the dimensions without data domain (last axis)."""
return copy.deepcopy(self._dim_names[:-1])

@property
def dimension_names(self) -> tuple[str, ...]:
Expand All @@ -123,10 +122,18 @@ def coordinate_names(self) -> tuple[str, ...]:
return copy.deepcopy(self._physical_coord_names + self._logical_coord_names)

@property
def full_chunk_size(self) -> tuple[int, ...]:
"""Returns the chunk size for the variables."""
def full_chunk_shape(self) -> tuple[int, ...]:
"""Returns the chunk shape for the variables."""
return copy.deepcopy(self._var_chunk_shape)

@full_chunk_shape.setter
def full_chunk_shape(self, shape: tuple[int, ...]) -> None:
"""Sets the chunk shape for the variables."""
if len(shape) != len(self._dim_sizes):
msg = f"Chunk shape {shape} does not match dimension sizes {self._dim_sizes}"
raise ValueError(msg)
self._var_chunk_shape = shape

@property
@abstractmethod
def _name(self) -> str:
Expand Down Expand Up @@ -192,7 +199,7 @@ def _add_coordinates(self) -> None:
for name in self.coordinate_names:
self._builder.add_coordinate(
name=name,
dimensions=self._spatial_dim_names,
dimensions=self.spatial_dimension_names,
data_type=ScalarType.FLOAT64,
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
Expand All @@ -202,15 +209,15 @@ def _add_trace_mask(self) -> None:
"""Add trace mask variables."""
self._builder.add_variable(
name="trace_mask",
dimensions=self._spatial_dim_names,
dimensions=self.spatial_dimension_names,
data_type=ScalarType.BOOL,
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd), # also default in zarr3
coordinates=self.coordinate_names,
)

def _add_trace_headers(self, header_dtype: StructuredType) -> None:
"""Add trace mask variables."""
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self._var_chunk_shape[:-1]))
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self.full_chunk_shape[:-1]))
self._builder.add_variable(
name="headers",
dimensions=self.spatial_dimension_names,
Expand All @@ -226,7 +233,7 @@ def _add_variables(self) -> None:
A virtual method that can be overwritten by subclasses to add custom variables.
Uses the class field 'builder' to add variables to the dataset.
"""
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self._var_chunk_shape))
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self.full_chunk_shape))
unit = self.get_unit_by_key(self._default_variable_name)
self._builder.add_variable(
name=self.default_variable_name,
Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_2d_poststack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


Expand All @@ -12,8 +12,7 @@ class Seismic2DPostStackTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._spatial_dim_names = ("cdp",)
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("cdp", self._data_domain)
self._physical_coord_names = ("cdp_x", "cdp_y")
self._var_chunk_shape = (1024, 1024)

Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_2d_prestack_cdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import CdpGatherDomain
from mdio.builder.templates.types import SeismicDataDomain

Expand All @@ -18,8 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomai
msg = "gather_type must be 'offset' or 'angle'"
raise ValueError(msg)

self._spatial_dim_names = ("cdp", self._gather_domain)
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("cdp", self._gather_domain, self._data_domain)
self._physical_coord_names = ("cdp_x", "cdp_y")
self._var_chunk_shape = (16, 64, 1024)

Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_2d_prestack_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mdio.builder.schemas import compressors
from mdio.builder.schemas.dtype import ScalarType
from mdio.builder.schemas.v1.variable import CoordinateMetadata
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


Expand All @@ -15,8 +15,7 @@ class Seismic2DPreStackShotTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._spatial_dim_names = ("shot_point", "channel")
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("shot_point", "channel", self._data_domain)
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
self._logical_coord_names = ("gun",)
self._var_chunk_shape = (16, 32, 2048)
Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_3d_poststack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


Expand All @@ -12,8 +12,7 @@ class Seismic3DPostStackTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._spatial_dim_names = ("inline", "crossline")
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("inline", "crossline", self._data_domain)
self._physical_coord_names = ("cdp_x", "cdp_y")
self._var_chunk_shape = (128, 128, 128)

Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_3d_prestack_cdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import CdpGatherDomain
from mdio.builder.templates.types import SeismicDataDomain

Expand All @@ -18,8 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomai
msg = "gather_type must be 'offset' or 'angle'"
raise ValueError(msg)

self._spatial_dim_names = ("inline", "crossline", self._gather_domain)
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("inline", "crossline", self._gather_domain, self._data_domain)
self._physical_coord_names = ("cdp_x", "cdp_y")
self._var_chunk_shape = (8, 8, 32, 512)

Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_3d_prestack_coca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mdio.builder.schemas import compressors
from mdio.builder.schemas.dtype import ScalarType
from mdio.builder.schemas.v1.variable import CoordinateMetadata
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


Expand All @@ -15,8 +15,7 @@ class Seismic3DPreStackCocaTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._spatial_dim_names = ("inline", "crossline", "offset", "azimuth")
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("inline", "crossline", "offset", "azimuth", self._data_domain)
self._physical_coord_names = ("cdp_x", "cdp_y")
self._var_chunk_shape = (8, 8, 32, 1, 1024)

Expand Down
5 changes: 2 additions & 3 deletions src/mdio/builder/templates/seismic_3d_prestack_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mdio.builder.schemas import compressors
from mdio.builder.schemas.dtype import ScalarType
from mdio.builder.schemas.v1.variable import CoordinateMetadata
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.builder.templates.types import SeismicDataDomain


Expand All @@ -15,8 +15,7 @@ class Seismic3DPreStackShotTemplate(AbstractDatasetTemplate):
def __init__(self, data_domain: SeismicDataDomain):
super().__init__(data_domain=data_domain)

self._spatial_dim_names = ("shot_point", "cable", "channel")
self._dim_names = (*self._spatial_dim_names, self._data_domain)
self._dim_names = ("shot_point", "cable", "channel", self._data_domain)
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
self._logical_coord_names = ("gun",)
self._var_chunk_shape = (8, 1, 128, 2048)
Expand Down
14 changes: 7 additions & 7 deletions src/mdio/converters/segy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from xarray import Dataset as xr_Dataset

from mdio.builder.schemas import Dataset
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.core.dimension import Dimension
from mdio.segy.file import SegyFileArguments
from mdio.segy.file import SegyFileInfo
Expand Down Expand Up @@ -158,21 +158,21 @@ def _scan_for_headers(
This is an expensive operation.
It scans the SEG-Y file in chunks by using ProcessPoolExecutor
"""
full_chunk_size = template.full_chunk_size
full_chunk_shape = template.full_chunk_shape
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
segy_file_kwargs=segy_file_kwargs,
segy_file_info=segy_file_info,
return_headers=True,
template=template,
chunksize=full_chunk_size,
chunksize=full_chunk_shape,
grid_overrides=grid_overrides,
)
if full_chunk_size != chunk_size:
if full_chunk_shape != chunk_size:
# TODO(Dmitriy): implement grid overrides
# https://github.com/TGSAI/mdio-python/issues/585
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
# support for grid overrides is implemented
err = "Support for changing full_chunk_size in grid overrides is not yet implemented"
err = "Support for changing full_chunk_shape in grid overrides is not yet implemented"
raise NotImplementedError(err)
return segy_dimensions, segy_headers

Expand Down Expand Up @@ -439,7 +439,7 @@ def enhanced_add_variables() -> None:
original_add_variables()

# Now add the raw headers variable
chunk_shape = mdio_template._var_chunk_shape[:-1]
chunk_shape = mdio_template.full_chunk_shape[:-1]

# Create chunk grid metadata
chunk_metadata = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=chunk_shape))
Expand All @@ -448,7 +448,7 @@ def enhanced_add_variables() -> None:
mdio_template._builder.add_variable(
name="raw_headers",
long_name="Raw Headers",
dimensions=mdio_template._dim_names[:-1], # All dimensions except vertical
dimensions=mdio_template.spatial_dimension_names,
data_type=ScalarType.BYTES240,
compressor=Blosc(cname=BloscCname.zstd),
coordinates=None, # No coordinates as specified
Expand Down
2 changes: 1 addition & 1 deletion src/mdio/segy/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def parse_headers(
segy_file_kwargs: SegyFileArguments,
num_traces: int,
subset: list[str] | None = None,
subset: tuple[str, ...] | None = None,
block_size: int = 10000,
progress_bar: bool = True,
) -> HeaderArray:
Expand Down
4 changes: 2 additions & 2 deletions src/mdio/segy/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from numpy.typing import DTypeLike
from segy.arrays import HeaderArray

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.segy.file import SegyFileArguments
from mdio.segy.file import SegyFileInfo

Expand Down Expand Up @@ -56,7 +56,7 @@ def get_grid_plan( # noqa: C901, PLR0913
grid_overrides = {}

# Keep only dimension and non-dimension coordinates excluding the vertical axis
horizontal_dimensions = template.dimension_names[:-1]
horizontal_dimensions = template.spatial_dimension_names
horizontal_coordinates = horizontal_dimensions + template.coordinate_names
headers_subset = parse_headers(
segy_file_kwargs=segy_file_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_segy_spec_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from segy.schema import HeaderField
from segy.standards import get_segy_standard

from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
from mdio.builder.templates.base import AbstractDatasetTemplate
from mdio.converters.segy import _validate_spec_in_template


Expand Down
3 changes: 1 addition & 2 deletions tests/unit/v1/templates/test_seismic_2d_poststack.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ def test_configuration(self, data_domain: SeismicDataDomain) -> None:

# Template attributes
assert t._data_domain == data_domain
assert t._spatial_dim_names == ("cdp",)
assert t._dim_names == ("cdp", data_domain)
assert t._physical_coord_names == ("cdp_x", "cdp_y")
assert t._var_chunk_shape == (1024, 1024)
assert t.full_chunk_shape == (1024, 1024)

# Variables instantiated when build_dataset() is called
assert t._builder is None
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/v1/templates/test_seismic_2d_prestack_cdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ def test_configuration(self, data_domain: SeismicDataDomain, gather_domain: CdpG
t = Seismic2DPreStackCDPTemplate(data_domain, gather_domain)

# Template attributes for prestack CDP
assert t._spatial_dim_names == ("cdp", gather_domain)
assert t._dim_names == ("cdp", gather_domain, data_domain)
assert t._physical_coord_names == ("cdp_x", "cdp_y")
assert t._var_chunk_shape == (16, 64, 1024)
assert t.full_chunk_shape == (16, 64, 1024)

# Variables instantiated when build_dataset() is called
assert t._builder is None
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/v1/templates/test_seismic_2d_prestack_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,10 @@ def test_configuration(self) -> None:

# Template attributes for prestack shot
assert t._data_domain == "time"
assert t._spatial_dim_names == ("shot_point", "channel")
assert t._dim_names == ("shot_point", "channel", "time")
assert t._physical_coord_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
assert t._logical_coord_names == ("gun",)
assert t._var_chunk_shape == (16, 32, 2048)
assert t.full_chunk_shape == (16, 32, 2048)

# Variables instantiated when build_dataset() is called
assert t._builder is None
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/v1/templates/test_seismic_3d_poststack.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ def test_configuration(self, data_domain: SeismicDataDomain) -> None:

# Template attributes to be overridden by subclasses
assert t._data_domain == data_domain # Domain should be lowercased
assert t._spatial_dim_names == ("inline", "crossline")
assert t._dim_names == ("inline", "crossline", data_domain)
assert t._physical_coord_names == ("cdp_x", "cdp_y")
assert t._var_chunk_shape == (128, 128, 128)
assert t.full_chunk_shape == (128, 128, 128)

# Variables instantiated when build_dataset() is called
assert t._builder is None
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/v1/templates/test_seismic_3d_prestack_cdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ def test_configuration(self, data_domain: SeismicDataDomain, gather_domain: CdpG
t = Seismic3DPreStackCDPTemplate(data_domain, gather_domain)

# Template attributes for prestack CDP
assert t._spatial_dim_names == ("inline", "crossline", gather_domain)
assert t._dim_names == ("inline", "crossline", gather_domain, data_domain)
assert t._physical_coord_names == ("cdp_x", "cdp_y")
assert t._var_chunk_shape == (8, 8, 32, 512)
assert t.full_chunk_shape == (8, 8, 32, 512)

# Variables instantiated when build_dataset() is called
assert t._builder is None
Expand Down
Loading