diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 870af91e083c..8522fcd781b9 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1632,12 +1632,12 @@ object desugar { } } - def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match { + def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match case Parens(body1) => makePolyFunction(targs, body1) case Block(Nil, body1) => makePolyFunction(targs, body1) - case Function(vargs, res) => + case _ => assert(targs.nonEmpty) // TODO: Figure out if we need a `PolyFunctionWithMods` instead. val mods = body match { @@ -1646,33 +1646,37 @@ object desugar { } val polyFunctionTpt = ref(defn.PolyFunctionType) val applyTParams = targs.asInstanceOf[List[TypeDef]] - if (ctx.mode.is(Mode.Type)) { + if ctx.mode.is(Mode.Type) then // Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R // Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } - - val applyVParams = vargs.zipWithIndex.map { - case (p: ValDef, _) => p.withAddedFlags(mods.flags) - case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags) - } + val (res, applyVParamss) = body match + case Function(vargs, res) => + ( res, + vargs.zipWithIndex.map { + case (p: ValDef, _) => p.withAddedFlags(mods.flags) + case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags) + } :: Nil + ) + case _ => + (body, Nil) RefinedTypeTree(polyFunctionTpt, List( - DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree) + DefDef(nme.apply, applyTParams :: applyVParamss, res, EmptyTree) )) - } - else { + else // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body } - - val applyVParams = vargs.asInstanceOf[List[ValDef]] - .map(varg => varg.withAddedFlags(mods.flags | Param)) - New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, - List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res)) - )) - } - case _ => - // may happen for erroneous input. An error will already have been reported. - assert(ctx.reporter.errorsReported) - EmptyTree - } + val (res, applyVParamss) = body match + case Function(vargs, res) => + ( res, + vargs.asInstanceOf[List[ValDef]] + .map(varg => varg.withAddedFlags(mods.flags | Param)) + :: Nil + ) + case _ => + (body, Nil) + New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, + List(DefDef(nme.apply, applyTParams :: applyVParamss, TypeTree(), res)) + )) // begin desugar diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 3a077407d0b5..291419d5c400 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -549,6 +549,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst * - otherwise, if T is a type parameter coming from Java, []Object * - otherwise, Object * - For a term ref p.x, the type # x. + * - For a refined type scala.PolyFunction { def apply[...]: R }, scala.Function0 * - For a refined type scala.PolyFunction { def apply[...](x_1, ..., x_N): R }, scala.FunctionN * - For a typeref scala.Any, scala.AnyVal, scala.Singleton, scala.Tuple, or scala.*: : |java.lang.Object| * - For a typeref scala.Unit, |scala.runtime.BoxedUnit|. @@ -600,8 +601,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst assert(refinedInfo.isInstanceOf[PolyType]) val res = refinedInfo.resultType val paramss = res.paramNamess - assert(paramss.length == 1) - this(defn.FunctionType(paramss.head.length, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod)) + assert(paramss.length <= 1) + val arity = if paramss.isEmpty then 0 else paramss.head.length + this(defn.FunctionType(arity, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod)) case tp: TypeProxy => this(tp.underlying) case tp @ AndType(tp1, tp2) => diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 8f0a57dc2fc8..cf49704619de 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1434,14 +1434,7 @@ object Parsers { else if (in.token == ARROW) { val arrowOffset = in.skipToken() val body = toplevelTyp() - atSpan(start, arrowOffset) { - if (isFunction(body)) - PolyFunction(tparams, body) - else { - syntaxError("Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) - Ident(nme.ERROR.toTypeName) - } - } + atSpan(start, arrowOffset) { PolyFunction(tparams, body) } } else { accept(TLARROW); typ() } } @@ -1917,14 +1910,7 @@ object Parsers { val tparams = typeParamClause(ParamOwner.TypeParam) val arrowOffset = accept(ARROW) val body = expr(location) - atSpan(start, arrowOffset) { - if (isFunction(body)) - PolyFunction(tparams, body) - else { - syntaxError("Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset) - errorTermTree - } - } + atSpan(start, arrowOffset) { PolyFunction(tparams, body) } case _ => val saved = placeholderParams placeholderParams = Nil diff --git a/tests/neg/i2887b.scala b/tests/neg/i2887b.scala index 3984949bf580..fea973c9e7cf 100644 --- a/tests/neg/i2887b.scala +++ b/tests/neg/i2887b.scala @@ -1,5 +1,5 @@ -trait A { type S[X[_] <: [_] => Any, Y[_]] <: [_] => Any; type I[_] } // error // error -trait B { type S[X[_],Y[_]]; type I[_] <: [_] => Any } // error +trait A { type S[X[_] <: [_] => Any, Y[_]] <: [_] => Any; type I[_] } +trait B { type S[X[_],Y[_]]; type I[_] <: [_] => Any } trait C { type M <: B } trait D { type M >: A } diff --git a/tests/neg/polymorphic-functions.scala b/tests/pos/polymorphic-functions.scala similarity index 100% rename from tests/neg/polymorphic-functions.scala rename to tests/pos/polymorphic-functions.scala