Skip to content

Commit c22c70c

Browse files
authored
Update BaseDocumentStore to not return Nones in result (#19513)
1 parent 68f8ea3 commit c22c70c

File tree

1 file changed

+31
-37
lines changed
  • llama-index-core/llama_index/core/storage/docstore

1 file changed

+31
-37
lines changed

llama-index-core/llama_index/core/storage/docstore/types.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,10 @@ async def async_add_documents(
5454
) -> None: ...
5555

5656
@abstractmethod
57-
def get_document(
58-
self, doc_id: str, raise_error: bool = True
59-
) -> Optional[BaseNode]: ...
57+
def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[BaseNode]: ...
6058

6159
@abstractmethod
62-
async def aget_document(
63-
self, doc_id: str, raise_error: bool = True
64-
) -> Optional[BaseNode]: ...
60+
async def aget_document(self, doc_id: str, raise_error: bool = True) -> Optional[BaseNode]: ...
6561

6662
@abstractmethod
6763
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
@@ -130,9 +126,7 @@ async def adelete_ref_doc(self, ref_doc_id: str, raise_error: bool = True) -> No
130126
"""Delete a ref_doc and all it's associated nodes."""
131127

132128
# ===== Nodes =====
133-
def get_nodes(
134-
self, node_ids: List[str], raise_error: bool = True
135-
) -> List[BaseNode]:
129+
def get_nodes(self, node_ids: List[str], raise_error: bool = True) -> List[BaseNode]:
136130
"""
137131
Get nodes from docstore.
138132
@@ -141,15 +135,20 @@ def get_nodes(
141135
raise_error (bool): raise error if node_id not found
142136
143137
"""
144-
# if/else needed for type checking
145-
if raise_error:
146-
return [node for node_id in node_ids if (node := self.get_node(node_id, raise_error=True))]
147-
else:
148-
return [self.get_node(node_id) for node_id in node_ids]
149-
150-
async def aget_nodes(
151-
self, node_ids: List[str], raise_error: bool = True
152-
) -> List[BaseNode]:
138+
nodes: list[BaseNode] = []
139+
140+
for node_id in node_ids:
141+
# if needed for type checking
142+
if not raise_error:
143+
if node := self.get_node(node_id=node_id, raise_error=False):
144+
nodes.append(node)
145+
continue
146+
147+
nodes.append(self.get_node(node_id=node_id, raise_error=True))
148+
149+
return nodes
150+
151+
async def aget_nodes(self, node_ids: List[str], raise_error: bool = True) -> List[BaseNode]:
153152
"""
154153
Get nodes from docstore.
155154
@@ -158,18 +157,18 @@ async def aget_nodes(
158157
raise_error (bool): raise error if node_id not found
159158
160159
"""
161-
# if/else needed for type checking
162-
if raise_error:
163-
return [
164-
node
165-
for node_id in node_ids
166-
if (node := await self.aget_node(node_id, raise_error=True))
167-
]
168-
else:
169-
return [
170-
await self.aget_node(node_id)
171-
for node_id in node_ids
172-
]
160+
nodes: list[BaseNode] = []
161+
162+
for node_id in node_ids:
163+
# if needed for type checking
164+
if not raise_error:
165+
if node := await self.aget_node(node_id=node_id, raise_error=False):
166+
nodes.append(node)
167+
continue
168+
169+
nodes.append(await self.aget_node(node_id=node_id, raise_error=True))
170+
171+
return nodes
173172

174173
@overload
175174
def get_node(self, node_id: str, raise_error: Literal[True] = True) -> BaseNode: ...
@@ -239,9 +238,7 @@ def get_node_dict(self, node_id_dict: Dict[int, str]) -> Dict[int, BaseNode]:
239238
node_id_dict (Dict[int, str]): mapping of index to node ids
240239
241240
"""
242-
return {
243-
index: self.get_node(node_id) for index, node_id in node_id_dict.items()
244-
}
241+
return {index: self.get_node(node_id) for index, node_id in node_id_dict.items()}
245242

246243
async def aget_node_dict(self, node_id_dict: Dict[int, str]) -> Dict[int, BaseNode]:
247244
"""
@@ -251,7 +248,4 @@ async def aget_node_dict(self, node_id_dict: Dict[int, str]) -> Dict[int, BaseNo
251248
node_id_dict (Dict[int, str]): mapping of index to node ids
252249
253250
"""
254-
return {
255-
index: await self.aget_node(node_id)
256-
for index, node_id in node_id_dict.items()
257-
}
251+
return {index: await self.aget_node(node_id) for index, node_id in node_id_dict.items()}

0 commit comments

Comments
 (0)