Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions plotly/figure_factory/_distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,31 +387,60 @@ def make_normal(self):
mean = [None] * self.trace_number
sd = [None] * self.trace_number

# Instead of lists, use tuple and local access for performance in loop
norm_fit = scipy_stats.norm.fit
norm_pdf = scipy_stats.norm.pdf

# Avoid repeated indexing into self by pulling needed data once per iteration
histnorm = self.histnorm
bin_size = self.bin_size
start = self.start
end = self.end
hist_data = self.hist_data
curve_x = self.curve_x
curve_y = self.curve_y

# Avoid recomputation by precompute commonly-used values and reuse loop variables
alt_histnorm = ALTERNATIVE_HISTNORM

for index in range(self.trace_number):
mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
self.curve_x[index] = [
self.start[index] + x * (self.end[index] - self.start[index]) / 500
for x in range(500)
]
self.curve_y[index] = scipy_stats.norm.pdf(
self.curve_x[index], loc=mean[index], scale=sd[index]
)
data = hist_data[index]
s0 = start[index]
e0 = end[index]
step = (e0 - s0) / 500
mean_val, sd_val = norm_fit(data)
mean[index] = mean_val
sd[index] = sd_val

if self.histnorm == ALTERNATIVE_HISTNORM:
self.curve_y[index] *= self.bin_size[index]
# Use list comprehension directly for curve_x, local binding of s0 and step
x_vals = [s0 + x * step for x in range(500)]
curve_x[index] = x_vals

y = norm_pdf(x_vals, loc=mean_val, scale=sd_val) # y is np.ndarray

if histnorm == alt_histnorm:
y *= bin_size[index] # vectorized multiplication

curve_y[index] = y

colors = self.colors
group_labels = self.group_labels
show_hist = self.show_hist

# Use locals + np.ndarray if possible for y, avoids extra conversion

for index in range(self.trace_number):
curve[index] = dict(
type="scatter",
x=self.curve_x[index],
y=self.curve_y[index],
x=curve_x[index],
y=curve_y[index],
xaxis="x1",
yaxis="y1",
mode="lines",
name=self.group_labels[index],
legendgroup=self.group_labels[index],
showlegend=False if self.show_hist else True,
marker=dict(color=self.colors[index % len(self.colors)]),
name=group_labels[index],
legendgroup=group_labels[index],
showlegend=False if show_hist else True,
marker=dict(color=colors[index % len(colors)]),
)
return curve

Expand Down