1+ import * as ts from "typescript" ;
2+ import { NodeBuilderFlags } from "typescript" ;
3+ import { map } from "../compiler/lang-utils" ;
4+ import { SymbolTracker } from "../compiler/types" ;
5+
6+ const declarationEmitNodeBuilderFlags =
7+ NodeBuilderFlags . MultilineObjectLiterals |
8+ NodeBuilderFlags . WriteClassExpressionAsTypeLiteral |
9+ NodeBuilderFlags . UseTypeOfFunction |
10+ NodeBuilderFlags . UseStructuralFallback |
11+ NodeBuilderFlags . AllowEmptyTuple |
12+ NodeBuilderFlags . GenerateNamesForShadowedTypeParams |
13+ NodeBuilderFlags . NoTruncation ;
14+
15+
16+ // Define a transformer function
17+ export function addTypeAnnotationTransformer ( program : ts . Program , moduleResolutionHost ?: ts . ModuleResolutionHost ) {
18+ function tryGetReturnType (
19+ typeChecker : ts . TypeChecker ,
20+ node : ts . SignatureDeclaration
21+ ) : ts . Type | undefined {
22+ const signature = typeChecker . getSignatureFromDeclaration ( node ) ;
23+ if ( signature ) {
24+ return typeChecker . getReturnTypeOfSignature ( signature ) ;
25+ }
26+ }
27+
28+ function isVarConst ( node : ts . VariableDeclaration | ts . VariableDeclarationList ) : boolean {
29+ return ! ! ( ts . getCombinedNodeFlags ( node ) & ts . NodeFlags . Const ) ;
30+ }
31+
32+ function isDeclarationReadonly ( declaration : ts . Declaration ) : boolean {
33+ return ! ! ( ts . getCombinedModifierFlags ( declaration ) & ts . ModifierFlags . Readonly && ! ts . isParameterPropertyDeclaration ( declaration , declaration . parent ) ) ;
34+ }
35+
36+ function isLiteralConstDeclaration ( node : ts . VariableDeclaration | ts . PropertyDeclaration | ts . PropertySignature | ts . ParameterDeclaration ) : boolean {
37+ if ( isDeclarationReadonly ( node ) || ts . isVariableDeclaration ( node ) && isVarConst ( node ) ) {
38+ // TODO: Make sure this is a valid approximation for literal types
39+ return ! node . type && 'initializer' in node && ! ! node . initializer && ts . isLiteralExpression ( node . initializer ) ;
40+ // Original TS version
41+ // return isFreshLiteralType(getTypeOfSymbol(getSymbolOfNode(node)));
42+ }
43+ return false ;
44+ }
45+
46+ const typeChecker = program . getTypeChecker ( ) ;
47+
48+ return ( context : ts . TransformationContext ) => {
49+ let hasError = false ;
50+ let reportError = ( ) => {
51+ hasError = true ;
52+ }
53+ const symbolTracker : SymbolTracker | undefined = ! moduleResolutionHost ? undefined : {
54+ trackSymbol ( ) { return false ; } ,
55+ reportInaccessibleThisError : reportError ,
56+ reportInaccessibleUniqueSymbolError : reportError ,
57+ reportCyclicStructureError : reportError ,
58+ reportPrivateInBaseOfClassExpression : reportError ,
59+ reportLikelyUnsafeImportRequiredError : reportError ,
60+ reportTruncationError : reportError ,
61+ moduleResolverHost : moduleResolutionHost as any ,
62+ trackReferencedAmbientModule ( ) { } ,
63+ trackExternalModuleSymbolOfImportTypeNode ( ) { } ,
64+ reportNonlocalAugmentation ( ) { } ,
65+ reportNonSerializableProperty ( ) { } ,
66+ reportImportTypeNodeResolutionModeOverride ( ) { } ,
67+ } ;
68+
69+ function typeToTypeNode ( type : ts . Type , enclosingDeclaration : ts . Node ) {
70+ const typeNode = typeChecker . typeToTypeNode (
71+ type ,
72+ enclosingDeclaration ,
73+ declarationEmitNodeBuilderFlags ,
74+ // @ts -expect-error Use undocumented parameters
75+ symbolTracker ,
76+ )
77+ if ( hasError ) {
78+ hasError = false ;
79+ return undefined ;
80+ }
81+
82+ return typeNode ;
83+ }
84+ // Return a visitor function
85+ return ( rootNode : ts . Node ) => {
86+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > ) : ts . NodeArray < T >
87+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > | undefined ) : ts . NodeArray < T > | undefined
88+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > | undefined ) {
89+ if ( nodeArray === undefined ) return undefined ;
90+ return ts . factory . createNodeArray (
91+ nodeArray . map ( param => {
92+ return visit ( param ) as ts . ParameterDeclaration ;
93+ } )
94+ )
95+ }
96+
97+ // Define a visitor function
98+ function visit ( node : ts . Node ) : ts . Node | ts . Node [ ] {
99+ if ( ts . isParameter ( node ) && ! node . type ) {
100+ const type = typeChecker . getTypeAtLocation ( node ) ;
101+ if ( type ) {
102+ const typeNode = typeToTypeNode ( type , node ) ;
103+ return ts . factory . updateParameterDeclaration (
104+ node ,
105+ node . modifiers ,
106+ node . dotDotDotToken ,
107+ node . name ,
108+ node . questionToken ,
109+ typeNode ,
110+ node . initializer
111+ )
112+ }
113+ }
114+ // Check if node is a variable declaration
115+ if ( ts . isVariableDeclaration ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
116+ const type = typeChecker . getTypeAtLocation ( node ) ;
117+ const typeNode = typeToTypeNode ( type , node )
118+ return ts . factory . updateVariableDeclaration (
119+ node ,
120+ node . name ,
121+ undefined ,
122+ typeNode ,
123+ node . initializer
124+ ) ;
125+ }
126+
127+ if ( ts . isFunctionDeclaration ( node ) && ! node . type ) {
128+ const type = tryGetReturnType ( typeChecker , node ) ;
129+ if ( type ) {
130+
131+ const typeNode = typeToTypeNode ( type , node ) ;
132+ return ts . factory . updateFunctionDeclaration (
133+ node ,
134+ node . modifiers ,
135+ node . asteriskToken ,
136+ node . name ,
137+ updateTypesInNodeArray ( node . typeParameters ) ,
138+ updateTypesInNodeArray ( node . parameters ) ,
139+ typeNode ,
140+ node . body
141+ )
142+ }
143+ }
144+ if ( ts . isPropertySignature ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
145+ const type = typeChecker . getTypeAtLocation ( node ) ;
146+ const typeNode = typeToTypeNode ( type , node ) ;
147+ return ts . factory . updatePropertySignature (
148+ node ,
149+ node . modifiers ,
150+ node . name ,
151+ node . questionToken ,
152+ typeNode ,
153+ ) ;
154+ }
155+ if ( ts . isPropertyDeclaration ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
156+ const type = typeChecker . getTypeAtLocation ( node ) ;
157+ const typeNode = typeToTypeNode ( type , node ) ;
158+ return ts . factory . updatePropertyDeclaration (
159+ node ,
160+ node . modifiers ,
161+ node . name ,
162+ node . questionToken ?? node . exclamationToken ,
163+ typeNode ,
164+ node . initializer
165+ ) ;
166+ }
167+ if ( ts . isMethodSignature ( node ) && ! node . type ) {
168+ const type = tryGetReturnType ( typeChecker , node ) ;
169+ if ( type ) {
170+
171+ const typeNode = typeToTypeNode ( type , node ) ;
172+ return ts . factory . updateMethodSignature (
173+ node ,
174+ node . modifiers ,
175+ node . name ,
176+ node . questionToken ,
177+ updateTypesInNodeArray ( node . typeParameters ) ,
178+ updateTypesInNodeArray ( node . parameters ) ,
179+ typeNode ,
180+ ) ;
181+ }
182+ }
183+ if ( ts . isCallSignatureDeclaration ( node ) ) {
184+ const type = tryGetReturnType ( typeChecker , node ) ;
185+ if ( type ) {
186+ const typeNode = typeToTypeNode ( type , node ) ;
187+ return ts . factory . updateCallSignature (
188+ node ,
189+ updateTypesInNodeArray ( node . typeParameters ) ,
190+ updateTypesInNodeArray ( node . parameters ) ,
191+ typeNode ,
192+ )
193+ }
194+ }
195+ if ( ts . isMethodDeclaration ( node ) && ! node . type ) {
196+ const type = tryGetReturnType ( typeChecker , node ) ;
197+ if ( type ) {
198+
199+ const typeNode = typeToTypeNode ( type , node ) ;
200+ return ts . factory . updateMethodDeclaration (
201+ node ,
202+ node . modifiers ,
203+ node . asteriskToken ,
204+ node . name ,
205+ node . questionToken ,
206+ updateTypesInNodeArray ( node . typeParameters ) ,
207+ updateTypesInNodeArray ( node . parameters ) ,
208+ typeNode ,
209+ node . body ,
210+ ) ;
211+ }
212+ }
213+ if ( ts . isGetAccessorDeclaration ( node ) && ! node . type ) {
214+ const type = tryGetReturnType ( typeChecker , node ) ;
215+ if ( type ) {
216+ const typeNode = typeToTypeNode ( type , node ) ;
217+ return ts . factory . updateGetAccessorDeclaration (
218+ node ,
219+ node . modifiers ,
220+ node . name ,
221+ updateTypesInNodeArray ( node . parameters ) ,
222+ typeNode ,
223+ node . body ,
224+ ) ;
225+ }
226+ }
227+ if ( ts . isSetAccessorDeclaration ( node ) && ! node . parameters [ 0 ] ?. type ) {
228+ return ts . factory . updateSetAccessorDeclaration (
229+ node ,
230+ node . modifiers ,
231+ node . name ,
232+ updateTypesInNodeArray ( node . parameters ) ,
233+ node . body ,
234+ ) ;
235+ }
236+ if ( ts . isConstructorDeclaration ( node ) ) {
237+ return ts . factory . updateConstructorDeclaration (
238+ node ,
239+ node . modifiers ,
240+ updateTypesInNodeArray ( node . parameters ) ,
241+ node . body ,
242+ )
243+ }
244+ if ( ts . isConstructSignatureDeclaration ( node ) ) {
245+ const type = tryGetReturnType ( typeChecker , node ) ;
246+ if ( type ) {
247+ const typeNode = typeToTypeNode ( type , node ) ;
248+ return ts . factory . updateConstructSignature (
249+ node ,
250+ updateTypesInNodeArray ( node . typeParameters ) ,
251+ updateTypesInNodeArray ( node . parameters ) ,
252+ typeNode ,
253+ )
254+ }
255+ }
256+ if ( ts . isExportAssignment ( node ) && node . expression . kind !== ts . SyntaxKind . Identifier ) {
257+ const type = typeChecker . getTypeAtLocation ( node . expression ) ;
258+ if ( type ) {
259+ const typeNode = typeToTypeNode ( type , node ) ;
260+ const newId = ts . factory . createIdentifier ( "_default" ) ;
261+ const varDecl = ts . factory . createVariableDeclaration ( newId , /*exclamationToken*/ undefined , typeNode , /*initializer*/ undefined ) ;
262+ const statement = ts . factory . createVariableStatement (
263+ [ ] ,
264+ ts . factory . createVariableDeclarationList ( [ varDecl ] , ts . NodeFlags . Const )
265+ ) ;
266+ return [ statement , ts . factory . updateExportAssignment ( node , node . modifiers , newId ) ] ;
267+ }
268+ }
269+ // Otherwise, visit each child node recursively
270+ return ts . visitEachChild ( node , visit , context ) ;
271+ }
272+ // Start visiting from root node
273+ return ts . visitNode ( rootNode , visit ) ! ;
274+ } ;
275+ } ;
276+ }
0 commit comments