Skip to content

Commit fdf7d36

Browse files
authored
KTOR-4499 Fix 405 by adding tailcard OPTIONS route where the plugin is installed (#4943)
1 parent 93d1fde commit fdf7d36

File tree

3 files changed

+339
-1
lines changed
  • ktor-server
    • ktor-server-core/common/src/io/ktor/server/application
    • ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/routing
    • ktor-server-tests/common/test/io/ktor/tests/server/plugins

3 files changed

+339
-1
lines changed

ktor-server/ktor-server-core/common/src/io/ktor/server/application/ApplicationPlugin.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ private fun <B : Any, F : Any> RoutingNode.installIntoRoute(
168168
val installed = plugin.install(fakePipeline, configure)
169169
pluginRegistry.put(plugin.key, installed)
170170

171+
for (child in fakePipeline.children) { // We need to transfer nodes added into the fake pipeline
172+
copyChildrenRecursively(child)
173+
}
174+
171175
mergePhases(fakePipeline)
172176
receivePipeline.mergePhases(fakePipeline.receivePipeline)
173177
sendPipeline.mergePhases(fakePipeline.sendPipeline)
@@ -179,6 +183,18 @@ private fun <B : Any, F : Any> RoutingNode.installIntoRoute(
179183
return installed
180184
}
181185

186+
private fun Route.copyChildrenRecursively(child: RoutingNode) {
187+
val copiedChild = createChild(child.selector)
188+
189+
for (handler in child.handlers) {
190+
copiedChild.handle(handler)
191+
}
192+
193+
for (ch in child.children) {
194+
copiedChild.copyChildrenRecursively(ch)
195+
}
196+
}
197+
182198
private fun <B : Any, F : Any, TSubject, TContext, P : Pipeline<TSubject, TContext>> P.addAllInterceptors(
183199
fakePipeline: P,
184200
plugin: BaseRouteScopedPlugin<B, F>,

ktor-server/ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/routing/CORS.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package io.ktor.server.plugins.cors.routing
66

77
import io.ktor.server.application.*
88
import io.ktor.server.plugins.cors.*
9+
import io.ktor.server.routing.options
910

1011
/**
1112
* A plugin that allows you to configure handling cross-origin requests.
@@ -24,5 +25,9 @@ import io.ktor.server.plugins.cors.*
2425
* [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.server.plugins.cors.routing.CORS)
2526
*/
2627
public val CORS: RouteScopedPlugin<CORSConfig> = createRouteScopedPlugin("CORS", ::CORSConfig) {
28+
route?.options("{cors-options-wildcard...}") {
29+
// Handled by an interceptor of the Call phase added in the plugin
30+
}
31+
2732
buildPlugin()
2833
}

ktor-server/ktor-server-tests/common/test/io/ktor/tests/server/plugins/CORSTest.kt

Lines changed: 318 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import io.ktor.client.statement.*
99
import io.ktor.http.*
1010
import io.ktor.server.application.*
1111
import io.ktor.server.application.hooks.*
12-
import io.ktor.server.plugins.cors.routing.*
12+
import io.ktor.server.plugins.cors.routing.CORS
13+
import io.ktor.server.request.httpMethod
1314
import io.ktor.server.response.*
1415
import io.ktor.server.routing.*
1516
import io.ktor.server.testing.*
@@ -125,6 +126,14 @@ class CORSTest {
125126
}
126127
}
127128

129+
client.options("/1") {
130+
header(HttpHeaders.Origin, "http://my-host")
131+
header(HttpHeaders.AccessControlRequestMethod, "GET")
132+
}.let { response ->
133+
assertEquals(HttpStatusCode.OK, response.status)
134+
assertEquals("http://my-host", response.headers[HttpHeaders.AccessControlAllowOrigin])
135+
}
136+
128137
client.get("/1") {
129138
header(HttpHeaders.Origin, "http://my-host")
130139
}.let { response ->
@@ -1275,4 +1284,312 @@ class CORSTest {
12751284
}.status
12761285
)
12771286
}
1287+
1288+
@Test
1289+
fun preflightRoutingCorsAfter() = testApplication {
1290+
routing {
1291+
get("/test") {
1292+
call.respond("OK")
1293+
}
1294+
1295+
install(CORS) {
1296+
allowHost("example.com")
1297+
}
1298+
}
1299+
1300+
val response = client.options("/test") {
1301+
header(HttpHeaders.Origin, "https://example.com")
1302+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1303+
}
1304+
1305+
assertEquals(response.status, HttpStatusCode.OK)
1306+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1307+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1308+
}
1309+
1310+
@Test
1311+
fun preflightRoutingNested() = testApplication {
1312+
routing {
1313+
install(CORS) {
1314+
allowHost("example.com")
1315+
}
1316+
1317+
route("/sub") {
1318+
install(CORS) {
1319+
allowHost("another.com")
1320+
}
1321+
1322+
get {
1323+
call.respond("OK")
1324+
}
1325+
}
1326+
1327+
get {
1328+
call.respond("OK")
1329+
}
1330+
}
1331+
1332+
client.options("/sub") {
1333+
header(HttpHeaders.Origin, "https://another.com")
1334+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1335+
}.let { response ->
1336+
assertEquals(response.status, HttpStatusCode.OK)
1337+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1338+
assertEquals("https://another.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1339+
}
1340+
1341+
client.options("/sub") {
1342+
header(HttpHeaders.Origin, "https://example.com")
1343+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1344+
}.let { response ->
1345+
assertEquals(response.status, HttpStatusCode.Forbidden)
1346+
}
1347+
1348+
client.options("/") {
1349+
header(HttpHeaders.Origin, "https://example.com")
1350+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1351+
}.let { response ->
1352+
assertEquals(response.status, HttpStatusCode.OK)
1353+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1354+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1355+
}
1356+
1357+
client.options("/") {
1358+
header(HttpHeaders.Origin, "https://another.com")
1359+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1360+
}.let { response ->
1361+
assertEquals(response.status, HttpStatusCode.Forbidden)
1362+
}
1363+
}
1364+
1365+
@Test
1366+
fun preflightNoRouting() = testApplication {
1367+
application {
1368+
intercept(ApplicationCallPipeline.Call) {
1369+
if (context.request.httpMethod == HttpMethod.Get) {
1370+
call.respondText("OK")
1371+
}
1372+
}
1373+
}
1374+
1375+
install(CORS) {
1376+
allowHost("example.com")
1377+
}
1378+
1379+
client.options("/") {
1380+
header(HttpHeaders.Origin, "https://example.com")
1381+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1382+
}.let { response ->
1383+
assertEquals(response.status, HttpStatusCode.OK)
1384+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1385+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1386+
}
1387+
1388+
client.get("/") {
1389+
header(HttpHeaders.Origin, "https://example.com")
1390+
}.let { response ->
1391+
assertEquals(response.status, HttpStatusCode.OK)
1392+
assertEquals(response.bodyAsText(), "OK")
1393+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1394+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1395+
}
1396+
}
1397+
1398+
@Test
1399+
fun preflightHasOptionsRoute() = testApplication {
1400+
routing {
1401+
route("/test") {
1402+
get {
1403+
call.respond("OK")
1404+
}
1405+
1406+
options {
1407+
call.respond("OPTIONS")
1408+
}
1409+
}
1410+
}
1411+
1412+
install(CORS) {
1413+
allowHost("example.com")
1414+
}
1415+
1416+
client.options("/test") {
1417+
header(HttpHeaders.Origin, "https://example.com")
1418+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1419+
}.let { response ->
1420+
assertEquals(response.status, HttpStatusCode.OK)
1421+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1422+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1423+
}
1424+
}
1425+
1426+
@Test
1427+
fun preflightRoutingHasOptionsRouteAfter() = testApplication {
1428+
routing {
1429+
route("/test") {
1430+
install(CORS) {
1431+
allowHost("example.com")
1432+
}
1433+
1434+
get {
1435+
call.respond("OK")
1436+
}
1437+
1438+
options {
1439+
call.respond("OPTIONS")
1440+
}
1441+
}
1442+
}
1443+
1444+
client.options("/test") {
1445+
header(HttpHeaders.Origin, "https://example.com")
1446+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1447+
}.let { response ->
1448+
assertEquals(response.status, HttpStatusCode.OK)
1449+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1450+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1451+
}
1452+
}
1453+
1454+
@Test
1455+
fun preflightRoutingHasOptionsRouteBefore() = testApplication {
1456+
routing {
1457+
route("/test") {
1458+
options {
1459+
call.respond("OPTIONS")
1460+
}
1461+
1462+
install(CORS) {
1463+
allowHost("example.com")
1464+
}
1465+
1466+
get {
1467+
call.respond("OK")
1468+
}
1469+
}
1470+
}
1471+
1472+
client.options("/test") {
1473+
header(HttpHeaders.Origin, "https://example.com")
1474+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1475+
}.let { response ->
1476+
assertEquals(response.status, HttpStatusCode.OK)
1477+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1478+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1479+
}
1480+
}
1481+
1482+
@Test
1483+
fun preflightAllowedMethodsConfinedToRoutes() = testApplication {
1484+
routing {
1485+
route("/test") {
1486+
install(CORS) {
1487+
allowHost("example.com")
1488+
allowMethod(HttpMethod.Put)
1489+
}
1490+
}
1491+
1492+
route("/other") {
1493+
install(CORS) {
1494+
allowHost("example.com")
1495+
allowMethod(HttpMethod.Patch)
1496+
}
1497+
}
1498+
}
1499+
1500+
client.options("/test") {
1501+
header(HttpHeaders.Origin, "https://example.com")
1502+
header(HttpHeaders.AccessControlRequestMethod, "PUT")
1503+
}.let { response ->
1504+
assertEquals(response.status, HttpStatusCode.OK)
1505+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1506+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1507+
}
1508+
1509+
client.options("/test") {
1510+
header(HttpHeaders.Origin, "https://example.com")
1511+
header(HttpHeaders.AccessControlRequestMethod, "PATCH")
1512+
}.let { response ->
1513+
assertEquals(response.status, HttpStatusCode.Forbidden)
1514+
}
1515+
1516+
client.options("/other") {
1517+
header(HttpHeaders.Origin, "https://example.com")
1518+
header(HttpHeaders.AccessControlRequestMethod, "PATCH")
1519+
}.let { response ->
1520+
assertEquals(response.status, HttpStatusCode.OK)
1521+
assertEquals(HttpHeaders.Origin, response.headers[HttpHeaders.Vary])
1522+
assertEquals("https://example.com", response.headers[HttpHeaders.AccessControlAllowOrigin])
1523+
}
1524+
1525+
client.options("/other") {
1526+
header(HttpHeaders.Origin, "https://example.com")
1527+
header(HttpHeaders.AccessControlRequestMethod, "PUT")
1528+
}.let { response ->
1529+
assertEquals(response.status, HttpStatusCode.Forbidden)
1530+
}
1531+
}
1532+
1533+
@Test
1534+
fun preflightPluginInParentHandlesChildRoutes() = testApplication {
1535+
routing {
1536+
route("/outer") {
1537+
install(CORS) {
1538+
anyHost()
1539+
}
1540+
1541+
route("/inner1") {
1542+
route("/inner2") {
1543+
get {
1544+
call.respond("OK")
1545+
}
1546+
}
1547+
}
1548+
}
1549+
}
1550+
1551+
client.options("/outer/inner1/inner2") {
1552+
header(HttpHeaders.Origin, "https://example.com")
1553+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1554+
}.let { response ->
1555+
assertEquals(response.status, HttpStatusCode.OK)
1556+
assertEquals("*", response.headers[HttpHeaders.AccessControlAllowOrigin])
1557+
}
1558+
}
1559+
1560+
@Test
1561+
fun preflightNoEffectWithoutOriginHeader() = testApplication {
1562+
routing {
1563+
install(CORS) {
1564+
anyHost()
1565+
}
1566+
}
1567+
1568+
client.options("/") {
1569+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1570+
}.let { response ->
1571+
assertEquals(response.status, HttpStatusCode.NotFound)
1572+
assertNull(response.headers[HttpHeaders.AccessControlAllowOrigin])
1573+
}
1574+
}
1575+
1576+
@Test
1577+
fun routeInsideCorsPluginConfig() = testApplication {
1578+
routing {
1579+
install(CORS) {
1580+
route("/abc") {
1581+
get {
1582+
call.respond("OK")
1583+
}
1584+
}
1585+
}
1586+
}
1587+
1588+
client.options("/abc") {
1589+
header(HttpHeaders.Origin, "https://example.com")
1590+
header(HttpHeaders.AccessControlRequestMethod, "GET")
1591+
}.let { response ->
1592+
assertEquals(response.status, HttpStatusCode.Forbidden)
1593+
}
1594+
}
12781595
}

0 commit comments

Comments
 (0)