117 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			117 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
| from collections.abc import Sequence
 | |
| from typing import TYPE_CHECKING, Any
 | |
| 
 | |
| from injector import inject, singleton
 | |
| from llama_index import ServiceContext, StorageContext, VectorStoreIndex
 | |
| from llama_index.chat_engine import ContextChatEngine
 | |
| from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
 | |
| from llama_index.llm_predictor.utils import stream_chat_response_to_tokens
 | |
| from llama_index.llms import ChatMessage
 | |
| from llama_index.types import TokenGen
 | |
| 
 | |
| from private_gpt.components.embedding.embedding_component import EmbeddingComponent
 | |
| from private_gpt.components.llm.llm_component import LLMComponent
 | |
| from private_gpt.components.node_store.node_store_component import NodeStoreComponent
 | |
| from private_gpt.components.vector_store.vector_store_component import (
 | |
|     VectorStoreComponent,
 | |
| )
 | |
| from private_gpt.open_ai.extensions.context_filter import ContextFilter
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from llama_index.chat_engine.types import (
 | |
|         AgentChatResponse,
 | |
|         StreamingAgentChatResponse,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @singleton
 | |
| class ChatService:
 | |
|     @inject
 | |
|     def __init__(
 | |
|         self,
 | |
|         llm_component: LLMComponent,
 | |
|         vector_store_component: VectorStoreComponent,
 | |
|         embedding_component: EmbeddingComponent,
 | |
|         node_store_component: NodeStoreComponent,
 | |
|     ) -> None:
 | |
|         self.llm_service = llm_component
 | |
|         self.vector_store_component = vector_store_component
 | |
|         self.storage_context = StorageContext.from_defaults(
 | |
|             vector_store=vector_store_component.vector_store,
 | |
|             docstore=node_store_component.doc_store,
 | |
|             index_store=node_store_component.index_store,
 | |
|         )
 | |
|         self.service_context = ServiceContext.from_defaults(
 | |
|             llm=llm_component.llm, embed_model=embedding_component.embedding_model
 | |
|         )
 | |
|         self.index = VectorStoreIndex.from_vector_store(
 | |
|             vector_store_component.vector_store,
 | |
|             storage_context=self.storage_context,
 | |
|             service_context=self.service_context,
 | |
|             show_progress=True,
 | |
|         )
 | |
| 
 | |
|     def _chat_with_contex(
 | |
|         self,
 | |
|         message: str,
 | |
|         context_filter: ContextFilter | None = None,
 | |
|         chat_history: Sequence[ChatMessage] | None = None,
 | |
|         streaming: bool = False,
 | |
|     ) -> Any:
 | |
|         vector_index_retriever = self.vector_store_component.get_retriever(
 | |
|             index=self.index, context_filter=context_filter
 | |
|         )
 | |
|         chat_engine = ContextChatEngine.from_defaults(
 | |
|             retriever=vector_index_retriever,
 | |
|             service_context=self.service_context,
 | |
|             node_postprocessors=[
 | |
|                 MetadataReplacementPostProcessor(target_metadata_key="window"),
 | |
|             ],
 | |
|         )
 | |
|         if streaming:
 | |
|             result = chat_engine.stream_chat(message, chat_history)
 | |
|         else:
 | |
|             result = chat_engine.chat(message, chat_history)
 | |
|         return result
 | |
| 
 | |
|     def stream_chat(
 | |
|         self,
 | |
|         messages: list[ChatMessage],
 | |
|         use_context: bool = False,
 | |
|         context_filter: ContextFilter | None = None,
 | |
|     ) -> TokenGen:
 | |
|         if use_context:
 | |
|             last_message = messages[-1].content
 | |
|             response: StreamingAgentChatResponse = self._chat_with_contex(
 | |
|                 message=last_message if last_message is not None else "",
 | |
|                 chat_history=messages[:-1],
 | |
|                 context_filter=context_filter,
 | |
|                 streaming=True,
 | |
|             )
 | |
|             response_gen = response.response_gen
 | |
|         else:
 | |
|             stream = self.llm_service.llm.stream_chat(messages)
 | |
|             response_gen = stream_chat_response_to_tokens(stream)
 | |
|         return response_gen
 | |
| 
 | |
|     def chat(
 | |
|         self,
 | |
|         messages: list[ChatMessage],
 | |
|         use_context: bool = False,
 | |
|         context_filter: ContextFilter | None = None,
 | |
|     ) -> str:
 | |
|         if use_context:
 | |
|             last_message = messages[-1].content
 | |
|             wrapped_response: AgentChatResponse = self._chat_with_contex(
 | |
|                 message=last_message if last_message is not None else "",
 | |
|                 chat_history=messages[:-1],
 | |
|                 context_filter=context_filter,
 | |
|                 streaming=False,
 | |
|             )
 | |
|             response = wrapped_response.response
 | |
|         else:
 | |
|             chat_response = self.llm_service.llm.chat(messages)
 | |
|             response_content = chat_response.message.content
 | |
|             response = response_content if response_content is not None else ""
 | |
|         return response
 |