"""
Example is based on:
- https://huggingface.co/Snowflake/snowflake-arctic-embed-xs
- https://docs.ray.io/en/latest/serve/getting_started.html
"""
import json
import sys
from time import time
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from langchain_text_splitters import CharacterTextSplitter
from ray import serve
from starlette.requests import Request
from starlette.responses import JSONResponse
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
HUB_MODEL_NAME = "Snowflake/snowflake-arctic-embed-xs"
# Truncate to 150k characters to avoid timeout errors from model
# and 10 Chunk limit to protect Fusion document count
CHAR_TRUNCATION = 150_000
MAX_CHUNKS = 10
@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 1})
class Deployment(object):
def __init__(self):
from loguru import logger
# Initializing logger
self.logger = logger
self.logger.remove()
self.logger.add(sys.stdout, level="INFO", serialize=False, colorize=True)
# Initializing model
self.logger.info("Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_NAME)
self.model = AutoModel.from_pretrained(HUB_MODEL_NAME)
self.model.eval()
self.vector_size = 384
self.logger.info("Model initialization finished!")
# snowflake-arctic-embed-xs specific prefix (NOT for other models)
self.query_prefix = "Represent this sentence for searching relevant passages:"
self.passage_prefix = ""
def main(self, input_dict: Dict[str, Any]) -> Dict[str, Any]:
try:
timings_dict = {}
_start_time = time()
# Extracting text from input
text = input_dict.get("text", None)
include_text_chunks = self.obj_to_bool(
input_dict.get("include_text_chunks", False)
)
if text is None: # Check if text is None and return
self.logger.error("No text provided in the input dictionary.")
return JSONResponse(
{"error": "No `text`input key provided"}, status_code=400
)
elif (
text.isspace() or len(text) == 0
): # Check if the text is empty and return
self.logger.warning("Empty text provided. Returning empty output.")
return self.create_output_dict(
spans=[],
vectors=[],
chunks=[],
include_text_chunks=include_text_chunks,
)
elif len(text) > CHAR_TRUNCATION: # Check if the text is too long
text = text[:CHAR_TRUNCATION]
self.logger.warning(
f"Input text truncated to {CHAR_TRUNCATION} characters"
f" to avoid Fusion stage timeout errors."
)
# Log the length of the text
self.logger.debug(f"Text len: {len(text)}")
# Check DataType
dataType = input_dict.get("dataType", "passage")
# Check if quantization is set
quantize_bool = self.obj_to_bool(input_dict.get("quantize", False))
# Check if chunks should be included
# chunking the text
chunks = self.chunk_text(text)
initial_chunk_count = len(chunks)
# Limit the number of chunks to Expolding Fusion Document Count
if initial_chunk_count > MAX_CHUNKS:
chunks = chunks[:MAX_CHUNKS]
self.logger.warning(
f"Input text chunked into {initial_chunk_count} chunks. "
f"Only the first {MAX_CHUNKS} chunks will be processed."
)
timings_dict["chunking_time"] = time()
spans = self.get_chunk_spans(text=text, chunks=chunks)
timings_dict["span_extraction_time"] = time()
# encoding the chunks
vectors = self.encode_chunks(chunks, dataType=dataType)
timings_dict["encoding_time"] = time()
# Check if quantization is set
if quantize_bool:
vectors = self.quantize_vectors(embeddings=vectors)
timings_dict["quantization_time"] = time()
# Create output dictionary
output_dict = self.create_output_dict(
spans=spans,
vectors=vectors,
chunks=chunks,
include_text_chunks=include_text_chunks,
)
self.log_response_timings(
action_name="chunk_via_ray",
start_time=_start_time,
timings_dict=timings_dict,
)
return output_dict
except Exception as e:
self.logger.error(f"An error occurred: {e}")
return JSONResponse({"error": str(e)}, status_code=500)
def encode_chunks(self, chunks: List[str], dataType: str) -> Tensor:
# Add prefix to the chunks based on the dataType
if dataType == "passage":
chunks = [f"{self.passage_prefix} {text}".strip() for text in chunks]
elif dataType == "query":
chunks = [f"{self.query_prefix} {text}".strip() for text in chunks]
# Tokenization
tokenized_texts = self.tokenizer(
chunks,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
)
# Encoding (Model Specific Please check the model documentation)
with torch.inference_mode():
# Vectorization
model_output = self.model(**tokenized_texts)
embeddings = self.cls_pooling(model_output.last_hidden_state)
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def create_output_dict(
self,
spans: List[List[int]],
vectors: Tensor,
chunks: Optional[List[str]],
include_text_chunks: bool = False,
) -> Dict[str, Any]:
vectors = self.format_vectors(vectors=vectors, chunks_len=len(chunks))
output_dict = {
"spans": spans,
"vectors": vectors,
}
if include_text_chunks:
output_dict["chunks"] = chunks
return {
"chunkedData": json.dumps(output_dict),
}
@staticmethod
def format_vectors(vectors: Tensor, chunks_len: int) -> List[Dict[str, Any]]:
vectors = vectors if chunks_len > 1 else [vectors]
# Convert the tensor to a list of dictionaries
formatted_vectors = []
for vector in vectors:
formatted_vectors.append({"vector": vector.squeeze().tolist()})
return formatted_vectors
@staticmethod
def cls_pooling(encoded: torch.Tensor) -> torch.Tensor:
return encoded[:, 0, :]
@staticmethod
def chunk_text(text: str) -> List[str]:
separator = "\n\n" if text.count("\n\n") > 0 else "\n"
# Initialize the text splitter
text_splitter = CharacterTextSplitter(
separator=separator,
chunk_size=512,
chunk_overlap=0,
length_function=len,
)
# Split the text into chunks
chunks = text_splitter.split_text(text)
return chunks
@staticmethod
def get_chunk_spans(text: str, chunks: List[str]) -> List[List[int]]:
# Initialize spans list
spans = []
for chunk in chunks:
start = text.find(chunk)
end = start + len(chunk)
spans.append([start, end])
return spans
@staticmethod
def quantize_vectors(embeddings: Tensor) -> Tensor:
min_val = torch.min(embeddings, dim=1, keepdim=True).values
max_val = torch.max(embeddings, dim=1, keepdim=True).values
scale = (max_val - min_val).clamp(min=1e-8)
normalized = (embeddings - min_val) / scale
# Makes signed byte compatible
quantized = normalized * 255 - 128
quantized = torch.round(quantized).clamp(-128, 127).to(torch.int8)
return quantized
@staticmethod
def obj_to_bool(s: Any) -> bool:
if isinstance(s, bool):
return s
elif isinstance(s, int):
return s != 0
elif isinstance(s, str):
return str(s).strip().lower() == "true"
else:
return False
def log_response_timings(
self,
action_name: str,
start_time: float,
timings_dict: Optional[Dict[str, float]] = None,
) -> None:
timings_str = (
f"Time taken to {action_name} input: {(time() - start_time) * 1000:.1f}ms"
)
if timings_dict is not None:
timings_dict_str = {}
previous_time = start_time
for k, v in timings_dict.items():
timings_dict_str[k] = f"{(v - previous_time) * 1000:.1f}ms"
previous_time = v
timings_str += f" {timings_dict_str}"
self.logger.info(timings_str)
async def __call__(self, http_request: Request) -> Dict[str, Any]:
try:
input_dict: Dict[str, Any] = await http_request.json()
except UnicodeDecodeError:
body_bytes = await http_request.body()
try:
decoded = body_bytes.decode("utf-8", errors="replace")
input_dict = json.loads(decoded)
except json.JSONDecodeError:
return JSONResponse({"error": "Invalid JSON"}, status_code=400)
return self.main(input_dict=input_dict)
app = Deployment.bind()