Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/mdio/converters/segy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def _scan_for_headers(
"""Extract trace dimensions and index headers from the SEG-Y file.

This is an expensive operation.
It scans the SEG-Y file in chunks by using ProcessPoolExecutor
It scans the SEG-Y file in chunks by using ProcessPoolExecutor.

Note:
If grid_overrides are applied to the template before calling this function,
the chunk_size returned from get_grid_plan should match the template's chunk shape.
"""
full_chunk_shape = template.full_chunk_shape
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
Expand All @@ -167,13 +171,26 @@ def _scan_for_headers(
chunksize=full_chunk_shape,
grid_overrides=grid_overrides,
)

# Update template to match grid_plan results after grid overrides
if full_chunk_shape != chunk_size:
# TODO(Dmitriy): implement grid overrides
# https://github.com/TGSAI/mdio-python/issues/585
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
# support for grid overrides is implemented
err = "Support for changing full_chunk_shape in grid overrides is not yet implemented"
raise NotImplementedError(err)
logger.debug(
"Adjusting template chunk shape from %s to %s to match grid after overrides",
full_chunk_shape,
chunk_size,
)
template._var_chunk_shape = chunk_size

# Update dimensions if they don't match grid_plan results
actual_spatial_dims = tuple(dim.name for dim in segy_dimensions[:-1])
if template.spatial_dimension_names != actual_spatial_dims:
logger.debug(
"Adjusting template dimensions from %s to %s to match grid after overrides",
template.spatial_dimension_names,
actual_spatial_dims,
)
template._dim_names = actual_spatial_dims + (template.trace_domain,)

return segy_dimensions, segy_headers


Expand Down
27 changes: 24 additions & 3 deletions src/mdio/segy/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from numpy.typing import NDArray
from segy.arrays import HeaderArray

from mdio.builder.templates.base import AbstractDatasetTemplate


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,7 +269,8 @@ def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = n
header_names = []
for header_key in index_headers.dtype.names:
if header_key != "trace":
unique_headers[header_key] = np.sort(np.unique(index_headers[header_key]))
unique_vals = np.sort(np.unique(index_headers[header_key]))
unique_headers[header_key] = unique_vals
header_names.append(header_key)
total_depth += 1

Expand Down Expand Up @@ -302,6 +305,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""

Expand Down Expand Up @@ -378,11 +382,25 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate,
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)

return analyze_non_indexed_headers(index_headers)
# Filter out coordinate fields, keep only dimensions for trace indexing
coord_fields = set(template.coordinate_names) if template else set()
dim_fields = [name for name in index_headers.dtype.names if name != "trace" and name not in coord_fields]

# Create trace indices on dimension fields only
dim_headers = index_headers[dim_fields] if dim_fields else index_headers
dim_headers_with_trace = analyze_non_indexed_headers(dim_headers)

# Add trace field back to full headers
if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names:
trace_values = np.array(dim_headers_with_trace["trace"])
index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False)

return index_headers

def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]:
"""Insert dimension "trace" to the sample-1 dimension."""
Expand Down Expand Up @@ -434,6 +452,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)
Expand Down Expand Up @@ -471,6 +490,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)
Expand Down Expand Up @@ -532,6 +552,7 @@ def run(
index_names: Sequence[str],
grid_overrides: dict[str, bool],
chunksize: Sequence[int] | None = None,
template: AbstractDatasetTemplate | None = None,
) -> tuple[HeaderArray, tuple[str], tuple[int]]:
"""Run grid overrides and return result."""
for override in grid_overrides:
Expand All @@ -542,7 +563,7 @@ def run(
raise GridOverrideUnknownError(override)

function = self.commands[override].transform
index_headers = function(index_headers, grid_overrides=grid_overrides)
index_headers = function(index_headers, grid_overrides=grid_overrides, template=template)

function = self.commands[override].transform_index_names
index_names = function(index_names)
Expand Down
11 changes: 10 additions & 1 deletion src/mdio/segy/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,19 @@ def get_grid_plan( # noqa: C901, PLR0913
horizontal_coordinates,
chunksize=chunksize,
grid_overrides=grid_overrides,
template=template,
)
# Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides)
# Extract only the dimension names (not including non-dimension coordinates)
# After grid overrides, trace might have been added to horizontal_coordinates
transformed_spatial_dims = [
name for name in horizontal_coordinates if name in horizontal_dimensions or name == "trace"
]

dimensions = []
for dim_name in horizontal_dimensions:
for dim_name in transformed_spatial_dims:
if dim_name not in headers_subset.dtype.names:
continue
dim_unique = np.unique(headers_subset[dim_name])
dimensions.append(Dimension(coords=dim_unique, name=dim_name))

Expand Down
11 changes: 5 additions & 6 deletions tests/integration/test_import_streamer_grid_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import dask
import numpy as np
import numpy.testing as npt
import pytest
import xarray.testing as xrt
from tests.integration.conftest import get_segy_mock_4d_spec
Expand All @@ -28,12 +27,12 @@
os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true"


# TODO(Altay): Finish implementing these grid overrides.
# TODO(BrianMichell): Add non-binned back
# https://github.com/TGSAI/mdio-python/issues/612
@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
# @pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4}, {"HasDuplicates": True}])
@pytest.mark.parametrize("grid_override", [{"HasDuplicates": True}])
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
class TestImport4DNonReg: # pragma: no cover - tests is skipped
class TestImport4DNonReg:
"""Test for 4D segy import with grid overrides."""

def test_import_4d_segy( # noqa: PLR0913
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_import_4d_segy( # noqa: PLR0913
assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples
assert ds.attrs["attributes"]["gridOverrides"] == grid_override

assert npt.assert_array_equal(ds["shot_point"], shots)
xrt.assert_duckarray_equal(ds["shot_point"], shots)
xrt.assert_duckarray_equal(ds["cable"], cables)

# assert grid.select_dim("trace") == Dimension(range(1, np.amax(receivers_per_cable) + 1), "trace")
Expand Down