36 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import Literal
 | |
| 
 | |
| from fastapi import APIRouter, Depends, Request
 | |
| from pydantic import BaseModel
 | |
| 
 | |
| from private_gpt.server.embeddings.embeddings_service import (
 | |
|     Embedding,
 | |
|     EmbeddingsService,
 | |
| )
 | |
| from private_gpt.server.utils.auth import authenticated
 | |
| 
 | |
| embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
 | |
| 
 | |
| 
 | |
| class EmbeddingsBody(BaseModel):
 | |
|     input: str | list[str]
 | |
| 
 | |
| 
 | |
| class EmbeddingsResponse(BaseModel):
 | |
|     object: Literal["list"]
 | |
|     model: Literal["private-gpt"]
 | |
|     data: list[Embedding]
 | |
| 
 | |
| 
 | |
| @embeddings_router.post("/embeddings", tags=["Embeddings"])
 | |
| def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
 | |
|     """Get a vector representation of a given input.
 | |
| 
 | |
|     That vector representation can be easily consumed
 | |
|     by machine learning models and algorithms.
 | |
|     """
 | |
|     service = request.state.injector.get(EmbeddingsService)
 | |
|     input_texts = body.input if isinstance(body.input, list) else [body.input]
 | |
|     embeddings = service.texts_embeddings(input_texts)
 | |
|     return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
 |