Skip to content

Commit b829405

Browse files
feat: add non persisted composite retrieval (#18908)
1 parent aa63555 commit b829405

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/composite_retriever.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
# composite retrieval params
4646
mode: Optional[CompositeRetrievalMode] = None,
4747
rerank_top_n: Optional[int] = None,
48+
persisted: Optional[bool] = True,
4849
**kwargs: Any,
4950
) -> None:
5051
"""Initialize the Composite Retriever."""
@@ -58,13 +59,15 @@ def __init__(
5859
self._client, project_name, project_id, organization_id
5960
)
6061

61-
self.retriever = resolve_retriever(
62-
self._client, self.project, name, retriever_id
63-
)
6462
self.name = name
6563
self.project_name = self.project.name
64+
self._persisted = persisted
65+
66+
self.retriever = resolve_retriever(
67+
self._client, self.project, name, retriever_id, persisted
68+
)
6669

67-
if self.retriever is None:
70+
if self.retriever is None and persisted:
6871
if create_if_not_exists:
6972
self.retriever = self._client.retrievers.upsert_retriever(
7073
project_id=self.project.id,
@@ -91,9 +94,13 @@ def retriever_pipelines(self) -> List[RetrieverPipeline]:
9194
def update_retriever_pipelines(
9295
self, pipelines: List[RetrieverPipeline]
9396
) -> Retriever:
94-
self.retriever = self._client.retrievers.update_retriever(
95-
self.retriever.id, pipelines=pipelines
96-
)
97+
if self._persisted:
98+
self.retriever = self._client.retrievers.update_retriever(
99+
self.retriever.id, pipelines=pipelines
100+
)
101+
else:
102+
# Update in-memory retriever for non-persisted case using copy
103+
self.retriever = self.retriever.copy(update={"pipelines": pipelines})
97104
return self.retriever
98105

99106
def add_index(
@@ -138,9 +145,13 @@ def remove_index(self, name: str) -> bool:
138145
async def aupdate_retriever_pipelines(
139146
self, pipelines: List[RetrieverPipeline]
140147
) -> Retriever:
141-
self.retriever = await self._aclient.retrievers.update_retriever(
142-
self.retriever.id, pipelines=pipelines
143-
)
148+
if self._persisted:
149+
self.retriever = await self._aclient.retrievers.update_retriever(
150+
self.retriever.id, pipelines=pipelines
151+
)
152+
else:
153+
# Update in-memory retriever for non-persisted case using copy
154+
self.retriever = self.retriever.copy(update={"pipelines": pipelines})
144155
return self.retriever
145156

146157
async def async_add_index(
@@ -202,12 +213,21 @@ def _retrieve(
202213
) -> List[NodeWithScore]:
203214
mode = mode if mode is not None else self._mode
204215
rerank_top_n = rerank_top_n if rerank_top_n is not None else self._rerank_top_n
205-
result = self._client.retrievers.retrieve(
206-
self.retriever.id,
207-
mode=mode,
208-
rerank_top_n=rerank_top_n,
209-
query=query_bundle.query_str,
210-
)
216+
if self._persisted:
217+
result = self._client.retrievers.retrieve(
218+
self.retriever.id,
219+
mode=mode,
220+
rerank_top_n=rerank_top_n,
221+
query=query_bundle.query_str,
222+
)
223+
else:
224+
result = self._client.retrievers.direct_retrieve(
225+
project_id=self.project.id,
226+
mode=mode,
227+
rerank_top_n=rerank_top_n,
228+
query=query_bundle.query_str,
229+
pipelines=self.retriever.pipelines,
230+
)
211231
node_w_scores = [
212232
self._result_nodes_to_node_with_score(node) for node in result.nodes
213233
]
@@ -226,12 +246,21 @@ async def _aretrieve(
226246
) -> List[NodeWithScore]:
227247
mode = mode if mode is not None else self._mode
228248
rerank_top_n = rerank_top_n if rerank_top_n is not None else self._rerank_top_n
229-
result = await self._aclient.retrievers.retrieve(
230-
self.retriever.id,
231-
mode=mode,
232-
rerank_top_n=rerank_top_n,
233-
query=query_bundle.query_str,
234-
)
249+
if self._persisted:
250+
result = await self._aclient.retrievers.retrieve(
251+
self.retriever.id,
252+
mode=mode,
253+
rerank_top_n=rerank_top_n,
254+
query=query_bundle.query_str,
255+
)
256+
else:
257+
result = await self._aclient.retrievers.direct_retrieve(
258+
project_id=self.project.id,
259+
mode=mode,
260+
rerank_top_n=rerank_top_n,
261+
query=query_bundle.query_str,
262+
pipelines=self.retriever.pipelines,
263+
)
235264
node_w_scores = [
236265
self._result_nodes_to_node_with_score(node) for node in result.nodes
237266
]

llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dev = [
2828

2929
[project]
3030
name = "llama-index-indices-managed-llama-cloud"
31-
version = "0.7.1"
31+
version = "0.7.2"
3232
description = "llama-index indices llama-cloud integration"
3333
authors = [{name = "Logan Markewich", email = "[email protected]"}]
3434
requires-python = ">=3.9,<4.0"

0 commit comments

Comments
 (0)