88 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			88 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import Any
 | |
| 
 | |
| from llama_index.schema import BaseNode, MetadataMode
 | |
| from llama_index.vector_stores import ChromaVectorStore
 | |
| from llama_index.vector_stores.chroma import chunk_list
 | |
| from llama_index.vector_stores.utils import node_to_metadata_dict
 | |
| 
 | |
| 
 | |
| class BatchedChromaVectorStore(ChromaVectorStore):
 | |
|     """Chroma vector store, batching additions to avoid reaching the max batch limit.
 | |
| 
 | |
|     In this vector store, embeddings are stored within a ChromaDB collection.
 | |
| 
 | |
|     During query time, the index uses ChromaDB to query for the top
 | |
|     k most similar nodes.
 | |
| 
 | |
|     Args:
 | |
|         chroma_client (from chromadb.api.API):
 | |
|             API instance
 | |
|         chroma_collection (chromadb.api.models.Collection.Collection):
 | |
|             ChromaDB collection instance
 | |
| 
 | |
|     """
 | |
| 
 | |
|     chroma_client: Any | None
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         chroma_client: Any,
 | |
|         chroma_collection: Any,
 | |
|         host: str | None = None,
 | |
|         port: str | None = None,
 | |
|         ssl: bool = False,
 | |
|         headers: dict[str, str] | None = None,
 | |
|         collection_kwargs: dict[Any, Any] | None = None,
 | |
|     ) -> None:
 | |
|         super().__init__(
 | |
|             chroma_collection=chroma_collection,
 | |
|             host=host,
 | |
|             port=port,
 | |
|             ssl=ssl,
 | |
|             headers=headers,
 | |
|             collection_kwargs=collection_kwargs or {},
 | |
|         )
 | |
|         self.chroma_client = chroma_client
 | |
| 
 | |
|     def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
 | |
|         """Add nodes to index, batching the insertion to avoid issues.
 | |
| 
 | |
|         Args:
 | |
|             nodes: List[BaseNode]: list of nodes with embeddings
 | |
|             add_kwargs: _
 | |
|         """
 | |
|         if not self.chroma_client:
 | |
|             raise ValueError("Client not initialized")
 | |
| 
 | |
|         if not self._collection:
 | |
|             raise ValueError("Collection not initialized")
 | |
| 
 | |
|         max_chunk_size = self.chroma_client.max_batch_size
 | |
|         node_chunks = chunk_list(nodes, max_chunk_size)
 | |
| 
 | |
|         all_ids = []
 | |
|         for node_chunk in node_chunks:
 | |
|             embeddings = []
 | |
|             metadatas = []
 | |
|             ids = []
 | |
|             documents = []
 | |
|             for node in node_chunk:
 | |
|                 embeddings.append(node.get_embedding())
 | |
|                 metadatas.append(
 | |
|                     node_to_metadata_dict(
 | |
|                         node, remove_text=True, flat_metadata=self.flat_metadata
 | |
|                     )
 | |
|                 )
 | |
|                 ids.append(node.node_id)
 | |
|                 documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
 | |
| 
 | |
|             self._collection.add(
 | |
|                 embeddings=embeddings,
 | |
|                 ids=ids,
 | |
|                 metadatas=metadatas,
 | |
|                 documents=documents,
 | |
|             )
 | |
|             all_ids.extend(ids)
 | |
| 
 | |
|         return all_ids
 |