11import multiprocessing as mp
22from typing import List , Tuple
33
4- from cassandra .cluster import Cluster , ExecutionProfile , EXEC_PROFILE_DEFAULT
5- from cassandra .policies import DCAwareRoundRobinPolicy , TokenAwarePolicy , ExponentialReconnectionPolicy
64from cassandra import ConsistencyLevel , ProtocolVersion
5+ from cassandra .cluster import EXEC_PROFILE_DEFAULT , Cluster , ExecutionProfile
6+ from cassandra .policies import (
7+ DCAwareRoundRobinPolicy ,
8+ ExponentialReconnectionPolicy ,
9+ TokenAwarePolicy ,
10+ )
711
812from dataset_reader .base_reader import Query
913from engine .base_client .distances import Distance
@@ -24,20 +28,22 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
2428 profile = ExecutionProfile (
2529 load_balancing_policy = TokenAwarePolicy (DCAwareRoundRobinPolicy ()),
2630 consistency_level = ConsistencyLevel .LOCAL_ONE , # Use LOCAL_ONE for faster reads
27- request_timeout = 60
31+ request_timeout = 60 ,
2832 )
29-
33+
3034 # Initialize Cassandra cluster connection
3135 cls .cluster = Cluster (
32- contact_points = [host ],
36+ contact_points = [host ],
3337 execution_profiles = {EXEC_PROFILE_DEFAULT : profile },
34- reconnection_policy = ExponentialReconnectionPolicy (base_delay = 1 , max_delay = 60 ),
38+ reconnection_policy = ExponentialReconnectionPolicy (
39+ base_delay = 1 , max_delay = 60
40+ ),
3541 protocol_version = ProtocolVersion .V4 ,
36- ** connection_params
42+ ** connection_params ,
3743 )
3844 cls .session = cls .cluster .connect (CASSANDRA_KEYSPACE )
3945 cls .search_params = search_params
40-
46+
4147 # Update prepared statements with current search parameters
4248 cls .update_prepared_statements (distance )
4349
@@ -50,7 +56,7 @@ def update_prepared_statements(cls, distance):
5056 """Create prepared statements for vector searches"""
5157 # Prepare a vector similarity search query
5258 limit = cls .search_params .get ("top" , 10 )
53-
59+
5460 if distance == Distance .COSINE :
5561 SIMILARITY_FUNC = "similarity_cosine"
5662 elif distance == Distance .L2 :
@@ -61,48 +67,49 @@ def update_prepared_statements(cls, distance):
6167 raise ValueError (f"Unsupported distance metric: { distance } " )
6268
6369 cls .ann_search_stmt = cls .session .prepare (
64- f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
65- FROM { CASSANDRA_TABLE }
70+ f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
71+ FROM { CASSANDRA_TABLE }
6672 ORDER BY embedding ANN OF ?
6773 LIMIT { limit } """
6874 )
69-
75+
7076 # Prepare a statement for filtered vector search
71- cls .filtered_search_query_template = (
72- f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
73- FROM { CASSANDRA_TABLE }
77+ cls .filtered_search_query_template = f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
78+ FROM { CASSANDRA_TABLE }
7479 WHERE {{conditions}}
7580 ORDER BY embedding ANN OF ?
7681 LIMIT { limit } """
77- )
7882
7983 @classmethod
8084 def search_one (cls , query : Query , top : int ) -> List [Tuple [int , float ]]:
8185 """Execute a vector similarity search with optional filters"""
8286 # Convert query vector to a format Cassandra can use
83- query_vector = query .vector .tolist () if hasattr (query .vector , 'tolist' ) else query .vector
84-
87+ query_vector = (
88+ query .vector .tolist () if hasattr (query .vector , "tolist" ) else query .vector
89+ )
90+
8591 # Generate filter conditions if metadata conditions exist
8692 filter_conditions = cls .parser .parse (query .meta_conditions )
87-
93+
8894 try :
8995 if filter_conditions :
9096 # Use the filtered search query
91- query_with_conditions = cls .filtered_search_query_template .format (conditions = filter_conditions )
97+ query_with_conditions = cls .filtered_search_query_template .format (
98+ conditions = filter_conditions
99+ )
92100 results = cls .session .execute (
93101 cls .session .prepare (query_with_conditions ),
94- (query_vector , query_vector )
102+ (query_vector , query_vector ),
95103 )
96104 else :
97105 # Use the basic ANN search query
98106 results = cls .session .execute (
99- cls .ann_search_stmt ,
100- (query_vector , query_vector )
107+ cls .ann_search_stmt , (query_vector , query_vector )
101108 )
102-
109+
103110 # Extract and return results
104111 return [(row .id , row .distance ) for row in results ]
105-
112+
106113 except Exception as ex :
107114 print (f"Error during Cassandra vector search: { ex } " )
108115 raise ex
@@ -113,4 +120,4 @@ def delete_client(cls):
113120 if cls .session :
114121 cls .session .shutdown ()
115122 if cls .cluster :
116- cls .cluster .shutdown ()
123+ cls .cluster .shutdown ()
0 commit comments