@@ -3598,14 +3598,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35983598
35993599 private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
36003600 case tpe : MethodType =>
3601- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3601+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36023602 case tpe : PolyType =>
3603- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3603+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36043604 case tpe : RefinedType =>
3605- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3606- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3605+ tpe.derivedRefinedType(
3606+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3607+ tpe.refinedName,
3608+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3609+ )
36073610 case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3608- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3611+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36093612 case tpe =>
36103613 val paramNames = params.map(_.name)
36113614 val paramTpts = params.map(_.tpt)
@@ -3614,18 +3617,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36143617 typed(ctxFunction).tpe
36153618 }
36163619
3617- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3620+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3621+ case tpe : MethodType =>
3622+ tpe.paramNames -> tpe.paramInfos
3623+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3624+ extractTopMethodTermParams(tpe.refinedInfo)
3625+ case _ =>
3626+ Nil -> Nil
3627+ }
3628+
3629+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3630+ case tpe : MethodType =>
3631+ tpe.resultType
3632+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3633+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3634+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3635+ tpe.args.last
3636+ case _ =>
3637+ tpe
3638+ }
3639+
3640+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3641+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3642+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3643+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3644+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3645+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3646+ val nestedCtx = ctx.fresh.setNewTyperState()
3647+ typed(newDefDef)(using nestedCtx)
3648+ case _ => tree
3649+ }
3650+
3651+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
36183652 tree.getAttachment(desugar.PolyFunctionApply ) match
36193653 case Some (params) if params.nonEmpty =>
36203654 tree.removeAttachment(desugar.PolyFunctionApply )
36213655 val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36223656 TypeTree (tpe).withSpan(tree.span) -> tpe
3657+ // case Some(params) if params.isEmpty =>
3658+ // println(s"tree: $tree")
3659+ // healToPolyFunctionType(tree) -> pt
36233660 case _ => tree -> pt
36243661 }
36253662
36263663 /** Interpolate and simplify the type of the given tree. */
36273664 protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3628- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3665+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
36293666 if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36303667 if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
36313668 || tree1.isDef // ... unless tree is a definition
0 commit comments