@@ -9,7 +9,8 @@ import io.ktor.client.statement.*
99import  io.ktor.http.* 
1010import  io.ktor.server.application.* 
1111import  io.ktor.server.application.hooks.* 
12- import  io.ktor.server.plugins.cors.routing.* 
12+ import  io.ktor.server.plugins.cors.routing.CORS 
13+ import  io.ktor.server.request.httpMethod 
1314import  io.ktor.server.response.* 
1415import  io.ktor.server.routing.* 
1516import  io.ktor.server.testing.* 
@@ -125,6 +126,14 @@ class CORSTest {
125126            }
126127        }
127128
129+         client.options(" /1" 
130+             header(HttpHeaders .Origin , " http://my-host" 
131+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
132+         }.let  { response -> 
133+             assertEquals(HttpStatusCode .OK , response.status)
134+             assertEquals(" http://my-host" HttpHeaders .AccessControlAllowOrigin ])
135+         }
136+ 
128137        client.get(" /1" 
129138            header(HttpHeaders .Origin , " http://my-host" 
130139        }.let  { response -> 
@@ -1275,4 +1284,312 @@ class CORSTest {
12751284            }.status
12761285        )
12771286    }
1287+ 
1288+     @Test
1289+     fun  preflightRoutingCorsAfter () =  testApplication {
1290+         routing {
1291+             get(" /test" 
1292+                 call.respond(" OK" 
1293+             }
1294+ 
1295+             install(CORS ) {
1296+                 allowHost(" example.com" 
1297+             }
1298+         }
1299+ 
1300+         val  response =  client.options(" /test" 
1301+             header(HttpHeaders .Origin , " https://example.com" 
1302+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1303+         }
1304+ 
1305+         assertEquals(response.status, HttpStatusCode .OK )
1306+         assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1307+         assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1308+     }
1309+ 
1310+     @Test
1311+     fun  preflightRoutingNested () =  testApplication {
1312+         routing {
1313+             install(CORS ) {
1314+                 allowHost(" example.com" 
1315+             }
1316+ 
1317+             route(" /sub" 
1318+                 install(CORS ) {
1319+                     allowHost(" another.com" 
1320+                 }
1321+ 
1322+                 get {
1323+                     call.respond(" OK" 
1324+                 }
1325+             }
1326+ 
1327+             get {
1328+                 call.respond(" OK" 
1329+             }
1330+         }
1331+ 
1332+         client.options(" /sub" 
1333+             header(HttpHeaders .Origin , " https://another.com" 
1334+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1335+         }.let  { response -> 
1336+             assertEquals(response.status, HttpStatusCode .OK )
1337+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1338+             assertEquals(" https://another.com" HttpHeaders .AccessControlAllowOrigin ])
1339+         }
1340+ 
1341+         client.options(" /sub" 
1342+             header(HttpHeaders .Origin , " https://example.com" 
1343+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1344+         }.let  { response -> 
1345+             assertEquals(response.status, HttpStatusCode .Forbidden )
1346+         }
1347+ 
1348+         client.options(" /" 
1349+             header(HttpHeaders .Origin , " https://example.com" 
1350+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1351+         }.let  { response -> 
1352+             assertEquals(response.status, HttpStatusCode .OK )
1353+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1354+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1355+         }
1356+ 
1357+         client.options(" /" 
1358+             header(HttpHeaders .Origin , " https://another.com" 
1359+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1360+         }.let  { response -> 
1361+             assertEquals(response.status, HttpStatusCode .Forbidden )
1362+         }
1363+     }
1364+ 
1365+     @Test
1366+     fun  preflightNoRouting () =  testApplication {
1367+         application {
1368+             intercept(ApplicationCallPipeline .Call ) {
1369+                 if  (context.request.httpMethod ==  HttpMethod .Get ) {
1370+                     call.respondText(" OK" 
1371+                 }
1372+             }
1373+         }
1374+ 
1375+         install(CORS ) {
1376+             allowHost(" example.com" 
1377+         }
1378+ 
1379+         client.options(" /" 
1380+             header(HttpHeaders .Origin , " https://example.com" 
1381+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1382+         }.let  { response -> 
1383+             assertEquals(response.status, HttpStatusCode .OK )
1384+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1385+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1386+         }
1387+ 
1388+         client.get(" /" 
1389+             header(HttpHeaders .Origin , " https://example.com" 
1390+         }.let  { response -> 
1391+             assertEquals(response.status, HttpStatusCode .OK )
1392+             assertEquals(response.bodyAsText(), " OK" 
1393+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1394+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1395+         }
1396+     }
1397+ 
1398+     @Test
1399+     fun  preflightHasOptionsRoute () =  testApplication {
1400+         routing {
1401+             route(" /test" 
1402+                 get {
1403+                     call.respond(" OK" 
1404+                 }
1405+ 
1406+                 options {
1407+                     call.respond(" OPTIONS" 
1408+                 }
1409+             }
1410+         }
1411+ 
1412+         install(CORS ) {
1413+             allowHost(" example.com" 
1414+         }
1415+ 
1416+         client.options(" /test" 
1417+             header(HttpHeaders .Origin , " https://example.com" 
1418+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1419+         }.let  { response -> 
1420+             assertEquals(response.status, HttpStatusCode .OK )
1421+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1422+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1423+         }
1424+     }
1425+ 
1426+     @Test
1427+     fun  preflightRoutingHasOptionsRouteAfter () =  testApplication {
1428+         routing {
1429+             route(" /test" 
1430+                 install(CORS ) {
1431+                     allowHost(" example.com" 
1432+                 }
1433+ 
1434+                 get {
1435+                     call.respond(" OK" 
1436+                 }
1437+ 
1438+                 options {
1439+                     call.respond(" OPTIONS" 
1440+                 }
1441+             }
1442+         }
1443+ 
1444+         client.options(" /test" 
1445+             header(HttpHeaders .Origin , " https://example.com" 
1446+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1447+         }.let  { response -> 
1448+             assertEquals(response.status, HttpStatusCode .OK )
1449+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1450+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1451+         }
1452+     }
1453+ 
1454+     @Test
1455+     fun  preflightRoutingHasOptionsRouteBefore () =  testApplication {
1456+         routing {
1457+             route(" /test" 
1458+                 options {
1459+                     call.respond(" OPTIONS" 
1460+                 }
1461+ 
1462+                 install(CORS ) {
1463+                     allowHost(" example.com" 
1464+                 }
1465+ 
1466+                 get {
1467+                     call.respond(" OK" 
1468+                 }
1469+             }
1470+         }
1471+ 
1472+         client.options(" /test" 
1473+             header(HttpHeaders .Origin , " https://example.com" 
1474+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1475+         }.let  { response -> 
1476+             assertEquals(response.status, HttpStatusCode .OK )
1477+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1478+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1479+         }
1480+     }
1481+ 
1482+     @Test
1483+     fun  preflightAllowedMethodsConfinedToRoutes () =  testApplication {
1484+         routing {
1485+             route(" /test" 
1486+                 install(CORS ) {
1487+                     allowHost(" example.com" 
1488+                     allowMethod(HttpMethod .Put )
1489+                 }
1490+             }
1491+ 
1492+             route(" /other" 
1493+                 install(CORS ) {
1494+                     allowHost(" example.com" 
1495+                     allowMethod(HttpMethod .Patch )
1496+                 }
1497+             }
1498+         }
1499+ 
1500+         client.options(" /test" 
1501+             header(HttpHeaders .Origin , " https://example.com" 
1502+             header(HttpHeaders .AccessControlRequestMethod , " PUT" 
1503+         }.let  { response -> 
1504+             assertEquals(response.status, HttpStatusCode .OK )
1505+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1506+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1507+         }
1508+ 
1509+         client.options(" /test" 
1510+             header(HttpHeaders .Origin , " https://example.com" 
1511+             header(HttpHeaders .AccessControlRequestMethod , " PATCH" 
1512+         }.let  { response -> 
1513+             assertEquals(response.status, HttpStatusCode .Forbidden )
1514+         }
1515+ 
1516+         client.options(" /other" 
1517+             header(HttpHeaders .Origin , " https://example.com" 
1518+             header(HttpHeaders .AccessControlRequestMethod , " PATCH" 
1519+         }.let  { response -> 
1520+             assertEquals(response.status, HttpStatusCode .OK )
1521+             assertEquals(HttpHeaders .Origin , response.headers[HttpHeaders .Vary ])
1522+             assertEquals(" https://example.com" HttpHeaders .AccessControlAllowOrigin ])
1523+         }
1524+ 
1525+         client.options(" /other" 
1526+             header(HttpHeaders .Origin , " https://example.com" 
1527+             header(HttpHeaders .AccessControlRequestMethod , " PUT" 
1528+         }.let  { response -> 
1529+             assertEquals(response.status, HttpStatusCode .Forbidden )
1530+         }
1531+     }
1532+ 
1533+     @Test
1534+     fun  preflightPluginInParentHandlesChildRoutes () =  testApplication {
1535+         routing {
1536+             route(" /outer" 
1537+                 install(CORS ) {
1538+                     anyHost()
1539+                 }
1540+ 
1541+                 route(" /inner1" 
1542+                     route(" /inner2" 
1543+                         get {
1544+                             call.respond(" OK" 
1545+                         }
1546+                     }
1547+                 }
1548+             }
1549+         }
1550+ 
1551+         client.options(" /outer/inner1/inner2" 
1552+             header(HttpHeaders .Origin , " https://example.com" 
1553+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1554+         }.let  { response -> 
1555+             assertEquals(response.status, HttpStatusCode .OK )
1556+             assertEquals(" *" HttpHeaders .AccessControlAllowOrigin ])
1557+         }
1558+     }
1559+ 
1560+     @Test
1561+     fun  preflightNoEffectWithoutOriginHeader () =  testApplication {
1562+         routing {
1563+             install(CORS ) {
1564+                 anyHost()
1565+             }
1566+         }
1567+ 
1568+         client.options(" /" 
1569+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1570+         }.let  { response -> 
1571+             assertEquals(response.status, HttpStatusCode .NotFound )
1572+             assertNull(response.headers[HttpHeaders .AccessControlAllowOrigin ])
1573+         }
1574+     }
1575+ 
1576+     @Test
1577+     fun  routeInsideCorsPluginConfig () =  testApplication {
1578+         routing {
1579+             install(CORS ) {
1580+                 route(" /abc" 
1581+                     get {
1582+                         call.respond(" OK" 
1583+                     }
1584+                 }
1585+             }
1586+         }
1587+ 
1588+         client.options(" /abc" 
1589+             header(HttpHeaders .Origin , " https://example.com" 
1590+             header(HttpHeaders .AccessControlRequestMethod , " GET" 
1591+         }.let  { response -> 
1592+             assertEquals(response.status, HttpStatusCode .Forbidden )
1593+         }
1594+     }
12781595}
0 commit comments