Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
424 changes: 282 additions & 142 deletions core/database/postgres_database.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion core/embedding/colpali_api_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def _call_api_endpoint(self, endpoint: str, inputs: List[str], input_type:
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {"input_type": input_type, "inputs": inputs}
timeout = Timeout(read=6000.0, connect=6000.0, write=6000.0, pool=6000.0)
timeout = Timeout(read=600.0, connect=30.0, write=600.0, pool=60.0)

async with AsyncClient(timeout=timeout) as client:
resp = await client.post(endpoint, json=payload, headers=headers)
Expand Down
185 changes: 141 additions & 44 deletions core/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from datetime import UTC, datetime
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Set, Tuple, Type, Union, overload

import fitz # PyMuPDF for PDF/presentation processing
from fastapi import HTTPException
Expand Down Expand Up @@ -50,6 +50,8 @@ class DocumentService:
Note: Ingestion operations have been moved to IngestionService.
"""

_PADDING_METADATA_KEY = "__morphik_padding"

def __init__(
self,
database: PostgresDatabase,
Expand Down Expand Up @@ -135,6 +137,48 @@ def _depth(path: str) -> int:
filters["folder_path"] = exact_paths if len(exact_paths) > 1 else exact_paths[0]
return filters

@overload
async def retrieve_chunks(
self,
query: Optional[str],
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 5,
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[Union[str, List[str]]] = None,
folder_depth: Optional[int] = None,
end_user_id: Optional[str] = None,
perf_tracker: Optional[Any] = None,
padding: int = 0,
output_format: str = "base64",
query_image: Optional[bytes] = None,
*,
return_preloaded_docs: Literal[True],
) -> Tuple[List[ChunkResult], Dict[str, Document]]: ...

@overload
async def retrieve_chunks(
self,
query: Optional[str],
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 5,
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[Union[str, List[str]]] = None,
folder_depth: Optional[int] = None,
end_user_id: Optional[str] = None,
perf_tracker: Optional[Any] = None,
padding: int = 0,
output_format: str = "base64",
query_image: Optional[bytes] = None,
*,
return_preloaded_docs: Literal[False] = False,
) -> List[ChunkResult]: ...

async def retrieve_chunks(
self,
query: Optional[str],
Expand All @@ -151,7 +195,8 @@ async def retrieve_chunks(
padding: int = 0, # Number of additional chunks to retrieve before and after matched chunks
output_format: str = "base64",
query_image: Optional[bytes] = None, # Base64-decoded image bytes for visual search
) -> List[ChunkResult]:
return_preloaded_docs: bool = False,
) -> Union[List[ChunkResult], Tuple[List[ChunkResult], Dict[str, Document]]]:
"""Retrieve relevant chunks.

Either query (text) or query_image (image bytes) must be provided.
Expand Down Expand Up @@ -469,7 +514,17 @@ async def timed_auth():
else:
result_creation_start = time.time()

results = await self._create_chunk_results(auth, chunks, output_format=output_format_value)
chunk_result_payload = await self._create_chunk_results(
auth,
chunks,
output_format=output_format_value,
return_doc_map=return_preloaded_docs,
)
retrieved_doc_map: Dict[str, Document] = {}
if return_preloaded_docs:
results, retrieved_doc_map = chunk_result_payload
else:
results = chunk_result_payload

if not perf_tracker:
phase_times["result_creation"] = time.time() - result_creation_start
Expand All @@ -496,6 +551,8 @@ async def timed_auth():
logger.info(f"Returning {len(results)} chunk results")
logger.info("==========================================================")

if return_preloaded_docs:
return results, retrieved_doc_map
return results

async def _apply_padding_to_chunks(
Expand Down Expand Up @@ -619,7 +676,11 @@ def _is_image_chunk(chunk: DocumentChunk) -> bool:
key = (chunk.document_id, chunk.chunk_number)
if key in seen:
continue
is_padding_chunk = key not in original_scores
chunk.score = original_scores.get(key, 0.0)
chunk_metadata = dict(chunk.metadata or {})
chunk_metadata[self._PADDING_METADATA_KEY] = is_padding_chunk
chunk.metadata = chunk_metadata
deduped.append(chunk)
seen.add(key)

Expand Down Expand Up @@ -728,8 +789,11 @@ async def retrieve_chunks_grouped(

Returns both flat results (for backward compatibility) and grouped results (for UI).
"""
# Get original chunks before padding (as ChunkResult objects)
original_chunk_results = await self.retrieve_chunks(
requested_padding = padding if padding > 0 and use_colpali else 0

# Single retrieval call: when padding is requested, retrieve once with padding
# and rely on `is_padding` markers assigned in `_apply_padding_to_chunks`.
final_chunk_results = await self.retrieve_chunks(
query,
auth,
filters,
Expand All @@ -741,31 +805,15 @@ async def retrieve_chunks_grouped(
folder_depth,
end_user_id,
perf_tracker,
padding=0, # No padding for original
padding=requested_padding,
output_format=output_format,
query_image=query_image,
)

# Get final chunks with padding (as ChunkResult objects)
if padding > 0 and use_colpali:
final_chunk_results = await self.retrieve_chunks(
query,
auth,
filters,
k,
min_score,
use_reranking,
use_colpali,
folder_name,
folder_depth,
end_user_id,
perf_tracker,
padding,
output_format=output_format,
query_image=query_image,
)
if requested_padding > 0:
original_chunk_results = [result for result in final_chunk_results if not result.is_padding]
else:
final_chunk_results = original_chunk_results
original_chunk_results = final_chunk_results

# Create grouped response directly from ChunkResult objects
return await self._create_grouped_chunk_response_from_results(
Expand All @@ -787,11 +835,21 @@ async def retrieve_docs(
) -> List[DocumentResult]:
"""Retrieve relevant documents."""
# Get chunks first
chunks = await self.retrieve_chunks(
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, folder_depth, end_user_id
chunks, preloaded_docs = await self.retrieve_chunks(
query,
auth,
filters,
k,
min_score,
use_reranking,
use_colpali,
folder_name,
folder_depth,
end_user_id,
return_preloaded_docs=True,
)
# Convert to document results
results = await self._create_document_results(auth, chunks)
results = await self._create_document_results(auth, chunks, preloaded_docs=preloaded_docs)
documents = list(results.values())
logger.info(f"Returning {len(documents)} document results")
return documents
Expand Down Expand Up @@ -1022,7 +1080,7 @@ async def query(
else:
chunk_retrieval_start = time.time()

chunks = await self.retrieve_chunks(
chunk_results, preloaded_docs = await self.retrieve_chunks(
query,
auth,
filters,
Expand All @@ -1035,6 +1093,7 @@ async def query(
end_user_id,
perf_tracker,
padding,
return_preloaded_docs=True,
)

if not perf_tracker:
Expand All @@ -1046,7 +1105,7 @@ async def query(
else:
doc_results_start = time.time()

documents = await self._create_document_results(auth, chunks)
documents = await self._create_document_results(auth, chunk_results, preloaded_docs=preloaded_docs)

if not perf_tracker:
phase_times["document_results_creation"] = time.time() - doc_results_start
Expand All @@ -1057,13 +1116,13 @@ async def query(
else:
augmentation_start = time.time()

chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunk_results]

# Collect chunk metadata for inline citations if enabled
chunk_metadata = None
if inline_citations:
chunk_metadata = []
for chunk in chunks:
for chunk in chunk_results:
# Get the document for this chunk
doc = documents.get(chunk.document_id, {})
filename = (
Expand Down Expand Up @@ -1101,7 +1160,7 @@ async def query(

sources = [
ChunkSource(document_id=chunk.document_id, chunk_number=chunk.chunk_number, score=chunk.score)
for chunk in chunks
for chunk in chunk_results
]

if not perf_tracker:
Expand Down Expand Up @@ -1174,17 +1233,42 @@ async def query(

return response

@overload
async def _create_chunk_results(
self,
auth: AuthContext,
chunks: List[DocumentChunk],
preloaded_docs: Optional[Dict[str, Document]] = None,
output_format: str = "base64",
) -> List[ChunkResult]:
*,
return_doc_map: Literal[True],
) -> Tuple[List[ChunkResult], Dict[str, Document]]: ...

@overload
async def _create_chunk_results(
self,
auth: AuthContext,
chunks: List[DocumentChunk],
preloaded_docs: Optional[Dict[str, Document]] = None,
output_format: str = "base64",
*,
return_doc_map: Literal[False] = False,
) -> List[ChunkResult]: ...

async def _create_chunk_results(
self,
auth: AuthContext,
chunks: List[DocumentChunk],
preloaded_docs: Optional[Dict[str, Document]] = None,
output_format: str = "base64",
return_doc_map: bool = False,
) -> Union[List[ChunkResult], Tuple[List[ChunkResult], Dict[str, Document]]]:
"""Create ChunkResult objects with document metadata."""
results = []
if not chunks:
logger.info("No chunks provided, returning empty results")
if return_doc_map:
return results, {}
return results

# Collect all unique document IDs from chunks
Expand Down Expand Up @@ -1413,19 +1497,22 @@ async def _convert_image_to_text(content_str: str) -> str:
logger.warning(f"Document {chunk.document_id} not found")
continue

chunk_metadata = dict(chunk.metadata or {})
is_padding = bool(chunk_metadata.pop(self._PADDING_METADATA_KEY, False))

# Start with document metadata, then merge in chunk-specific metadata
metadata = doc.metadata.copy()
# Add all chunk metadata (this includes our XML metadata like unit, xml_id, breadcrumbs, etc.)
metadata.update(chunk.metadata)
metadata.update(chunk_metadata)
# Ensure is_image is set (fallback to False if not present)
metadata["is_image"] = chunk.metadata.get("is_image", False)
metadata["is_image"] = chunk_metadata.get("is_image", False)
# Default values
content_value = chunk.content
download_url: Optional[str] = None

# If requested, convert image chunks to presigned URLs or text
is_img = bool(metadata.get("is_image"))
mime = chunk.metadata.get("mime_type") if isinstance(chunk.metadata, dict) else None
mime = chunk_metadata.get("mime_type") if isinstance(chunk_metadata, dict) else None
# Try to infer from content if metadata was not properly populated
if not is_img and (output_format or "base64") in ("url", "text"):
inferred_mime = _infer_image_mime_from_content(chunk.content)
Expand Down Expand Up @@ -1623,13 +1710,21 @@ def _is_binary_image(b: bytes) -> bool:
content_type=doc.content_type,
filename=doc.filename,
download_url=download_url,
is_padding=is_padding,
)
)

logger.info(f"Created {len(results)} chunk results")
if return_doc_map:
return results, doc_map
return results

async def _create_document_results(self, auth: AuthContext, chunks: List[ChunkResult]) -> Dict[str, DocumentResult]:
async def _create_document_results(
self,
auth: AuthContext,
chunks: List[ChunkResult],
preloaded_docs: Optional[Dict[str, Document]] = None,
) -> Dict[str, DocumentResult]:
"""Group chunks by document and create DocumentResult objects."""
if not chunks:
logger.info("No chunks provided, returning empty results")
Expand All @@ -1645,12 +1740,14 @@ async def _create_document_results(self, auth: AuthContext, chunks: List[ChunkRe
# Get unique document IDs
unique_doc_ids = list(doc_chunks.keys())

# Fetch all documents in a single batch query
docs = await self.batch_retrieve_documents(unique_doc_ids, auth)

# Create a lookup dictionary of documents by ID
doc_map = {doc.external_id: doc for doc in docs}
logger.debug(f"Retrieved metadata for {len(doc_map)} unique documents in a single batch")
doc_map: Dict[str, Document] = dict(preloaded_docs) if preloaded_docs else {}
missing_doc_ids = [doc_id for doc_id in unique_doc_ids if doc_id not in doc_map]
if missing_doc_ids:
docs = await self.batch_retrieve_documents(missing_doc_ids, auth)
doc_map.update({doc.external_id: doc for doc in docs})
logger.debug(f"Retrieved metadata for {len(docs)} additional documents in a single batch")
else:
logger.debug(f"Using preloaded metadata for {len(doc_map)} unique documents")

# Create document results using the lookup dictionaries
results = {}
Expand Down
Loading
Loading