Skip to content

Commit 496dc79

Browse files
feat: Add tooltip to Altair agent portrayal (#2795)
1 parent e1f9780 commit 496dc79

File tree

3 files changed

+74
-93
lines changed

3 files changed

+74
-93
lines changed

mesa/examples/basic/boltzmann_wealth_model/app.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
def agent_portrayal(agent):
1616
return AgentPortrayalStyle(
17-
color=agent.wealth
18-
) # we are using a colormap to translate wealth to color
17+
color=agent.wealth,
18+
tooltip={"Agent ID": agent.unique_id, "Wealth": agent.wealth},
19+
)
1920

2021

2122
model_params = {
@@ -37,23 +38,6 @@ def agent_portrayal(agent):
3738
}
3839

3940

40-
def post_process(chart):
41-
"""Post-process the Altair chart to add a colorbar legend."""
42-
chart = chart.encode(
43-
color=alt.Color(
44-
"color:N",
45-
scale=alt.Scale(scheme="viridis", domain=[0, 10]),
46-
legend=alt.Legend(
47-
title="Wealth",
48-
orient="right",
49-
type="gradient",
50-
gradientLength=200,
51-
),
52-
),
53-
)
54-
return chart
55-
56-
5741
model = BoltzmannWealth(50, 10, 10)
5842

5943
# The SpaceRenderer is responsible for drawing the model's space and agents.
@@ -63,11 +47,13 @@ def post_process(chart):
6347
renderer = SpaceRenderer(model, backend="altair")
6448
# Can customize the grid appearance.
6549
renderer.draw_structure(grid_color="black", grid_dash=[6, 2], grid_opacity=0.3)
66-
renderer.draw_agents(agent_portrayal=agent_portrayal, cmap="viridis", vmin=0, vmax=10)
67-
68-
# The post_process function is used to modify the Altair chart after it has been created.
69-
# It can be used to add legends, colorbars, or other visual elements.
70-
renderer.post_process = post_process
50+
renderer.draw_agents(
51+
agent_portrayal=agent_portrayal,
52+
cmap="viridis",
53+
vmin=0,
54+
vmax=10,
55+
legend_title="Wealth",
56+
)
7157

7258
# Creates a line plot component from the model's "Gini" datacollector.
7359
GiniPlot = make_plot_component("Gini")
@@ -81,4 +67,4 @@ def post_process(chart):
8167
model_params=model_params,
8268
name="Boltzmann Wealth Model",
8369
)
84-
page # noqa
70+
page # noqa

mesa/visualization/backends/altair_backend.py

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# noqa: D100
21
import warnings
32
from collections.abc import Callable
43
from dataclasses import fields
@@ -75,6 +74,7 @@ def collect_agent_data(
7574
"stroke": [], # Stroke color
7675
"strokeWidth": [],
7776
"filled": [],
77+
"tooltip": [],
7878
}
7979

8080
# Import here to avoid circular import issues
@@ -129,6 +129,7 @@ def collect_agent_data(
129129
linewidths=dict_data.pop(
130130
"linewidths", style_fields.get("linewidths")
131131
),
132+
tooltip=dict_data.pop("tooltip", None),
132133
)
133134
if dict_data:
134135
ignored_keys = list(dict_data.keys())
@@ -184,6 +185,7 @@ def collect_agent_data(
184185
# FIXME: Make filled user-controllable
185186
filled_value = True
186187
arguments["filled"].append(filled_value)
188+
arguments["tooltip"].append(aps.tooltip)
187189

188190
final_data = {}
189191
for k, v in arguments.items():
@@ -199,87 +201,84 @@ def collect_agent_data(
199201

200202
return final_data
201203

204+
205+
202206
def draw_agents(
203207
self, arguments, chart_width: int = 450, chart_height: int = 350, **kwargs
204208
):
205-
"""Draw agents using Altair backend.
206-
207-
Args:
208-
arguments: Dictionary containing agent data arrays.
209-
chart_width: Width of the chart.
210-
chart_height: Height of the chart.
211-
**kwargs: Additional keyword arguments for customization.
212-
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
213-
214-
Returns:
215-
alt.Chart: The Altair chart representing the agents, or None if no agents.
216-
"""
209+
"""Draw agents using Altair backend."""
217210
if arguments["loc"].size == 0:
218211
return None
219212

220-
# To get a continuous scale for color the domain should be between [0, 1]
221-
# that's why changing the the domain of strokeWidth beforehand.
222-
stroke_width = [data / 10 for data in arguments["strokeWidth"]]
223-
224-
# Agent data preparation
225-
df_data = {
226-
"x": arguments["loc"][:, 0],
227-
"y": arguments["loc"][:, 1],
228-
"size": arguments["size"],
229-
"shape": arguments["shape"],
230-
"opacity": arguments["opacity"],
231-
"strokeWidth": stroke_width,
232-
"original_color": arguments["color"],
233-
"is_filled": arguments["filled"],
234-
"original_stroke": arguments["stroke"],
235-
}
236-
df = pd.DataFrame(df_data)
237-
238-
# To ensure distinct shapes according to agent portrayal
239-
unique_shape_names_in_data = df["shape"].unique().tolist()
240-
241-
fill_colors = []
242-
stroke_colors = []
243-
for i in range(len(df)):
244-
filled = df["is_filled"][i]
245-
main_color = df["original_color"][i]
246-
stroke_spec = (
247-
df["original_stroke"][i]
248-
if isinstance(df["original_stroke"][i], str)
249-
else None
250-
)
251-
if filled:
252-
fill_colors.append(main_color)
253-
stroke_colors.append(stroke_spec)
213+
# Prepare a list of dictionaries, which is a robust way to create a DataFrame
214+
records = []
215+
for i in range(len(arguments["loc"])):
216+
record = {
217+
"x": arguments["loc"][i][0],
218+
"y": arguments["loc"][i][1],
219+
"size": arguments["size"][i],
220+
"shape": arguments["shape"][i],
221+
"opacity": arguments["opacity"][i],
222+
"strokeWidth": arguments["strokeWidth"][i] / 10, # Scale for continuous domain
223+
"original_color": arguments["color"][i],
224+
}
225+
# Add tooltip data if available
226+
tooltip = arguments["tooltip"][i]
227+
if tooltip:
228+
record.update(tooltip)
229+
230+
# Determine fill and stroke colors
231+
if arguments["filled"][i]:
232+
record["viz_fill_color"] = arguments["color"][i]
233+
record["viz_stroke_color"] = arguments["stroke"][i] if isinstance(arguments["stroke"][i], str) else None
254234
else:
255-
fill_colors.append(None)
256-
stroke_colors.append(main_color)
257-
df["viz_fill_color"] = fill_colors
258-
df["viz_stroke_color"] = stroke_colors
235+
record["viz_fill_color"] = None
236+
record["viz_stroke_color"] = arguments["color"][i]
237+
238+
records.append(record)
239+
240+
df = pd.DataFrame(records)
241+
242+
# Ensure all columns that should be numeric are, handling potential Nones
243+
numeric_cols = ['x', 'y', 'size', 'opacity', 'strokeWidth', 'original_color']
244+
for col in numeric_cols:
245+
if col in df.columns:
246+
df[col] = pd.to_numeric(df[col], errors='coerce')
247+
248+
249+
# Get tooltip keys from the first valid record
250+
tooltip_list = ["x", "y"]
251+
# This is the corrected line:
252+
if any(t is not None for t in arguments["tooltip"]):
253+
first_valid_tooltip = next((t for t in arguments["tooltip"] if t), None)
254+
if first_valid_tooltip:
255+
tooltip_list.extend(first_valid_tooltip.keys())
259256

260257
# Extract additional parameters from kwargs
261-
# FIXME: Add more parameters to kwargs
262258
title = kwargs.pop("title", "")
263259
xlabel = kwargs.pop("xlabel", "")
264260
ylabel = kwargs.pop("ylabel", "")
265-
266-
# Tooltip list for interactivity
267-
# FIXME: Add more fields to tooltip (preferably from agent_portrayal)
268-
tooltip_list = ["x", "y"]
261+
legend_title = kwargs.pop("legend_title", "Color")
269262

270263
# Handle custom colormapping
271264
cmap = kwargs.pop("cmap", "viridis")
272265
vmin = kwargs.pop("vmin", None)
273266
vmax = kwargs.pop("vmax", None)
274267

275-
color_is_numeric = np.issubdtype(df["original_color"].dtype, np.number)
268+
color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"])
276269
if color_is_numeric:
277270
color_min = vmin if vmin is not None else df["original_color"].min()
278271
color_max = vmax if vmax is not None else df["original_color"].max()
279272

280273
fill_encoding = alt.Fill(
281274
"original_color:Q",
282275
scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]),
276+
legend=alt.Legend(
277+
title=legend_title,
278+
orient="right",
279+
type="gradient",
280+
gradientLength=200,
281+
),
283282
)
284283
else:
285284
fill_encoding = alt.Fill(
@@ -290,6 +289,7 @@ def draw_agents(
290289

291290
# Determine space dimensions
292291
xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits()
292+
unique_shape_names_in_data = df["shape"].dropna().unique().tolist()
293293

294294
chart = (
295295
alt.Chart(df)
@@ -316,16 +316,10 @@ def draw_agents(
316316
),
317317
title="Shape",
318318
),
319-
opacity=alt.Opacity(
320-
"opacity:Q",
321-
title="Opacity",
322-
scale=alt.Scale(domain=[0, 1], range=[0, 1]),
323-
),
319+
opacity=alt.Opacity("opacity:Q", title="Opacity", scale=alt.Scale(domain=[0, 1], range=[0, 1])),
324320
fill=fill_encoding,
325321
stroke=alt.Stroke("viz_stroke_color:N", scale=None),
326-
strokeWidth=alt.StrokeWidth(
327-
"strokeWidth:Q", scale=alt.Scale(domain=[0, 1])
328-
),
322+
strokeWidth=alt.StrokeWidth("strokeWidth:Q", scale=alt.Scale(domain=[0, 1])),
329323
tooltip=tooltip_list,
330324
)
331325
.properties(title=title, width=chart_width, height=chart_height)
@@ -437,4 +431,4 @@ def draw_propertylayer(
437431
main_charts.append(current_chart)
438432

439433
base = alt.layer(*main_charts).resolve_scale(color="independent")
440-
return base
434+
return base

mesa/visualization/components/portrayal_components.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class AgentPortrayalStyle:
5555
alpha: float | None = 1.0
5656
edgecolors: str | tuple | None = None
5757
linewidths: float | int | None = 1.0
58+
tooltip: dict | None = None
5859

5960
def update(self, *updates_fields: tuple[str, Any]):
6061
"""Updates attributes from variable (field_name, new_value) tuple arguments.
@@ -91,7 +92,7 @@ class PropertyLayerStyle:
9192
(vmin, vmax), transparency (alpha) and colorbar visibility.
9293
9394
Note: vmin and vmax are the lower and upper bounds for the colorbar and the data is
94-
normalized between these values for color/colormap rendering. If they are not
95+
normalized between these values for color/colorbar rendering. If they are not
9596
declared the values are automatically determined from the data range.
9697
9798
Note: You can specify either a 'colormap' (for varying data) or a single
@@ -117,4 +118,4 @@ def __post_init__(self):
117118
if self.color is not None and self.colormap is not None:
118119
raise ValueError("Specify either 'color' or 'colormap', not both.")
119120
if self.color is None and self.colormap is None:
120-
raise ValueError("Specify one of 'color' or 'colormap'")
121+
raise ValueError("Specify one of 'color' or 'colormap'")

0 commit comments

Comments
 (0)