@@ -45,6 +45,7 @@ def __init__(
4545 # composite retrieval params
4646 mode : Optional [CompositeRetrievalMode ] = None ,
4747 rerank_top_n : Optional [int ] = None ,
48+ persisted : Optional [bool ] = True ,
4849 ** kwargs : Any ,
4950 ) -> None :
5051 """Initialize the Composite Retriever."""
@@ -58,13 +59,15 @@ def __init__(
5859 self ._client , project_name , project_id , organization_id
5960 )
6061
61- self .retriever = resolve_retriever (
62- self ._client , self .project , name , retriever_id
63- )
6462 self .name = name
6563 self .project_name = self .project .name
64+ self ._persisted = persisted
65+
66+ self .retriever = resolve_retriever (
67+ self ._client , self .project , name , retriever_id , persisted
68+ )
6669
67- if self .retriever is None :
70+ if self .retriever is None and persisted :
6871 if create_if_not_exists :
6972 self .retriever = self ._client .retrievers .upsert_retriever (
7073 project_id = self .project .id ,
@@ -91,9 +94,13 @@ def retriever_pipelines(self) -> List[RetrieverPipeline]:
9194 def update_retriever_pipelines (
9295 self , pipelines : List [RetrieverPipeline ]
9396 ) -> Retriever :
94- self .retriever = self ._client .retrievers .update_retriever (
95- self .retriever .id , pipelines = pipelines
96- )
97+ if self ._persisted :
98+ self .retriever = self ._client .retrievers .update_retriever (
99+ self .retriever .id , pipelines = pipelines
100+ )
101+ else :
102+ # Update in-memory retriever for non-persisted case using copy
103+ self .retriever = self .retriever .copy (update = {"pipelines" : pipelines })
97104 return self .retriever
98105
99106 def add_index (
@@ -138,9 +145,13 @@ def remove_index(self, name: str) -> bool:
138145 async def aupdate_retriever_pipelines (
139146 self , pipelines : List [RetrieverPipeline ]
140147 ) -> Retriever :
141- self .retriever = await self ._aclient .retrievers .update_retriever (
142- self .retriever .id , pipelines = pipelines
143- )
148+ if self ._persisted :
149+ self .retriever = await self ._aclient .retrievers .update_retriever (
150+ self .retriever .id , pipelines = pipelines
151+ )
152+ else :
153+ # Update in-memory retriever for non-persisted case using copy
154+ self .retriever = self .retriever .copy (update = {"pipelines" : pipelines })
144155 return self .retriever
145156
146157 async def async_add_index (
@@ -202,12 +213,21 @@ def _retrieve(
202213 ) -> List [NodeWithScore ]:
203214 mode = mode if mode is not None else self ._mode
204215 rerank_top_n = rerank_top_n if rerank_top_n is not None else self ._rerank_top_n
205- result = self ._client .retrievers .retrieve (
206- self .retriever .id ,
207- mode = mode ,
208- rerank_top_n = rerank_top_n ,
209- query = query_bundle .query_str ,
210- )
216+ if self ._persisted :
217+ result = self ._client .retrievers .retrieve (
218+ self .retriever .id ,
219+ mode = mode ,
220+ rerank_top_n = rerank_top_n ,
221+ query = query_bundle .query_str ,
222+ )
223+ else :
224+ result = self ._client .retrievers .direct_retrieve (
225+ project_id = self .project .id ,
226+ mode = mode ,
227+ rerank_top_n = rerank_top_n ,
228+ query = query_bundle .query_str ,
229+ pipelines = self .retriever .pipelines ,
230+ )
211231 node_w_scores = [
212232 self ._result_nodes_to_node_with_score (node ) for node in result .nodes
213233 ]
@@ -226,12 +246,21 @@ async def _aretrieve(
226246 ) -> List [NodeWithScore ]:
227247 mode = mode if mode is not None else self ._mode
228248 rerank_top_n = rerank_top_n if rerank_top_n is not None else self ._rerank_top_n
229- result = await self ._aclient .retrievers .retrieve (
230- self .retriever .id ,
231- mode = mode ,
232- rerank_top_n = rerank_top_n ,
233- query = query_bundle .query_str ,
234- )
249+ if self ._persisted :
250+ result = await self ._aclient .retrievers .retrieve (
251+ self .retriever .id ,
252+ mode = mode ,
253+ rerank_top_n = rerank_top_n ,
254+ query = query_bundle .query_str ,
255+ )
256+ else :
257+ result = await self ._aclient .retrievers .direct_retrieve (
258+ project_id = self .project .id ,
259+ mode = mode ,
260+ rerank_top_n = rerank_top_n ,
261+ query = query_bundle .query_str ,
262+ pipelines = self .retriever .pipelines ,
263+ )
235264 node_w_scores = [
236265 self ._result_nodes_to_node_with_score (node ) for node in result .nodes
237266 ]
0 commit comments