1
- # noqa: D100
2
1
import warnings
3
2
from collections .abc import Callable
4
3
from dataclasses import fields
@@ -75,6 +74,7 @@ def collect_agent_data(
75
74
"stroke" : [], # Stroke color
76
75
"strokeWidth" : [],
77
76
"filled" : [],
77
+ "tooltip" : [],
78
78
}
79
79
80
80
# Import here to avoid circular import issues
@@ -129,6 +129,7 @@ def collect_agent_data(
129
129
linewidths = dict_data .pop (
130
130
"linewidths" , style_fields .get ("linewidths" )
131
131
),
132
+ tooltip = dict_data .pop ("tooltip" , None ),
132
133
)
133
134
if dict_data :
134
135
ignored_keys = list (dict_data .keys ())
@@ -184,6 +185,7 @@ def collect_agent_data(
184
185
# FIXME: Make filled user-controllable
185
186
filled_value = True
186
187
arguments ["filled" ].append (filled_value )
188
+ arguments ["tooltip" ].append (aps .tooltip )
187
189
188
190
final_data = {}
189
191
for k , v in arguments .items ():
@@ -199,87 +201,84 @@ def collect_agent_data(
199
201
200
202
return final_data
201
203
204
+
205
+
202
206
def draw_agents (
203
207
self , arguments , chart_width : int = 450 , chart_height : int = 350 , ** kwargs
204
208
):
205
- """Draw agents using Altair backend.
206
-
207
- Args:
208
- arguments: Dictionary containing agent data arrays.
209
- chart_width: Width of the chart.
210
- chart_height: Height of the chart.
211
- **kwargs: Additional keyword arguments for customization.
212
- Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
213
-
214
- Returns:
215
- alt.Chart: The Altair chart representing the agents, or None if no agents.
216
- """
209
+ """Draw agents using Altair backend."""
217
210
if arguments ["loc" ].size == 0 :
218
211
return None
219
212
220
- # To get a continuous scale for color the domain should be between [0, 1]
221
- # that's why changing the the domain of strokeWidth beforehand.
222
- stroke_width = [data / 10 for data in arguments ["strokeWidth" ]]
223
-
224
- # Agent data preparation
225
- df_data = {
226
- "x" : arguments ["loc" ][:, 0 ],
227
- "y" : arguments ["loc" ][:, 1 ],
228
- "size" : arguments ["size" ],
229
- "shape" : arguments ["shape" ],
230
- "opacity" : arguments ["opacity" ],
231
- "strokeWidth" : stroke_width ,
232
- "original_color" : arguments ["color" ],
233
- "is_filled" : arguments ["filled" ],
234
- "original_stroke" : arguments ["stroke" ],
235
- }
236
- df = pd .DataFrame (df_data )
237
-
238
- # To ensure distinct shapes according to agent portrayal
239
- unique_shape_names_in_data = df ["shape" ].unique ().tolist ()
240
-
241
- fill_colors = []
242
- stroke_colors = []
243
- for i in range (len (df )):
244
- filled = df ["is_filled" ][i ]
245
- main_color = df ["original_color" ][i ]
246
- stroke_spec = (
247
- df ["original_stroke" ][i ]
248
- if isinstance (df ["original_stroke" ][i ], str )
249
- else None
250
- )
251
- if filled :
252
- fill_colors .append (main_color )
253
- stroke_colors .append (stroke_spec )
213
+ # Prepare a list of dictionaries, which is a robust way to create a DataFrame
214
+ records = []
215
+ for i in range (len (arguments ["loc" ])):
216
+ record = {
217
+ "x" : arguments ["loc" ][i ][0 ],
218
+ "y" : arguments ["loc" ][i ][1 ],
219
+ "size" : arguments ["size" ][i ],
220
+ "shape" : arguments ["shape" ][i ],
221
+ "opacity" : arguments ["opacity" ][i ],
222
+ "strokeWidth" : arguments ["strokeWidth" ][i ] / 10 , # Scale for continuous domain
223
+ "original_color" : arguments ["color" ][i ],
224
+ }
225
+ # Add tooltip data if available
226
+ tooltip = arguments ["tooltip" ][i ]
227
+ if tooltip :
228
+ record .update (tooltip )
229
+
230
+ # Determine fill and stroke colors
231
+ if arguments ["filled" ][i ]:
232
+ record ["viz_fill_color" ] = arguments ["color" ][i ]
233
+ record ["viz_stroke_color" ] = arguments ["stroke" ][i ] if isinstance (arguments ["stroke" ][i ], str ) else None
254
234
else :
255
- fill_colors .append (None )
256
- stroke_colors .append (main_color )
257
- df ["viz_fill_color" ] = fill_colors
258
- df ["viz_stroke_color" ] = stroke_colors
235
+ record ["viz_fill_color" ] = None
236
+ record ["viz_stroke_color" ] = arguments ["color" ][i ]
237
+
238
+ records .append (record )
239
+
240
+ df = pd .DataFrame (records )
241
+
242
+ # Ensure all columns that should be numeric are, handling potential Nones
243
+ numeric_cols = ['x' , 'y' , 'size' , 'opacity' , 'strokeWidth' , 'original_color' ]
244
+ for col in numeric_cols :
245
+ if col in df .columns :
246
+ df [col ] = pd .to_numeric (df [col ], errors = 'coerce' )
247
+
248
+
249
+ # Get tooltip keys from the first valid record
250
+ tooltip_list = ["x" , "y" ]
251
+ # This is the corrected line:
252
+ if any (t is not None for t in arguments ["tooltip" ]):
253
+ first_valid_tooltip = next ((t for t in arguments ["tooltip" ] if t ), None )
254
+ if first_valid_tooltip :
255
+ tooltip_list .extend (first_valid_tooltip .keys ())
259
256
260
257
# Extract additional parameters from kwargs
261
- # FIXME: Add more parameters to kwargs
262
258
title = kwargs .pop ("title" , "" )
263
259
xlabel = kwargs .pop ("xlabel" , "" )
264
260
ylabel = kwargs .pop ("ylabel" , "" )
265
-
266
- # Tooltip list for interactivity
267
- # FIXME: Add more fields to tooltip (preferably from agent_portrayal)
268
- tooltip_list = ["x" , "y" ]
261
+ legend_title = kwargs .pop ("legend_title" , "Color" )
269
262
270
263
# Handle custom colormapping
271
264
cmap = kwargs .pop ("cmap" , "viridis" )
272
265
vmin = kwargs .pop ("vmin" , None )
273
266
vmax = kwargs .pop ("vmax" , None )
274
267
275
- color_is_numeric = np . issubdtype (df ["original_color" ]. dtype , np . number )
268
+ color_is_numeric = pd . api . types . is_numeric_dtype (df ["original_color" ])
276
269
if color_is_numeric :
277
270
color_min = vmin if vmin is not None else df ["original_color" ].min ()
278
271
color_max = vmax if vmax is not None else df ["original_color" ].max ()
279
272
280
273
fill_encoding = alt .Fill (
281
274
"original_color:Q" ,
282
275
scale = alt .Scale (scheme = cmap , domain = [color_min , color_max ]),
276
+ legend = alt .Legend (
277
+ title = legend_title ,
278
+ orient = "right" ,
279
+ type = "gradient" ,
280
+ gradientLength = 200 ,
281
+ ),
283
282
)
284
283
else :
285
284
fill_encoding = alt .Fill (
@@ -290,6 +289,7 @@ def draw_agents(
290
289
291
290
# Determine space dimensions
292
291
xmin , xmax , ymin , ymax = self .space_drawer .get_viz_limits ()
292
+ unique_shape_names_in_data = df ["shape" ].dropna ().unique ().tolist ()
293
293
294
294
chart = (
295
295
alt .Chart (df )
@@ -316,16 +316,10 @@ def draw_agents(
316
316
),
317
317
title = "Shape" ,
318
318
),
319
- opacity = alt .Opacity (
320
- "opacity:Q" ,
321
- title = "Opacity" ,
322
- scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ]),
323
- ),
319
+ opacity = alt .Opacity ("opacity:Q" , title = "Opacity" , scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ])),
324
320
fill = fill_encoding ,
325
321
stroke = alt .Stroke ("viz_stroke_color:N" , scale = None ),
326
- strokeWidth = alt .StrokeWidth (
327
- "strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])
328
- ),
322
+ strokeWidth = alt .StrokeWidth ("strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])),
329
323
tooltip = tooltip_list ,
330
324
)
331
325
.properties (title = title , width = chart_width , height = chart_height )
@@ -437,4 +431,4 @@ def draw_propertylayer(
437
431
main_charts .append (current_chart )
438
432
439
433
base = alt .layer (* main_charts ).resolve_scale (color = "independent" )
440
- return base
434
+ return base
0 commit comments