@@ -453,4 +453,264 @@ describe('SSEServerTransport', () => {
453453 expect . stringContaining ( `data: /messages?sessionId=${ transport . sessionId } ` ) ) ;
454454 } ) ;
455455 } ) ;
456- } ) ;
456+
457+ describe ( 'DNS rebinding protection' , ( ) => {
458+ beforeEach ( ( ) => {
459+ jest . clearAllMocks ( ) ;
460+ } ) ;
461+
462+ describe ( 'Host header validation' , ( ) => {
463+ it ( 'should accept requests with allowed host headers' , async ( ) => {
464+ const mockRes = createMockResponse ( ) ;
465+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
466+ allowedHosts : [ 'localhost:3000' , 'example.com' ] ,
467+ enableDnsRebindingProtection : true ,
468+ } ) ;
469+ await transport . start ( ) ;
470+
471+ const mockReq = createMockRequest ( {
472+ headers : {
473+ host : 'localhost:3000' ,
474+ 'content-type' : 'application/json' ,
475+ }
476+ } ) ;
477+ const mockHandleRes = createMockResponse ( ) ;
478+
479+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
480+
481+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
482+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
483+ } ) ;
484+
485+ it ( 'should reject requests with disallowed host headers' , async ( ) => {
486+ const mockRes = createMockResponse ( ) ;
487+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
488+ allowedHosts : [ 'localhost:3000' ] ,
489+ enableDnsRebindingProtection : true ,
490+ } ) ;
491+ await transport . start ( ) ;
492+
493+ const mockReq = createMockRequest ( {
494+ headers : {
495+ host : 'evil.com' ,
496+ 'content-type' : 'application/json' ,
497+ }
498+ } ) ;
499+ const mockHandleRes = createMockResponse ( ) ;
500+
501+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
502+
503+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
504+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Host header: evil.com' ) ;
505+ } ) ;
506+
507+ it ( 'should reject requests without host header when allowedHosts is configured' , async ( ) => {
508+ const mockRes = createMockResponse ( ) ;
509+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
510+ allowedHosts : [ 'localhost:3000' ] ,
511+ enableDnsRebindingProtection : true ,
512+ } ) ;
513+ await transport . start ( ) ;
514+
515+ const mockReq = createMockRequest ( {
516+ headers : {
517+ 'content-type' : 'application/json' ,
518+ }
519+ } ) ;
520+ const mockHandleRes = createMockResponse ( ) ;
521+
522+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
523+
524+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
525+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Host header: undefined' ) ;
526+ } ) ;
527+ } ) ;
528+
529+ describe ( 'Origin header validation' , ( ) => {
530+ it ( 'should accept requests with allowed origin headers' , async ( ) => {
531+ const mockRes = createMockResponse ( ) ;
532+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
533+ allowedOrigins : [ 'http://localhost:3000' , 'https://example.com' ] ,
534+ enableDnsRebindingProtection : true ,
535+ } ) ;
536+ await transport . start ( ) ;
537+
538+ const mockReq = createMockRequest ( {
539+ headers : {
540+ origin : 'http://localhost:3000' ,
541+ 'content-type' : 'application/json' ,
542+ }
543+ } ) ;
544+ const mockHandleRes = createMockResponse ( ) ;
545+
546+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
547+
548+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
549+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
550+ } ) ;
551+
552+ it ( 'should reject requests with disallowed origin headers' , async ( ) => {
553+ const mockRes = createMockResponse ( ) ;
554+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
555+ allowedOrigins : [ 'http://localhost:3000' ] ,
556+ enableDnsRebindingProtection : true ,
557+ } ) ;
558+ await transport . start ( ) ;
559+
560+ const mockReq = createMockRequest ( {
561+ headers : {
562+ origin : 'http://evil.com' ,
563+ 'content-type' : 'application/json' ,
564+ }
565+ } ) ;
566+ const mockHandleRes = createMockResponse ( ) ;
567+
568+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
569+
570+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
571+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Origin header: http://evil.com' ) ;
572+ } ) ;
573+ } ) ;
574+
575+ describe ( 'Content-Type validation' , ( ) => {
576+ it ( 'should accept requests with application/json content-type' , async ( ) => {
577+ const mockRes = createMockResponse ( ) ;
578+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
579+ await transport . start ( ) ;
580+
581+ const mockReq = createMockRequest ( {
582+ headers : {
583+ 'content-type' : 'application/json' ,
584+ }
585+ } ) ;
586+ const mockHandleRes = createMockResponse ( ) ;
587+
588+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
589+
590+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
591+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
592+ } ) ;
593+
594+ it ( 'should accept requests with application/json with charset' , async ( ) => {
595+ const mockRes = createMockResponse ( ) ;
596+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
597+ await transport . start ( ) ;
598+
599+ const mockReq = createMockRequest ( {
600+ headers : {
601+ 'content-type' : 'application/json; charset=utf-8' ,
602+ }
603+ } ) ;
604+ const mockHandleRes = createMockResponse ( ) ;
605+
606+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
607+
608+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
609+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
610+ } ) ;
611+
612+ it ( 'should reject requests with non-application/json content-type when protection is enabled' , async ( ) => {
613+ const mockRes = createMockResponse ( ) ;
614+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
615+ await transport . start ( ) ;
616+
617+ const mockReq = createMockRequest ( {
618+ headers : {
619+ 'content-type' : 'text/plain' ,
620+ }
621+ } ) ;
622+ const mockHandleRes = createMockResponse ( ) ;
623+
624+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
625+
626+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
627+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Error: Unsupported content-type: text/plain' ) ;
628+ } ) ;
629+ } ) ;
630+
631+ describe ( 'enableDnsRebindingProtection option' , ( ) => {
632+ it ( 'should skip all validations when enableDnsRebindingProtection is false' , async ( ) => {
633+ const mockRes = createMockResponse ( ) ;
634+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
635+ allowedHosts : [ 'localhost:3000' ] ,
636+ allowedOrigins : [ 'http://localhost:3000' ] ,
637+ enableDnsRebindingProtection : false ,
638+ } ) ;
639+ await transport . start ( ) ;
640+
641+ const mockReq = createMockRequest ( {
642+ headers : {
643+ host : 'evil.com' ,
644+ origin : 'http://evil.com' ,
645+ 'content-type' : 'text/plain' ,
646+ }
647+ } ) ;
648+ const mockHandleRes = createMockResponse ( ) ;
649+
650+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
651+
652+ // Should pass even with invalid headers because protection is disabled
653+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
654+ // The error should be from content-type parsing, not DNS rebinding protection
655+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Error: Unsupported content-type: text/plain' ) ;
656+ } ) ;
657+ } ) ;
658+
659+ describe ( 'Combined validations' , ( ) => {
660+ it ( 'should validate both host and origin when both are configured' , async ( ) => {
661+ const mockRes = createMockResponse ( ) ;
662+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
663+ allowedHosts : [ 'localhost:3000' ] ,
664+ allowedOrigins : [ 'http://localhost:3000' ] ,
665+ enableDnsRebindingProtection : true ,
666+ } ) ;
667+ await transport . start ( ) ;
668+
669+ // Valid host, invalid origin
670+ const mockReq1 = createMockRequest ( {
671+ headers : {
672+ host : 'localhost:3000' ,
673+ origin : 'http://evil.com' ,
674+ 'content-type' : 'application/json' ,
675+ }
676+ } ) ;
677+ const mockHandleRes1 = createMockResponse ( ) ;
678+
679+ await transport . handlePostMessage ( mockReq1 , mockHandleRes1 , { jsonrpc : '2.0' , method : 'test' } ) ;
680+
681+ expect ( mockHandleRes1 . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
682+ expect ( mockHandleRes1 . end ) . toHaveBeenCalledWith ( 'Invalid Origin header: http://evil.com' ) ;
683+
684+ // Invalid host, valid origin
685+ const mockReq2 = createMockRequest ( {
686+ headers : {
687+ host : 'evil.com' ,
688+ origin : 'http://localhost:3000' ,
689+ 'content-type' : 'application/json' ,
690+ }
691+ } ) ;
692+ const mockHandleRes2 = createMockResponse ( ) ;
693+
694+ await transport . handlePostMessage ( mockReq2 , mockHandleRes2 , { jsonrpc : '2.0' , method : 'test' } ) ;
695+
696+ expect ( mockHandleRes2 . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
697+ expect ( mockHandleRes2 . end ) . toHaveBeenCalledWith ( 'Invalid Host header: evil.com' ) ;
698+
699+ // Both valid
700+ const mockReq3 = createMockRequest ( {
701+ headers : {
702+ host : 'localhost:3000' ,
703+ origin : 'http://localhost:3000' ,
704+ 'content-type' : 'application/json' ,
705+ }
706+ } ) ;
707+ const mockHandleRes3 = createMockResponse ( ) ;
708+
709+ await transport . handlePostMessage ( mockReq3 , mockHandleRes3 , { jsonrpc : '2.0' , method : 'test' } ) ;
710+
711+ expect ( mockHandleRes3 . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
712+ expect ( mockHandleRes3 . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
713+ } ) ;
714+ } ) ;
715+ } ) ;
716+ } ) ;
0 commit comments