@@ -249,8 +249,7 @@ sealed abstract class CaptureSet extends Showable:
249249 if this .subCaptures(that, frozen = true ).isOK then that
250250 else if that.subCaptures(this , frozen = true ).isOK then this
251251 else if this .isConst && that.isConst then Const (this .elems ++ that.elems)
252- else Var (initialElems = this .elems ++ that.elems)
253- .addAsDependentTo(this ).addAsDependentTo(that)
252+ else Union (this , that)
254253
255254 /** The smallest superset (via <:<) of this capture set that also contains `ref`.
256255 */
@@ -263,7 +262,7 @@ sealed abstract class CaptureSet extends Showable:
263262 if this .subCaptures(that, frozen = true ).isOK then this
264263 else if that.subCaptures(this , frozen = true ).isOK then that
265264 else if this .isConst && that.isConst then Const (elemIntersection(this , that))
266- else Intersected (this , that)
265+ else Intersection (this , that)
267266
268267 /** The largest subset (via <:<) of this capture set that does not account for
269268 * any of the elements in the constant capture set `that`
@@ -816,7 +815,29 @@ object CaptureSet:
816815 class Diff (source : Var , other : Const )(using Context )
817816 extends Filtered (source, ! other.accountsFor(_))
818817
819- class Intersected (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
818+ class Union (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
819+ extends Var (initialElems = cs1.elems ++ cs2.elems):
820+ addAsDependentTo(cs1)
821+ addAsDependentTo(cs2)
822+
823+ override def tryInclude (elem : CaptureRef , origin : CaptureSet )(using Context , VarState ): CompareResult =
824+ if accountsFor(elem) then CompareResult .OK
825+ else
826+ val res = super .tryInclude(elem, origin)
827+ // If this is the union of a constant and a variable,
828+ // propagate `elem` to the variable part to avoid slack
829+ // between the operands and the union.
830+ if res.isOK && (origin ne cs1) && (origin ne cs2) then
831+ if cs1.isConst then cs2.tryInclude(elem, origin)
832+ else if cs2.isConst then cs1.tryInclude(elem, origin)
833+ else res
834+ else res
835+
836+ override def propagateSolved ()(using Context ) =
837+ if cs1.isConst && cs2.isConst && ! isConst then markSolved()
838+ end Union
839+
840+ class Intersection (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
820841 extends Var (initialElems = elemIntersection(cs1, cs2)):
821842 addAsDependentTo(cs1)
822843 addAsDependentTo(cs2)
@@ -841,7 +862,7 @@ object CaptureSet:
841862
842863 override def propagateSolved ()(using Context ) =
843864 if cs1.isConst && cs2.isConst && ! isConst then markSolved()
844- end Intersected
865+ end Intersection
845866
846867 def elemIntersection (cs1 : CaptureSet , cs2 : CaptureSet )(using Context ): Refs =
847868 cs1.elems.filter(cs2.mightAccountFor) ++ cs2.elems.filter(cs1.mightAccountFor)
0 commit comments