Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Fixed
- Fixed issue with necessary columns from complex arguments dropped when interchanging dataframes [[#4324](https://github.com/plotly/plotly.py/pull/4324)]
- Fixed issue with px.imshow failing when facet_col is an earlier dimension than animation_frame for xarrays [[#4330](https://github.com/plotly/plotly.py/issues/4330)]
- Fixed issue with px.imshow failing when facet_col has string coordinates in xarrays [[#4329](https://github.com/plotly/plotly.py/issues/4329)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you create a new ## UNRELEASED section above ## [5.16.1] and move these there? They can also be combined and reference the PR, something like "Fixed two issues with px.imshow: [#4330], facet_col is an earlier dimension than animation_frame for xarrays, and [#4329], facet_col has string coordinates in xarrays [#4331]"


## [5.16.0] - 2023-08-11

Expand Down
10 changes: 7 additions & 3 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,18 @@ def imshow(
if xarray_imported and isinstance(img, xarray.DataArray):
dims = list(img.dims)
img_is_xarray = True
pop_indexes = []
if facet_col is not None:
facet_slices = img.coords[img.dims[facet_col]].values
_ = dims.pop(facet_col)
pop_indexes.append(facet_col)
facet_label = img.dims[facet_col]
if animation_frame is not None:
animation_slices = img.coords[img.dims[animation_frame]].values
_ = dims.pop(animation_frame)
pop_indexes.append(animation_frame)
animation_label = img.dims[animation_frame]
# Remove indices in sorted order.
for index in sorted(pop_indexes, reverse=True):
_ = dims.pop(index)
y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
Expand Down Expand Up @@ -541,7 +545,7 @@ def imshow(
slice_label = (
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
)
col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices]
col_labels = [f"{slice_label}={i}" for i in facet_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
for attr_name in ["height", "width"]:
if args[attr_name]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,40 @@ def test_imshow_xarray_slicethrough():
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_facet_col_string():
img = np.random.random((3, 4, 5))
da = xr.DataArray(
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
)
fig = px.imshow(da, facet_col="str_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_2"
assert fig.layout.yaxis.title.text == "dim_1"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_animation_frame_string():
img = np.random.random((3, 4, 5))
da = xr.DataArray(
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
)
fig = px.imshow(da, animation_frame="str_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_2"
assert fig.layout.yaxis.title.text == "dim_1"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))


def test_imshow_xarray_animation_facet_slicethrough():
img = np.random.random((3, 4, 5, 6))
da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2", "dim_3"])
fig = px.imshow(da, facet_col="dim_0", animation_frame="dim_1")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_3"
assert fig.layout.yaxis.title.text == "dim_2"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_3"]))


def test_imshow_labels_and_ranges():
fig = px.imshow(
[[1, 2], [3, 4], [5, 6]],
Expand Down