Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/content/using-tesseracts/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ As an alternative to the MLflow setup we provide, you can point your Tesseract t
$ tesseract serve --env=TESSERACT_MLFLOW_TRACKING_URI="..." metrics
````

Note that if your MLFlow server uses basic auth, you need to populate the `TESSERACT_MLFLOW_TRACKING_USERNAME` and
Note that if your MLflow server uses basic auth, you need to populate the `TESSERACT_MLFLOW_TRACKING_USERNAME` and
`TESSERACT_MLFLOW_TRACKING_PASSWORD` for the Tesseract to be able to authenticate to it.

```bash
Expand Down
2 changes: 0 additions & 2 deletions examples/metrics/tesseract_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# Tesseract requirements file
# Generated by tesseract 0.9.2.dev16+g7ca45a2.d20250627 on 2025-06-27T11:44:45.333107

mlflow==3.1.1
21 changes: 18 additions & 3 deletions extra/mlflow/docker-compose-mlflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,30 @@ services:
image: ghcr.io/mlflow/mlflow:latest
restart: unless-stopped
user: 1000:1000
command: mlflow server --backend-store-uri sqlite:///mlflow-data/mlflow.db --default-artifact-root file:///mlflow-data/mlruns --host 0.0.0.0 --port 5000
command: >
mlflow server
--backend-store-uri sqlite:///mlflow-data/mlflow.db
--serve-artifacts
--artifacts-destination file:///mlflow-data/mlruns/mlartifacts
--host 0.0.0.0
--allowed-hosts "mlflow-server:5000,localhost:*"
--port 5000
volumes:
- mlflow-data:/mlflow-data
- mlflow-data:/mlflow-data:rw
ports:
- "5000:5000"
- "5000"
depends_on:
mlflow-init:
condition: service_completed_successfully
networks:
- mlflow-network

volumes:
mlflow-data:
name: mlflow-data

networks:
# Use a deterministic network name so we can attach Tesseract
# containers to it more easily
mlflow-network:
name: tesseract-mlflow-server
1 change: 0 additions & 1 deletion inject_runtime_pyproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"fastapi",
"httpx", # required by fastapi test client
"jsf",
"mlflow",
"numpy",
"pre-commit",
"pytest",
Expand Down
4,532 changes: 1,842 additions & 2,690 deletions production.uv.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tesseract_core/runtime/meta/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"pybase64<=1.4.3,>=1.4",
"numpy<=2.3.5,>=1.26",
"debugpy<=1.8.18,>=1.8.14",
"mlflow-skinny>=3.7.0",
]

[project.scripts]
Expand Down
113 changes: 48 additions & 65 deletions tesseract_core/runtime/mpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from io import UnsupportedOperation
from pathlib import Path
from typing import Any
from urllib.parse import quote, urlparse
from urllib.parse import ParseResult, quote, urlparse, urlunparse

import mlflow
import requests

from tesseract_core.runtime.config import get_config
Expand Down Expand Up @@ -129,98 +130,80 @@ class MLflowBackend(BaseBackend):

def __init__(self, base_dir: str | None = None) -> None:
super().__init__(base_dir)
os.environ["GIT_PYTHON_REFRESH"] = (
"quiet" # Suppress potential MLflow git warnings
)

try:
import mlflow
except ImportError as exc:
raise ImportError(
"MLflow is required for MLflowBackend but is not installed"
) from exc

self.mlflow = mlflow
tracking_uri = MLflowBackend._build_tracking_uri()
self._ensure_mlflow_reachable(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
self._ensure_mlflow_reachable(tracking_uri)

@staticmethod
def _build_tracking_uri() -> str:
"""Build the MLflow tracking URI with embedded credentials if provided."""
config = get_config()
tracking_uri = config.mlflow_tracking_uri

if not tracking_uri.startswith(("http://", "https://")):
# If it's a db file URI, convert to local path
tracking_uri = tracking_uri.replace("sqlite:///", "")

# Relative paths are resolved against the base output path
if not Path(tracking_uri).is_absolute():
tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve()

tracking_uri = f"sqlite:///{tracking_uri}"
else:
username = config.mlflow_tracking_username
password = config.mlflow_tracking_password

if bool(username) != bool(password):
raise RuntimeError(
"If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, "
"both must be defined."
)

if username and password:
parsed = urlparse(tracking_uri)
# Reconstruct URI with embedded credentials
tracking_uri = (
f"{parsed.scheme}://{quote(username)}:{quote(password)}@"
f"{parsed.netloc}{parsed.path}"
)
if parsed.query:
tracking_uri += f"?{parsed.query}"
if parsed.fragment:
tracking_uri += f"#{parsed.fragment}"

return tracking_uri
parsed = urlparse(tracking_uri)
if not parsed.scheme:
tracking_uri = f"https://{tracking_uri}"

parsed = urlparse(tracking_uri)
if parsed.scheme not in ("http", "https"):
raise ValueError(
f"MLflow logging only supports accessing MLflow via HTTP/HTTPS (got URI scheme: {parsed.scheme})"
)

username = config.mlflow_tracking_username
password = config.mlflow_tracking_password

if bool(username) != bool(password):
raise RuntimeError(
"If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, "
"both must be defined."
)

if username and password:
# Reconstruct URI with embedded credentials
urlparts = parsed._asdict()
urlparts["netloc"] = f"{quote(username)}:{quote(password)}@{parsed.netloc}"
parsed = ParseResult(**urlparts)

return urlunparse(parsed)

def _ensure_mlflow_reachable(self, tracking_uri: str) -> None:
"""Check if the MLflow tracking server is reachable."""
if tracking_uri.startswith(("http://", "https://")):
try:
response = requests.get(tracking_uri, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
# Don't expose credentials in error message - use the original URI
config = get_config()
display_uri = config.mlflow_tracking_uri
raise RuntimeError(
f"Failed to connect to MLflow tracking server at {display_uri}. "
"Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, "
"or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string."
"If your MLflow server has authentication enabled, please make sure that"
"TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD are set correctly."
) from e
try:
response = requests.get(tracking_uri, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
# Don't expose credentials in error message - use the original URI
config = get_config()
display_uri = config.mlflow_tracking_uri
raise RuntimeError(
f"Failed to connect to MLflow tracking server at {display_uri}. "
"Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, "
"or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string."
"If your MLflow server has authentication enabled, please make sure that"
"TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD are set correctly."
) from e

def log_parameter(self, key: str, value: Any) -> None:
"""Log a parameter to MLflow."""
self.mlflow.log_param(key, value)
mlflow.log_param(key, value)

def log_metric(self, key: str, value: float, step: int | None = None) -> None:
"""Log a metric to MLflow."""
self.mlflow.log_metric(key, value, step=step)
mlflow.log_metric(key, value, step=step)

def log_artifact(self, local_path: str) -> None:
"""Log an artifact to MLflow."""
self.mlflow.log_artifact(local_path)
mlflow.log_artifact(local_path)

def start_run(self) -> None:
"""Start a new MLflow run."""
self.mlflow.start_run()
mlflow.start_run()

def end_run(self) -> None:
"""End the current MLflow run."""
self.mlflow.end_run()
mlflow.end_run()


def _create_backend(base_dir: str | None) -> BaseBackend:
Expand Down
103 changes: 103 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import random
import string
import subprocess
import time
from pathlib import Path
from shutil import copytree
from textwrap import indent
from traceback import format_exception
from typing import Any

import pytest
import requests

# NOTE: Do NOT import tesseract_core here, as it will cause typeguard to fail

Expand Down Expand Up @@ -491,3 +493,104 @@ def hacked_get(url, *args, **kwargs):
monkeypatch.setattr(engine.requests, "get", hacked_get)

yield mock_instance


@pytest.fixture(scope="module")
def mlflow_server():
"""MLflow server to use in tests."""
# Check if docker-compose is available
try:
result = subprocess.run(
["docker", "compose", "version"],
capture_output=True,
check=True,
)
except (subprocess.CalledProcessError, FileNotFoundError):
pytest.fail("docker-compose not available")

# Start MLflow server with unique project name
project_name = f"test_mlflow_{int(time.time())}"

compose_file = (
Path(__file__).parent.parent / "extra" / "mlflow" / "docker-compose-mlflow.yml"
)

try:
# Start the services
subprocess.run(
[
"docker",
"compose",
"-f",
str(compose_file),
"-p",
project_name,
"up",
"-d",
],
check=True,
capture_output=True,
)

res = subprocess.run(
[
"docker",
"compose",
"-f",
str(compose_file),
"-p",
project_name,
"ps",
"--format",
"json",
],
check=True,
capture_output=True,
text=True,
)
service_data = json.loads(res.stdout)
service_port = service_data["Publishers"][0]["PublishedPort"]

# Note: We don't track containers/volumes here because docker-compose down -v
# will handle cleanup automatically in the finally block

# Wait for MLflow to be ready (with timeout)
tracking_uri = f"http://localhost:{service_port}"
max_wait = 30 # seconds
start_time = time.time()

while time.time() - start_time < max_wait:
try:
response = requests.get(tracking_uri, timeout=2)
if response.status_code == 200:
break
except requests.RequestException:
pass
time.sleep(1)
else:
pytest.fail(f"MLflow server did not become ready within {max_wait}s")

yield tracking_uri

finally:
# Get logs for debugging
result = subprocess.run(
["docker", "compose", "-f", str(compose_file), "-p", project_name, "logs"],
capture_output=True,
text=True,
)
print(result.stdout)
# Stop and remove containers
subprocess.run(
[
"docker",
"compose",
"-f",
str(compose_file),
"-p",
project_name,
"down",
"-v",
],
capture_output=True,
)
4 changes: 3 additions & 1 deletion tests/endtoend_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def build_tesseract(
print_debug_info(result)
assert result.exit_code == 0, result.exception

image_tag = json.loads(result.stdout.strip())[0]
# Parse the last line of stdout which contains the JSON array of image tags
stdout_lines = result.stdout.strip().split("\n")
image_tag = json.loads(stdout_lines[-1])[0]

# This raise an error if the image does not exist
client.images.get(image_tag)
Expand Down
Loading