@@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
22import { AddressInfo } from "net" ;
33import { JSONRPCMessage } from "../types.js" ;
44import { SSEClientTransport } from "./sse.js" ;
5- import { auth , OAuthClientProvider } from "./auth.js" ;
5+ import { OAuthClientProvider , OAuthTokens } from "./auth.js" ;
66
77describe ( "SSEClientTransport" , ( ) => {
88 let server : Server ;
@@ -301,7 +301,7 @@ describe("SSEClientTransport", () => {
301301 mockAuthProvider = {
302302 get redirectUrl ( ) { return "http://localhost/callback" ; } ,
303303 get clientMetadata ( ) { return { redirect_uris : [ "http://localhost/callback" ] } ; } ,
304- clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" } ) ) ,
304+ clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" , client_secret : "test-client-secret" } ) ) ,
305305 tokens : jest . fn ( ) ,
306306 saveTokens : jest . fn ( ) ,
307307 redirectToAuthorization : jest . fn ( ) ,
@@ -466,5 +466,257 @@ describe("SSEClientTransport", () => {
466466 expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
467467 expect ( lastServerRequest . headers [ "x-custom-header" ] ) . toBe ( "custom-value" ) ;
468468 } ) ;
469+
470+ it ( "refreshes expired token during SSE connection" , async ( ) => {
471+ // Mock tokens() to return expired token until saveTokens is called
472+ let currentTokens : OAuthTokens = {
473+ access_token : "expired-token" ,
474+ token_type : "Bearer" ,
475+ refresh_token : "refresh-token"
476+ } ;
477+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
478+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
479+ currentTokens = tokens ;
480+ } ) ;
481+
482+ // Create server that returns 401 for expired token, then accepts new token
483+ await server . close ( ) ;
484+
485+ let connectionAttempts = 0 ;
486+ server = createServer ( ( req , res ) => {
487+ lastServerRequest = req ;
488+
489+ if ( req . url === "/token" && req . method === "POST" ) {
490+ // Handle token refresh request
491+ let body = "" ;
492+ req . on ( "data" , chunk => { body += chunk ; } ) ;
493+ req . on ( "end" , ( ) => {
494+ const params = new URLSearchParams ( body ) ;
495+ if ( params . get ( "grant_type" ) === "refresh_token" &&
496+ params . get ( "refresh_token" ) === "refresh-token" &&
497+ params . get ( "client_id" ) === "test-client-id" &&
498+ params . get ( "client_secret" ) === "test-client-secret" ) {
499+ res . writeHead ( 200 , { "Content-Type" : "application/json" } ) ;
500+ res . end ( JSON . stringify ( {
501+ access_token : "new-token" ,
502+ token_type : "Bearer" ,
503+ refresh_token : "new-refresh-token"
504+ } ) ) ;
505+ } else {
506+ res . writeHead ( 400 ) . end ( ) ;
507+ }
508+ } ) ;
509+ return ;
510+ }
511+
512+ if ( req . url !== "/" ) {
513+ res . writeHead ( 404 ) . end ( ) ;
514+ return ;
515+ }
516+
517+ const auth = req . headers . authorization ;
518+ if ( auth === "Bearer expired-token" ) {
519+ res . writeHead ( 401 ) . end ( ) ;
520+ return ;
521+ }
522+
523+ if ( auth === "Bearer new-token" ) {
524+ res . writeHead ( 200 , {
525+ "Content-Type" : "text/event-stream" ,
526+ "Cache-Control" : "no-cache" ,
527+ Connection : "keep-alive" ,
528+ } ) ;
529+ res . write ( "event: endpoint\n" ) ;
530+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
531+ connectionAttempts ++ ;
532+ return ;
533+ }
534+
535+ res . writeHead ( 401 ) . end ( ) ;
536+ } ) ;
537+
538+ await new Promise < void > ( resolve => {
539+ server . listen ( 0 , "127.0.0.1" , ( ) => {
540+ const addr = server . address ( ) as AddressInfo ;
541+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
542+ resolve ( ) ;
543+ } ) ;
544+ } ) ;
545+
546+ transport = new SSEClientTransport ( baseUrl , {
547+ authProvider : mockAuthProvider ,
548+ } ) ;
549+
550+ await transport . start ( ) ;
551+
552+ expect ( mockAuthProvider . saveTokens ) . toHaveBeenCalledWith ( {
553+ access_token : "new-token" ,
554+ token_type : "Bearer" ,
555+ refresh_token : "new-refresh-token"
556+ } ) ;
557+ expect ( connectionAttempts ) . toBe ( 1 ) ;
558+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer new-token" ) ;
559+ } ) ;
560+
561+ it ( "refreshes expired token during POST request" , async ( ) => {
562+ // Mock tokens() to return expired token until saveTokens is called
563+ let currentTokens : OAuthTokens = {
564+ access_token : "expired-token" ,
565+ token_type : "Bearer" ,
566+ refresh_token : "refresh-token"
567+ } ;
568+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
569+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
570+ currentTokens = tokens ;
571+ } ) ;
572+
573+ // Create server that accepts SSE but returns 401 on POST with expired token
574+ await server . close ( ) ;
575+
576+ let postAttempts = 0 ;
577+ server = createServer ( ( req , res ) => {
578+ lastServerRequest = req ;
579+
580+ if ( req . url === "/token" && req . method === "POST" ) {
581+ // Handle token refresh request
582+ let body = "" ;
583+ req . on ( "data" , chunk => { body += chunk ; } ) ;
584+ req . on ( "end" , ( ) => {
585+ const params = new URLSearchParams ( body ) ;
586+ if ( params . get ( "grant_type" ) === "refresh_token" &&
587+ params . get ( "refresh_token" ) === "refresh-token" &&
588+ params . get ( "client_id" ) === "test-client-id" &&
589+ params . get ( "client_secret" ) === "test-client-secret" ) {
590+ res . writeHead ( 200 , { "Content-Type" : "application/json" } ) ;
591+ res . end ( JSON . stringify ( {
592+ access_token : "new-token" ,
593+ token_type : "Bearer" ,
594+ refresh_token : "new-refresh-token"
595+ } ) ) ;
596+ } else {
597+ res . writeHead ( 400 ) . end ( ) ;
598+ }
599+ } ) ;
600+ return ;
601+ }
602+
603+ switch ( req . method ) {
604+ case "GET" :
605+ if ( req . url !== "/" ) {
606+ res . writeHead ( 404 ) . end ( ) ;
607+ return ;
608+ }
609+
610+ res . writeHead ( 200 , {
611+ "Content-Type" : "text/event-stream" ,
612+ "Cache-Control" : "no-cache" ,
613+ Connection : "keep-alive" ,
614+ } ) ;
615+ res . write ( "event: endpoint\n" ) ;
616+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
617+ break ;
618+
619+ case "POST" : {
620+ if ( req . url !== "/" ) {
621+ res . writeHead ( 404 ) . end ( ) ;
622+ return ;
623+ }
624+
625+ const auth = req . headers . authorization ;
626+ if ( auth === "Bearer expired-token" ) {
627+ res . writeHead ( 401 ) . end ( ) ;
628+ return ;
629+ }
630+
631+ if ( auth === "Bearer new-token" ) {
632+ res . writeHead ( 200 ) . end ( ) ;
633+ postAttempts ++ ;
634+ return ;
635+ }
636+
637+ res . writeHead ( 401 ) . end ( ) ;
638+ break ;
639+ }
640+ }
641+ } ) ;
642+
643+ await new Promise < void > ( resolve => {
644+ server . listen ( 0 , "127.0.0.1" , ( ) => {
645+ const addr = server . address ( ) as AddressInfo ;
646+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
647+ resolve ( ) ;
648+ } ) ;
649+ } ) ;
650+
651+ transport = new SSEClientTransport ( baseUrl , {
652+ authProvider : mockAuthProvider ,
653+ } ) ;
654+
655+ await transport . start ( ) ;
656+
657+ const message : JSONRPCMessage = {
658+ jsonrpc : "2.0" ,
659+ id : "1" ,
660+ method : "test" ,
661+ params : { } ,
662+ } ;
663+
664+ await transport . send ( message ) ;
665+
666+ expect ( mockAuthProvider . saveTokens ) . toHaveBeenCalledWith ( {
667+ access_token : "new-token" ,
668+ token_type : "Bearer" ,
669+ refresh_token : "new-refresh-token"
670+ } ) ;
671+ expect ( postAttempts ) . toBe ( 1 ) ;
672+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer new-token" ) ;
673+ } ) ;
674+
675+ it ( "redirects to authorization if refresh token flow fails" , async ( ) => {
676+ // Mock tokens() to return expired token until saveTokens is called
677+ let currentTokens : OAuthTokens = {
678+ access_token : "expired-token" ,
679+ token_type : "Bearer" ,
680+ refresh_token : "refresh-token"
681+ } ;
682+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
683+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
684+ currentTokens = tokens ;
685+ } ) ;
686+
687+ // Create server that returns 401 for all tokens
688+ await server . close ( ) ;
689+
690+ server = createServer ( ( req , res ) => {
691+ lastServerRequest = req ;
692+
693+ if ( req . url === "/token" && req . method === "POST" ) {
694+ // Handle token refresh request - always fail
695+ res . writeHead ( 400 ) . end ( ) ;
696+ return ;
697+ }
698+
699+ if ( req . url !== "/" ) {
700+ res . writeHead ( 404 ) . end ( ) ;
701+ return ;
702+ }
703+ res . writeHead ( 401 ) . end ( ) ;
704+ } ) ;
705+
706+ await new Promise < void > ( resolve => {
707+ server . listen ( 0 , "127.0.0.1" , ( ) => {
708+ const addr = server . address ( ) as AddressInfo ;
709+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
710+ resolve ( ) ;
711+ } ) ;
712+ } ) ;
713+
714+ transport = new SSEClientTransport ( baseUrl , {
715+ authProvider : mockAuthProvider ,
716+ } ) ;
717+
718+ await expect ( transport . start ( ) ) . rejects . toThrow ( "Unauthorized" ) ;
719+ expect ( mockAuthProvider . redirectToAuthorization ) . toHaveBeenCalled ( ) ;
720+ } ) ;
469721 } ) ;
470722} ) ;
0 commit comments