@@ -54,14 +54,10 @@ async def async_add_documents(
5454    ) ->  None : ...
5555
5656    @abstractmethod  
57-     def  get_document (
58-         self , doc_id : str , raise_error : bool  =  True 
59-     ) ->  Optional [BaseNode ]: ...
57+     def  get_document (self , doc_id : str , raise_error : bool  =  True ) ->  Optional [BaseNode ]: ...
6058
6159    @abstractmethod  
62-     async  def  aget_document (
63-         self , doc_id : str , raise_error : bool  =  True 
64-     ) ->  Optional [BaseNode ]: ...
60+     async  def  aget_document (self , doc_id : str , raise_error : bool  =  True ) ->  Optional [BaseNode ]: ...
6561
6662    @abstractmethod  
6763    def  delete_document (self , doc_id : str , raise_error : bool  =  True ) ->  None :
@@ -130,9 +126,7 @@ async def adelete_ref_doc(self, ref_doc_id: str, raise_error: bool = True) -> No
130126        """Delete a ref_doc and all it's associated nodes.""" 
131127
132128    # ===== Nodes ===== 
133-     def  get_nodes (
134-         self , node_ids : List [str ], raise_error : bool  =  True 
135-     ) ->  List [BaseNode ]:
129+     def  get_nodes (self , node_ids : List [str ], raise_error : bool  =  True ) ->  List [BaseNode ]:
136130        """ 
137131        Get nodes from docstore. 
138132
@@ -141,15 +135,20 @@ def get_nodes(
141135            raise_error (bool): raise error if node_id not found 
142136
143137        """ 
144-         # if/else needed for type checking 
145-         if  raise_error :
146-             return  [node  for  node_id  in  node_ids  if  (node  :=  self .get_node (node_id , raise_error = True ))]
147-         else :
148-             return  [self .get_node (node_id ) for  node_id  in  node_ids ]
149- 
150-     async  def  aget_nodes (
151-         self , node_ids : List [str ], raise_error : bool  =  True 
152-     ) ->  List [BaseNode ]:
138+         nodes : list [BaseNode ] =  []
139+ 
140+         for  node_id  in  node_ids :
141+             # if needed for type checking 
142+             if  not  raise_error :
143+                 if  node  :=  self .get_node (node_id = node_id , raise_error = False ):
144+                     nodes .append (node )
145+                 continue 
146+ 
147+             nodes .append (self .get_node (node_id = node_id , raise_error = True ))
148+ 
149+         return  nodes 
150+ 
151+     async  def  aget_nodes (self , node_ids : List [str ], raise_error : bool  =  True ) ->  List [BaseNode ]:
153152        """ 
154153        Get nodes from docstore. 
155154
@@ -158,18 +157,18 @@ async def aget_nodes(
158157            raise_error (bool): raise error if node_id not found 
159158
160159        """ 
161-         # if/else needed for type checking 
162-          if   raise_error : 
163-              return  [ 
164-                  node 
165-                  for   node_id   in   node_ids 
166-                 if  ( node  :=  await  self .aget_node (node_id , raise_error = True )) 
167-             ] 
168-         else : 
169-              return  [ 
170-                  await  self .aget_node (node_id )
171-                  for   node_id   in   node_ids 
172-             ] 
160+         nodes :  list [ BaseNode ]  =  [] 
161+ 
162+         for   node_id   in   node_ids : 
163+             # if needed for type checking 
164+             if   not   raise_error : 
165+                 if  node  :=  await  self .aget_node (node_id = node_id , raise_error = False ): 
166+                      nodes . append ( node ) 
167+                  continue 
168+ 
169+             nodes . append ( await  self .aget_node (node_id = node_id ,  raise_error = True ) )
170+ 
171+         return   nodes 
173172
174173    @overload  
175174    def  get_node (self , node_id : str , raise_error : Literal [True ] =  True ) ->  BaseNode : ...
@@ -239,9 +238,7 @@ def get_node_dict(self, node_id_dict: Dict[int, str]) -> Dict[int, BaseNode]:
239238            node_id_dict (Dict[int, str]): mapping of index to node ids 
240239
241240        """ 
242-         return  {
243-             index : self .get_node (node_id ) for  index , node_id  in  node_id_dict .items ()
244-         }
241+         return  {index : self .get_node (node_id ) for  index , node_id  in  node_id_dict .items ()}
245242
246243    async  def  aget_node_dict (self , node_id_dict : Dict [int , str ]) ->  Dict [int , BaseNode ]:
247244        """ 
@@ -251,7 +248,4 @@ async def aget_node_dict(self, node_id_dict: Dict[int, str]) -> Dict[int, BaseNo
251248            node_id_dict (Dict[int, str]): mapping of index to node ids 
252249
253250        """ 
254-         return  {
255-             index : await  self .aget_node (node_id )
256-             for  index , node_id  in  node_id_dict .items ()
257-         }
251+         return  {index : await  self .aget_node (node_id ) for  index , node_id  in  node_id_dict .items ()}
0 commit comments