Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- feat(example): use MTMD batch encoding by @abetlen in #2301
- feat(example): support server video inputs and Gemma text tool calls by @abetlen in #2291
- feat: update llama.cpp to ggml-org/llama.cpp@f05cf4676
- fix(example): support multi-step Responses tool streaming by @abetlen in #2288
Expand Down
10 changes: 7 additions & 3 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha

## Multimodal `model.mtmd`

`model.mtmd` loads a llama.cpp multimodal projector and enables OpenAI-style image and audio content parts.
`model.mtmd` loads a llama.cpp multimodal projector and enables OpenAI-style image, audio, and video content parts.

```json
{
Expand All @@ -305,8 +305,10 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha
"path": ".cache/mtmd-embeddings",
"max_bytes": 1073741824
},
"batch_max_tokens": 1024,
"image_max_bytes": 20971520,
"audio_max_bytes": 104857600,
"video_max_bytes": 536870912,
"image_timeout_seconds": 10.0
}
}
Expand All @@ -317,11 +319,13 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha
| --- | --- |
| `mmproj_path` | Local multimodal projector path. |
| `mmproj_from_pretrained` | Hugging Face projector source. |
| `embedding_cache.path` | Directory for cached image and audio embeddings. |
| `embedding_cache.path` | Directory for cached image, audio, and video embeddings. |
| `embedding_cache.max_bytes` | Maximum embedding cache size. |
| `batch_max_tokens` | Maximum number of media output tokens per MTMD projector-side encode batch. |
| `image_max_bytes` | Maximum image payload size. |
| `audio_max_bytes` | Maximum audio payload size. |
| `image_timeout_seconds` | Timeout for remote image and audio URL fetches. |
| `video_max_bytes` | Maximum video payload size. |
| `image_timeout_seconds` | Timeout for remote image, audio, and video URL fetches. |

Send image inputs with OpenAI chat content parts.

Expand Down
246 changes: 174 additions & 72 deletions examples/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3223,6 +3223,7 @@ class MTMDOptions(BaseModel):
embedding_cache: Optional["ConfigFile.MTMDEmbeddingCacheOptions"] = None
allowed_media_domains: Optional[List[str]] = None
allowed_local_media_path: Optional[str] = None
batch_max_tokens: int = Field(default=1024, ge=1)
image_max_bytes: int = Field(default=20 * 1024 * 1024, ge=1)
audio_max_bytes: int = Field(default=100 * 1024 * 1024, ge=1)
video_max_bytes: int = Field(default=512 * 1024 * 1024, ge=1)
Expand Down Expand Up @@ -10410,6 +10411,21 @@ class MTMDLoadedMedia:


class MTMDProcessor:
@dataclass
class MediaChunk:
kind: Literal["image", "audio", "video"]
key: str
chunk: Any
n_tokens: int
decode_n_pos: int
non_causal: bool
embeddings: Optional[np.ndarray] = None

@dataclass
class ParsedChunk:
text_tokens: Optional[List[int]] = None
media: Optional["MTMDProcessor.MediaChunk"] = None

def __init__(
self,
*,
Expand All @@ -10422,6 +10438,7 @@ def __init__(
n_ubatch: int,
n_threads_batch: int,
mmproj_path: str,
batch_max_tokens: int,
embedding_cache: Optional[MTMDEmbeddingCache],
allowed_media_domains: Optional[List[str]],
allowed_local_media_path: Optional[str],
Expand All @@ -10437,6 +10454,7 @@ def __init__(
self.n_ubatch = n_ubatch
self.mmproj_path = mmproj_path
self.embedding_cache = embedding_cache
self.batch_max_tokens = batch_max_tokens
self.model_fingerprint = MTMDEmbeddingCache.fingerprint_file(model_path)
self.mmproj_fingerprint = MTMDEmbeddingCache.fingerprint_file(mmproj_path)
self.allowed_media_domains = (
Expand All @@ -10456,6 +10474,7 @@ def __init__(
self.lock = threading.Lock()
params = mtmd_cpp.mtmd_context_params_default()
params.n_threads = max(1, n_threads_batch)
params.batch_max_tokens = batch_max_tokens
self.ctx = mtmd_cpp.mtmd_init_from_file(
mmproj_path.encode("utf-8"),
llama_model,
Expand Down Expand Up @@ -10705,37 +10724,91 @@ def _media_identity_tokens(
tokens.append(-1 - (int.from_bytes(digest[:4], "little") & 0x3FFFFFFF))
return tokens

def _encode_media_chunk(
self,
*,
kind: Literal["image", "audio", "video"],
key: str,
chunk: Any,
) -> np.ndarray:
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
if self.embedding_cache is not None:
cached = self.embedding_cache.load(key)
if (
cached is not None
and cached.embeddings.shape == (n_tokens, self.n_embd_inp)
):
return cached.embeddings
result = int(mtmd_cpp.mtmd_encode_chunk(self.ctx, chunk))
if result != 0:
raise CompletionRequestValidationError(
f"failed to encode {kind} chunk: error code {result}"
)
output = mtmd_cpp.mtmd_get_output_embd(self.ctx)
if output is None:
raise CompletionRequestValidationError(f"MTMD {kind} encoder returned no embeddings")
def _embeddings_from_pointer(self, output: Any, n_tokens: int) -> np.ndarray:
flat = np.ctypeslib.as_array(output, shape=(n_tokens * self.n_embd_inp,))
embeddings = np.array(flat, dtype=np.float32, copy=True).reshape(
return np.array(flat, dtype=np.float32, copy=True).reshape(
n_tokens,
self.n_embd_inp,
)
if self.embedding_cache is not None:
self.embedding_cache.save(key, embeddings)
return embeddings

def _load_cached_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> bool:
if self.embedding_cache is None:
return False
cached = self.embedding_cache.load(media_chunk.key)
if cached is None or cached.embeddings.shape != (
media_chunk.n_tokens,
self.n_embd_inp,
):
return False
media_chunk.embeddings = cached.embeddings
return True

def _save_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> None:
if self.embedding_cache is None or media_chunk.embeddings is None:
return
self.embedding_cache.save(media_chunk.key, media_chunk.embeddings)

def _encode_media_batch(
self,
media_chunks: Sequence["MTMDProcessor.MediaChunk"],
start_index: int,
) -> int:
batch = mtmd_cpp.mtmd_batch_init(self.ctx)
if batch is None:
raise CompletionRequestValidationError("failed to create MTMD media batch")
try:
first = media_chunks[start_index]
result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, first.chunk))
if result != 0:
raise CompletionRequestValidationError(
f"failed to add {first.kind} chunk to MTMD batch: error code {result}"
)
group = [first]
next_index = start_index + 1
while next_index < len(media_chunks):
candidate = media_chunks[next_index]
result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, candidate.chunk))
if result == 0:
group.append(candidate)
next_index += 1
continue
if result in {2, 3}:
break
raise CompletionRequestValidationError(
f"failed to add {candidate.kind} chunk to MTMD batch: error code {result}"
)
result = int(mtmd_cpp.mtmd_batch_encode(batch))
if result != 0:
raise CompletionRequestValidationError(
f"failed to encode MTMD media batch: error code {result}"
)
for media_chunk in group:
output = mtmd_cpp.mtmd_batch_get_output_embd(batch, media_chunk.chunk)
if output is None:
raise CompletionRequestValidationError(
f"MTMD {media_chunk.kind} encoder returned no embeddings"
)
media_chunk.embeddings = self._embeddings_from_pointer(
output,
media_chunk.n_tokens,
)
self._save_media_chunk(media_chunk)
return len(group)
finally:
mtmd_cpp.mtmd_batch_free(batch)

def _encode_media_chunks(
self,
media_chunks: Sequence["MTMDProcessor.MediaChunk"],
) -> None:
uncached = [
media_chunk
for media_chunk in media_chunks
if not self._load_cached_media_chunk(media_chunk)
]
index = 0
while index < len(uncached):
index += self._encode_media_batch(uncached, index)

def _positions_for_chunk(self, chunk: Any, start_pos: int) -> np.ndarray:
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
Expand Down Expand Up @@ -10858,12 +10931,8 @@ def _build_prompt_plan_locked(
raise CompletionRequestValidationError(
f"failed to tokenize MTMD prompt: error code {result}"
)
segments: List[PromptSegment] = []
identity_tokens: List[int] = []
text_tokens: List[int] = []
text_token_index_by_pos: Dict[int, int] = {}
identity_pos = 0
decode_pos = 0
parsed_chunks: List[MTMDProcessor.ParsedChunk] = []
media_chunks: List[MTMDProcessor.MediaChunk] = []
video_index = 0
used_media_keys = set()
n_chunks = int(mtmd_cpp.mtmd_input_chunks_size(chunks))
Expand All @@ -10884,24 +10953,9 @@ def _build_prompt_plan_locked(
else []
)
if tokens:
start_pos = identity_pos
segments.append(
PromptSegment(
kind="text",
start_pos=start_pos,
n_pos=len(tokens),
identity_tokens=list(tokens),
decode_start_pos=decode_pos,
decode_n_pos=len(tokens),
text_tokens=list(tokens),
)
parsed_chunks.append(
MTMDProcessor.ParsedChunk(text_tokens=tokens)
)
for offset, token in enumerate(tokens):
text_token_index_by_pos[start_pos + offset] = len(text_tokens)
text_tokens.append(token)
identity_tokens.extend(tokens)
identity_pos += len(tokens)
decode_pos += len(tokens)
continue
if chunk_type == mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE:
chunk_kind: Literal["image", "audio"] = "image"
Expand Down Expand Up @@ -10951,37 +11005,84 @@ def _build_prompt_plan_locked(
decode_n_pos = int(mtmd_cpp.mtmd_input_chunk_get_n_pos(chunk))
if decode_n_pos <= 0:
raise CompletionRequestValidationError("MTMD media chunk has no decoder positions")
embeddings = self._encode_media_chunk(kind=kind, key=key, chunk=chunk)
n_tokens = int(embeddings.shape[0])
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
if n_tokens <= 0:
raise CompletionRequestValidationError("MTMD media chunk has no embeddings")
raise CompletionRequestValidationError("MTMD media chunk has no embedding tokens")
non_causal = bool(mtmd_cpp.mtmd_decode_use_non_causal(self.ctx, chunk))
segment_identity = self._media_identity_tokens(kind, key, n_tokens)
positions = self._positions_for_chunk(chunk, decode_pos)
segment = PromptSegment(
media_chunk = MTMDProcessor.MediaChunk(
kind=kind,
start_pos=identity_pos,
n_pos=n_tokens,
identity_tokens=segment_identity,
decode_start_pos=decode_pos,
key=key,
chunk=chunk,
n_tokens=n_tokens,
decode_n_pos=decode_n_pos,
media=PromptSegment.Media(
embeddings=embeddings,
positions=positions,
non_causal=non_causal,
),
non_causal=non_causal,
)
if non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
parsed_chunks.append(MTMDProcessor.ParsedChunk(media=media_chunk))
media_chunks.append(media_chunk)
if used_media_keys != {media.key for media in loaded_media}:
raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
self._encode_media_chunks(media_chunks)
segments: List[PromptSegment] = []
identity_tokens: List[int] = []
text_tokens: List[int] = []
text_token_index_by_pos: Dict[int, int] = {}
identity_pos = 0
decode_pos = 0
for parsed_chunk in parsed_chunks:
if parsed_chunk.text_tokens is not None:
tokens = parsed_chunk.text_tokens
start_pos = identity_pos
segments.append(
PromptSegment(
kind="text",
start_pos=start_pos,
n_pos=len(tokens),
identity_tokens=list(tokens),
decode_start_pos=decode_pos,
decode_n_pos=len(tokens),
text_tokens=list(tokens),
)
)
for offset, token in enumerate(tokens):
text_token_index_by_pos[start_pos + offset] = len(text_tokens)
text_tokens.append(token)
identity_tokens.extend(tokens)
identity_pos += len(tokens)
decode_pos += len(tokens)
continue
media_chunk = parsed_chunk.media
if media_chunk is None or media_chunk.embeddings is None:
raise CompletionRequestValidationError("MTMD media chunk has no embeddings")
embeddings = media_chunk.embeddings
if media_chunk.non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
raise CompletionRequestValidationError(
f"non-causal {kind} embedding chunk exceeds model batch limits; "
f"non-causal {media_chunk.kind} embedding chunk exceeds model batch limits; "
"increase n_batch and n_ubatch"
)
segments.append(segment)
segment_identity = self._media_identity_tokens(
media_chunk.kind,
media_chunk.key,
media_chunk.n_tokens,
)
positions = self._positions_for_chunk(media_chunk.chunk, decode_pos)
segments.append(
PromptSegment(
kind=media_chunk.kind,
start_pos=identity_pos,
n_pos=media_chunk.n_tokens,
identity_tokens=segment_identity,
decode_start_pos=decode_pos,
decode_n_pos=media_chunk.decode_n_pos,
media=PromptSegment.Media(
embeddings=embeddings,
positions=positions,
non_causal=media_chunk.non_causal,
),
)
)
identity_tokens.extend(segment_identity)
identity_pos += n_tokens
decode_pos += decode_n_pos
if used_media_keys != {media.key for media in loaded_media}:
raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
identity_pos += media_chunk.n_tokens
decode_pos += media_chunk.decode_n_pos
return PromptPlan(
text=prompt,
generation_prompt=generation_prompt,
Expand Down Expand Up @@ -16211,6 +16312,7 @@ def main() -> None:
n_ubatch=model.n_ubatch,
n_threads_batch=model.n_threads_batch,
mmproj_path=mmproj_path,
batch_max_tokens=config.model.mtmd.batch_max_tokens,
embedding_cache=embedding_cache,
allowed_media_domains=config.model.mtmd.allowed_media_domains,
allowed_local_media_path=config.model.mtmd.allowed_local_media_path,
Expand Down
Loading