Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class OpensearchVectorClient:
space_type (Optional[str]): space type for distance metric calculation. Defaults to: l2
os_client (Optional[OSClient]): Custom synchronous client (see OpenSearch from opensearch-py)
os_async_client (Optional[OSClient]): Custom asynchronous client (see AsyncOpenSearch from opensearch-py)
excluded_source_fields (Optional[List[str]]): Optional list of document "source" fields to exclude from OpenSearch responses.
**kwargs: Optional arguments passed to the OpenSearch client from opensearch-py.

"""
Expand All @@ -77,6 +78,7 @@ def __init__(
search_pipeline: Optional[str] = None,
os_client: Optional[OSClient] = None,
os_async_client: Optional[OSClient] = None,
excluded_source_fields: Optional[List[str]] = None,
**kwargs: Any,
):
"""Init params."""
Expand All @@ -99,6 +101,7 @@ def __init__(
self._index = index
self._text_field = text_field
self._max_chunk_bytes = max_chunk_bytes
self._excluded_source_fields = excluded_source_fields

self._search_pipeline = search_pipeline
http_auth = kwargs.get("http_auth")
Expand Down Expand Up @@ -328,6 +331,7 @@ def _default_approximate_search_query(
k: int = 4,
filters: Optional[Union[Dict, List]] = None,
vector_field: str = "embedding",
excluded_source_fields: Optional[List[str]] = None,
) -> Dict:
"""For Approximate k-NN Search, this is the default query."""
query = {
Expand All @@ -345,6 +349,8 @@ def _default_approximate_search_query(
if filters:
# filter key must be added only when filtering to avoid "filter doesn't support values of type: START_ARRAY" exception
query["query"]["knn"][vector_field]["filter"] = filters
if excluded_source_fields:
query["_source"] = {"exclude": excluded_source_fields}
return query

def _is_text_field(self, value: Any) -> bool:
Expand Down Expand Up @@ -447,6 +453,7 @@ def _knn_search_query(
k: int,
filters: Optional[MetadataFilters] = None,
search_method="approximate",
excluded_source_fields: Optional[List[str]] = None,
) -> Dict:
"""
Perform a k-Nearest Neighbors (kNN) search.
Expand All @@ -465,6 +472,7 @@ def _knn_search_query(
filters (Optional[MetadataFilters]): Optional filters to apply for the search.
Supports filter-context queries documented at
https://opensearch.org/docs/latest/query-dsl/query-filter-context/
excluded_source_fields: Optional list of document "source" fields to exclude from the response.

Returns:
Dict: Up to k documents closest to query_embedding.
Expand All @@ -477,6 +485,7 @@ def _knn_search_query(
query_embedding,
k,
vector_field=embedding_field,
excluded_source_fields=excluded_source_fields,
)
elif (
search_method == "approximate"
Expand All @@ -493,6 +502,7 @@ def _knn_search_query(
k,
filters={"bool": {"filter": filters}},
vector_field=embedding_field,
excluded_source_fields=excluded_source_fields,
)
else:
if self.is_aoss:
Expand All @@ -504,6 +514,7 @@ def _knn_search_query(
space_type=self.space_type,
pre_filter={"bool": {"filter": filters}},
vector_field=embedding_field,
excluded_source_fields=excluded_source_fields,
)
else:
# https://opensearch.org/docs/latest/search-plugins/knn/painless-functions/
Expand All @@ -513,6 +524,7 @@ def _knn_search_query(
space_type="l2Squared",
pre_filter={"bool": {"filter": filters}},
vector_field=embedding_field,
excluded_source_fields=excluded_source_fields,
)
return search_query

Expand All @@ -524,23 +536,28 @@ def _hybrid_search_query(
query_embedding: List[float],
k: int,
filters: Optional[MetadataFilters] = None,
excluded_source_fields: Optional[List[str]] = None,
) -> Dict:
knn_query = self._knn_search_query(embedding_field, query_embedding, k, filters)
lexical_query = self._lexical_search_query(text_field, query_str, k, filters)

return {
query = {
"size": k,
"query": {
"hybrid": {"queries": [lexical_query["query"], knn_query["query"]]}
},
}
if excluded_source_fields:
query["_source"] = {"exclude": excluded_source_fields}
return query

def _lexical_search_query(
self,
text_field: str,
query_str: str,
k: int,
filters: Optional[MetadataFilters] = None,
excluded_source_fields: Optional[List[str]] = None,
) -> Dict:
lexical_query = {
"bool": {"must": {"match": {text_field: {"query": query_str}}}}
Expand All @@ -550,10 +567,13 @@ def _lexical_search_query(
if len(parsed_filters) > 0:
lexical_query["bool"]["filter"] = parsed_filters

return {
query = {
"size": k,
"query": lexical_query,
}
if excluded_source_fields:
query["_source"] = {"exclude": excluded_source_fields}
return query

def __get_painless_scripting_source(
self, space_type: str, vector_field: str = "embedding"
Expand Down Expand Up @@ -599,6 +619,7 @@ def _default_scoring_script_query(
space_type: str = "l2Squared",
pre_filter: Optional[Union[Dict, List]] = None,
vector_field: str = "embedding",
excluded_source_fields: Optional[List[str]] = None,
) -> Dict:
"""
For Scoring Script Search, this is the default query. Has to account for Opensearch Service
Expand All @@ -620,7 +641,7 @@ def _default_scoring_script_query(
script = self._get_painless_scoring_script(
space_type, vector_field, query_vector
)
return {
query = {
"size": k,
"query": {
"script_score": {
Expand All @@ -629,6 +650,9 @@ def _default_scoring_script_query(
}
},
}
if excluded_source_fields:
query["_source"] = {"exclude": excluded_source_fields}
return query

def _is_aoss_enabled(self, http_auth: Any) -> bool:
"""Check if the service is http_auth is set as `aoss`."""
Expand Down Expand Up @@ -817,18 +841,27 @@ def query(
query_embedding,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = {
"search_pipeline": self._search_pipeline,
}
elif query_mode == VectorStoreQueryMode.TEXT_SEARCH:
search_query = self._lexical_search_query(
self._text_field, query_str, k, filters=filters
self._text_field,
query_str,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = None
else:
search_query = self._knn_search_query(
self._embedding_field, query_embedding, k, filters=filters
self._embedding_field,
query_embedding,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = None

Expand Down Expand Up @@ -856,18 +889,27 @@ async def aquery(
query_embedding,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = {
"search_pipeline": self._search_pipeline,
}
elif query_mode == VectorStoreQueryMode.TEXT_SEARCH:
search_query = self._lexical_search_query(
self._text_field, query_str, k, filters=filters
self._text_field,
query_str,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = None
else:
search_query = self._knn_search_query(
self._embedding_field, query_embedding, k, filters=filters
self._embedding_field,
query_embedding,
k,
filters=filters,
excluded_source_fields=self._excluded_source_fields,
)
params = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev = [

[project]
name = "llama-index-vector-stores-opensearch"
version = "0.5.5"
version = "0.5.6"
description = "llama-index vector_stores opensearch integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: "3"

services:
opensearch:
image: opensearchproject/opensearch:latest
image: opensearchproject/opensearch:2.19.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using :latest (which is OpenSearch version 3) causes the following exception on the OpenSearch side when running the tests, thus causing almost all of them to fail:

org.opensearch.index.mapper.MapperParsingException: Failed to parse mapping [_doc]: nmslib engine is deprecated in OpenSearch  and cannot be used for new index creation in OpenSearch from  3.0.0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like something that needs to be fixed in the future 😞

environment:
- discovery.type=single-node
- plugins.security.disabled=true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MetadataFilter,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
)

##
Expand All @@ -26,7 +27,7 @@
# docker-compose up
#
# Run tests
# pytest test_opensearch_client.py
# uv run -- pytest test_opensearch_client.py

logging.basicConfig(level=logging.DEBUG)
evt_loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -883,3 +884,87 @@ def test_efficient_filtering_used_when_enabled(os_stores: List[OpensearchVectorS
embedding_field="embedding", query_embedding=[1], k=20, filters=filters
)
assert patched_default_approximate_search_query.called


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
@pytest.mark.parametrize(
"query_mode",
[
VectorStoreQueryMode.DEFAULT,
VectorStoreQueryMode.TEXT_SEARCH,
VectorStoreQueryMode.HYBRID,
],
)
def test_excluded_source_fields(
os_stores: List[OpensearchVectorStore],
node_embeddings: List[TextNode],
query_mode: VectorStoreQueryMode,
):
os_store = os_stores[0]
os_store.add(node_embeddings)
os_store.client._search_pipeline = "search_pipeline" # value doesn't matter

# set excluded source fields
excluded_fields = ["embedding"]
os_store.client._excluded_source_fields = excluded_fields

with mock.patch.object(os_store.client._os_client, "search") as patched_search:
exp_node = node_embeddings[3]
query = VectorStoreQuery(
query_embedding=exp_node.embedding,
similarity_top_k=1,
mode=query_mode,
query_str=exp_node.text,
)
os_store.query(query)
assert patched_search.called

kwargs = patched_search.call_args.kwargs
body = kwargs["body"]
assert "_source" in body
assert "exclude" in body["_source"]
assert body["_source"]["exclude"] == excluded_fields

kwargs.pop("params", None) # params not needed, even when testing hybrid
res = os_store.client._os_client.search(params=None, **kwargs)
assert len(res["hits"]["hits"]) > 0
for hit in res["hits"]["hits"]:
source = hit["_source"]
for ex in excluded_fields:
assert ex not in source


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
@pytest.mark.parametrize(
"query_mode",
[
VectorStoreQueryMode.DEFAULT,
VectorStoreQueryMode.TEXT_SEARCH,
VectorStoreQueryMode.HYBRID,
],
)
def test_no_excluded_source_fields(
os_stores: List[OpensearchVectorStore],
node_embeddings: List[TextNode],
query_mode: VectorStoreQueryMode,
) -> None:
os_store = os_stores[0]
os_store.add(node_embeddings)
os_store.client._search_pipeline = "search_pipeline" # value doesn't matter

# explicitly show `None` for excluded source fields (default)
os_store.client._excluded_source_fields = None

with mock.patch.object(os_store.client._os_client, "search") as patched_search:
exp_node = node_embeddings[3]
query = VectorStoreQuery(
query_embedding=exp_node.embedding,
similarity_top_k=1,
mode=query_mode,
query_str=exp_node.text,
)
os_store.query(query)
assert patched_search.called

body = patched_search.call_args.kwargs["body"]
assert "_source" not in body
Loading