From f257d16ec4cc734bbc0570b1a82a7238b01d5889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Fri, 19 Dec 2025 15:08:08 +0100 Subject: [PATCH] add dependency-free mlflow backend --- docs/content/using-tesseracts/advanced.md | 2 +- examples/metrics/tesseract_requirements.txt | 2 - extra/mlflow/docker-compose-mlflow.yml | 21 +- tesseract_core/runtime/mlflow_client.py | 383 ++++++++++++++++++++ tesseract_core/runtime/mpa.py | 113 +++--- tests/conftest.py | 103 ++++++ tests/endtoend_tests/test_endtoend.py | 87 +++-- tests/endtoend_tests/test_mlflow_client.py | 382 +++++++++++++++++++ tests/runtime_tests/test_mpa.py | 184 +++++----- 9 files changed, 1067 insertions(+), 210 deletions(-) create mode 100644 tesseract_core/runtime/mlflow_client.py create mode 100644 tests/endtoend_tests/test_mlflow_client.py diff --git a/docs/content/using-tesseracts/advanced.md b/docs/content/using-tesseracts/advanced.md index 2708a1c5..5340b2e9 100644 --- a/docs/content/using-tesseracts/advanced.md +++ b/docs/content/using-tesseracts/advanced.md @@ -66,7 +66,7 @@ As an alternative to the MLflow setup we provide, you can point your Tesseract t $ tesseract serve --env=TESSERACT_MLFLOW_TRACKING_URI="..." metrics ```` -Note that if your MLFlow server uses basic auth, you need to populate the `TESSERACT_MLFLOW_TRACKING_USERNAME` and +Note that if your MLflow server uses basic auth, you need to populate the `TESSERACT_MLFLOW_TRACKING_USERNAME` and `TESSERACT_MLFLOW_TRACKING_PASSWORD` for the Tesseract to be able to authenticate to it. ```bash diff --git a/examples/metrics/tesseract_requirements.txt b/examples/metrics/tesseract_requirements.txt index 117814ce..1f797d40 100644 --- a/examples/metrics/tesseract_requirements.txt +++ b/examples/metrics/tesseract_requirements.txt @@ -1,4 +1,2 @@ # Tesseract requirements file # Generated by tesseract 0.9.2.dev16+g7ca45a2.d20250627 on 2025-06-27T11:44:45.333107 - -mlflow==3.1.1 diff --git a/extra/mlflow/docker-compose-mlflow.yml b/extra/mlflow/docker-compose-mlflow.yml index 34d6e73b..606ca78e 100644 --- a/extra/mlflow/docker-compose-mlflow.yml +++ b/extra/mlflow/docker-compose-mlflow.yml @@ -17,15 +17,30 @@ services: image: ghcr.io/mlflow/mlflow:latest restart: unless-stopped user: 1000:1000 - command: mlflow server --backend-store-uri sqlite:///mlflow-data/mlflow.db --default-artifact-root file:///mlflow-data/mlruns --host 0.0.0.0 --port 5000 + command: > + mlflow server + --backend-store-uri sqlite:///mlflow-data/mlflow.db + --serve-artifacts + --artifacts-destination file:///mlflow-data/mlruns/mlartifacts + --host 0.0.0.0 + --allowed-hosts "mlflow-server:5000,localhost:*" + --port 5000 volumes: - - mlflow-data:/mlflow-data + - mlflow-data:/mlflow-data:rw ports: - - "5000:5000" + - "5000" depends_on: mlflow-init: condition: service_completed_successfully + networks: + - mlflow-network volumes: mlflow-data: name: mlflow-data + +networks: + # Use a deterministic network name so we can attach Tesseract + # containers to it more easily + mlflow-network: + name: tesseract-mlflow-server diff --git a/tesseract_core/runtime/mlflow_client.py b/tesseract_core/runtime/mlflow_client.py new file mode 100644 index 00000000..f50e49f3 --- /dev/null +++ b/tesseract_core/runtime/mlflow_client.py @@ -0,0 +1,383 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""MLflow HTTP client for interacting with MLflow Tracking Server without the mlflow package. + +This module implements direct REST API access to MLflow, eliminating the need for the heavy +mlflow package dependency (which includes pandas, numpy, scipy, etc.). Only requires requests. + +Key MLflow-specific behaviors: +- Timestamps must be Unix milliseconds (not seconds) +- Parameter values are limited to 6000 bytes, keys to 255 bytes +- Tags are passed as list of dicts: [{"key": "k", "value": "v"}], not plain dicts +- Artifact uploads require server started with --serve-artifacts flag +- Authentication can be embedded in URI: http://user:pass@host:port + +Reference: https://mlflow.org/docs/latest/api_reference/rest-api.html +""" + +import time +from pathlib import Path +from typing import Any +from urllib.parse import urljoin + +import requests + + +class MLflowHTTPClient: + """HTTP client for MLflow Tracking Server REST API. + + Maintains session state for efficient connection pooling and tracks the current run_id. + Most methods default to using the current run_id if not explicitly provided. + + Context manager usage automatically ends the run with status FAILED on exception, + FINISHED on normal exit. + """ + + def __init__( + self, + tracking_uri: str, + experiment_id: str = "0", + timeout: int = 30, + ) -> None: + """Initialize MLflow HTTP client. + + Args: + tracking_uri: MLflow server URI. Can include credentials: http://user:pass@host:port + experiment_id: Experiment ID (default "0"). Note: This is a string, not an int. + timeout: Request timeout in seconds + """ + self.tracking_uri = tracking_uri.rstrip("/") + self.experiment_id = experiment_id + self.timeout = timeout + self.session = requests.Session() + self.run_id: str | None = None + + def _make_request( + self, + method: str, + endpoint: str, + json_data: dict[str, Any] | None = None, + **kwargs: Any, + ) -> requests.Response: + """Make an HTTP request to the MLflow API. + + Automatically prepends /api/ to endpoint and extracts MLflow error messages from responses. + + Args: + method: HTTP method (GET, POST, PUT, etc.) + endpoint: API endpoint path like "2.0/mlflow/runs/create" (without /api/ prefix) + json_data: JSON data to send in request body + **kwargs: Additional arguments for requests (e.g., params for GET) + """ + url = urljoin(self.tracking_uri, f"/api/{endpoint}") + + response = self.session.request( + method=method, + url=url, + json=json_data, + timeout=kwargs.pop("timeout", self.timeout), + **kwargs, + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + # Add response body to error message for debugging + error_msg = f"MLflow API request failed: {e}" + try: + error_detail = response.json() + if "message" in error_detail: + error_msg += f"\nMLflow error: {error_detail['message']}" + except Exception: + error_msg += f"\nResponse: {response.text}" + raise requests.HTTPError(error_msg, response=response) from e + + return response + + def create_run( + self, + run_name: str | None = None, + tags: dict[str, str] | None = None, + ) -> str: + """Create a new run and set it as the current run_id. + + Note: Tags dict is converted to MLflow's list format: [{"key": k, "value": v}] + + Args: + run_name: Optional name for the run + tags: Optional dictionary of tags to attach to the run + + Returns: + Run ID (also stored in self.run_id) + """ + data: dict[str, Any] = { + "experiment_id": self.experiment_id, + "start_time": int(time.time() * 1000), # Unix timestamp in milliseconds + } + + if run_name: + data["run_name"] = run_name + + if tags: + data["tags"] = [{"key": k, "value": v} for k, v in tags.items()] + + response = self._make_request("POST", "2.0/mlflow/runs/create", json_data=data) + result = response.json() + self.run_id = result["run"]["info"]["run_id"] + assert self.run_id is not None + return self.run_id + + def update_run( + self, + run_id: str, + status: str = "FINISHED", + end_time: int | None = None, + ) -> None: + """Update a run's status and end time. + + Args: + run_id: ID of the run to update + status: RUNNING | SCHEDULED | FINISHED | FAILED | KILLED + end_time: Unix timestamp in milliseconds (not seconds). Defaults to now. + """ + if end_time is None: + end_time = int(time.time() * 1000) + + data = { + "run_id": run_id, + "status": status, + "end_time": end_time, + } + + self._make_request("POST", "2.0/mlflow/runs/update", json_data=data) + + def end_run( + self, + run_id: str | None = None, + status: str = "FINISHED", + ) -> None: + """End a run and clear self.run_id if it was the current run. + + Args: + run_id: Defaults to self.run_id. Raises RuntimeError if both are None. + status: FINISHED | FAILED | KILLED + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run to end") + run_id = self.run_id + + self.update_run(run_id, status=status) + + if run_id == self.run_id: + self.run_id = None + + def log_param( + self, + key: str, + value: Any, + run_id: str | None = None, + ) -> None: + """Log a parameter. Value is always converted to string. + + MLflow limits: key max 255 bytes, value max 6000 bytes (enforced server-side). + + Args: + key: Parameter name (max 255 bytes) + value: Parameter value (will be converted to string, max 6000 bytes) + run_id: Defaults to self.run_id + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run. Call create_run() first.") + run_id = self.run_id + + data = { + "run_id": run_id, + "key": key, + "value": str(value), + } + + self._make_request("POST", "2.0/mlflow/runs/log-parameter", json_data=data) + + def log_metric( + self, + key: str, + value: float, + step: int | None = None, + timestamp: int | None = None, + run_id: str | None = None, + ) -> None: + """Log a metric. Value is cast to float. + + Args: + key: Metric name + value: Metric value (will be cast to float) + step: Optional step/iteration number for time-series metrics + timestamp: Unix milliseconds (not seconds). Defaults to now. + run_id: Defaults to self.run_id + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run. Call create_run() first.") + run_id = self.run_id + + if timestamp is None: + timestamp = int(time.time() * 1000) + + data = { + "run_id": run_id, + "key": key, + "value": float(value), + "timestamp": timestamp, + } + + if step is not None: + data["step"] = step + + self._make_request("POST", "2.0/mlflow/runs/log-metric", json_data=data) + + def log_batch( + self, + metrics: list[dict[str, Any]] | None = None, + params: list[dict[str, Any]] | None = None, + tags: list[dict[str, str]] | None = None, + run_id: str | None = None, + ) -> None: + """Log multiple items in a single request. More efficient than individual calls. + + Args: + metrics: Each dict needs: key, value, timestamp. Optional: step + params: Each dict needs: key, value + tags: Each dict needs: key, value + run_id: Defaults to self.run_id + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run. Call create_run() first.") + run_id = self.run_id + + data: dict[str, Any] = {"run_id": run_id} + + if metrics: + data["metrics"] = metrics + if params: + data["params"] = params + if tags: + data["tags"] = tags + + self._make_request("POST", "2.0/mlflow/runs/log-batch", json_data=data) + + def log_artifact( + self, + local_path: str | Path, + artifact_path: str | None = None, + run_id: str | None = None, + ) -> None: + """Upload an artifact file using PUT request to mlflow-artifacts endpoint. + + IMPORTANT: Requires MLflow server with --serve-artifacts flag + + This method implements the same approach as MLflow's http_artifact_repo.py + + Args: + local_path: Path to the local file to upload + artifact_path: Optional subdirectory within run's artifact root + run_id: Defaults to self.run_id + + Raises: + requests.HTTPError: If upload fails + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run. Call create_run() first.") + run_id = self.run_id + + local_path = Path(local_path) + if not local_path.exists(): + raise FileNotFoundError(f"Artifact file not found: {local_path}") + + # Construct the artifact upload URL following MLflow's convention + # The endpoint format matches http_artifact_repo.py implementation + import posixpath + + file_name = local_path.name + paths = (artifact_path, file_name) if artifact_path else (file_name,) + endpoint = posixpath.join("/", *paths) + + # Base URL for artifacts with run_id + base_url = f"{self.tracking_uri}/api/2.0/mlflow-artifacts/artifacts/{self.experiment_id}/{run_id}/artifacts" + url = f"{base_url}{endpoint}" + + # Guess MIME type for Content-Type header (matches MLflow's _guess_mime_type) + import mimetypes + + mime_type, _ = mimetypes.guess_type(file_name) + if mime_type is None: + mime_type = "application/octet-stream" + + # Upload using PUT with Content-Type header (matches MLflow SDK) + with open(local_path, "rb") as f: + response = self.session.put( + url, + data=f, + headers={"Content-Type": mime_type}, + timeout=self.timeout, + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + error_msg = f"Failed to upload artifact: {e}" + # If proxied artifact access is not available, provide a helpful message + if response.status_code == 503: + error_msg += ( + "\n\nNote: Artifact upload requires the MLflow server to be started " + "with the --serve-artifacts flag for proxied artifact access." + ) + try: + error_detail = response.json() + if "message" in error_detail: + error_msg += f"\nMLflow error: {error_detail['message']}" + except Exception: + error_msg += f"\nResponse: {response.text}" + raise requests.HTTPError(error_msg, response=response) from e + + def get_run(self, run_id: str | None = None) -> dict[str, Any]: + """Get run metadata including params, metrics, tags, and status. + + Args: + run_id: Defaults to self.run_id + + Returns: + Dict with structure: {"run": {"info": {...}, "data": {...}}} + """ + if run_id is None: + if self.run_id is None: + raise RuntimeError("No active run") + run_id = self.run_id + + response = self._make_request( + "GET", + "2.0/mlflow/runs/get", + params={"run_id": run_id}, + ) + return response.json() + + def close(self) -> None: + """Close the HTTP session and release connection pool resources.""" + self.session.close() + + def __enter__(self) -> "MLflowHTTPClient": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """End active run (FAILED on exception, FINISHED otherwise) and close session.""" + if self.run_id is not None: + status = "FAILED" if exc_type is not None else "FINISHED" + try: + self.end_run(status=status) + except Exception: + pass # Don't raise during cleanup + self.close() diff --git a/tesseract_core/runtime/mpa.py b/tesseract_core/runtime/mpa.py index ac6003cb..9b0db19f 100644 --- a/tesseract_core/runtime/mpa.py +++ b/tesseract_core/runtime/mpa.py @@ -16,7 +16,7 @@ from io import UnsupportedOperation from pathlib import Path from typing import Any -from urllib.parse import quote, urlparse +from urllib.parse import ParseResult, quote, urlparse, urlunparse import requests @@ -129,21 +129,11 @@ class MLflowBackend(BaseBackend): def __init__(self, base_dir: str | None = None) -> None: super().__init__(base_dir) - os.environ["GIT_PYTHON_REFRESH"] = ( - "quiet" # Suppress potential MLflow git warnings - ) - - try: - import mlflow - except ImportError as exc: - raise ImportError( - "MLflow is required for MLflowBackend but is not installed" - ) from exc + from tesseract_core.runtime.mlflow_client import MLflowHTTPClient - self.mlflow = mlflow tracking_uri = MLflowBackend._build_tracking_uri() self._ensure_mlflow_reachable(tracking_uri) - mlflow.set_tracking_uri(tracking_uri) + self.client = MLflowHTTPClient(tracking_uri=tracking_uri) @staticmethod def _build_tracking_uri() -> str: @@ -151,76 +141,69 @@ def _build_tracking_uri() -> str: config = get_config() tracking_uri = config.mlflow_tracking_uri - if not tracking_uri.startswith(("http://", "https://")): - # If it's a db file URI, convert to local path - tracking_uri = tracking_uri.replace("sqlite:///", "") - - # Relative paths are resolved against the base output path - if not Path(tracking_uri).is_absolute(): - tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve() - - tracking_uri = f"sqlite:///{tracking_uri}" - else: - username = config.mlflow_tracking_username - password = config.mlflow_tracking_password - - if bool(username) != bool(password): - raise RuntimeError( - "If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, " - "both must be defined." - ) - - if username and password: - parsed = urlparse(tracking_uri) - # Reconstruct URI with embedded credentials - tracking_uri = ( - f"{parsed.scheme}://{quote(username)}:{quote(password)}@" - f"{parsed.netloc}{parsed.path}" - ) - if parsed.query: - tracking_uri += f"?{parsed.query}" - if parsed.fragment: - tracking_uri += f"#{parsed.fragment}" - - return tracking_uri + parsed = urlparse(tracking_uri) + if not parsed.scheme: + tracking_uri = f"https://{tracking_uri}" + + parsed = urlparse(tracking_uri) + if parsed.scheme not in ("http", "https"): + raise ValueError( + f"MLflow logging only supports accessing MLflow via HTTP/HTTPS (got URI scheme: {parsed.scheme})" + ) + + username = config.mlflow_tracking_username + password = config.mlflow_tracking_password + + if bool(username) != bool(password): + raise RuntimeError( + "If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, " + "both must be defined." + ) + + if username and password: + # Reconstruct URI with embedded credentials + urlparts = parsed._asdict() + urlparts["netloc"] = f"{quote(username)}:{quote(password)}@{parsed.netloc}" + parsed = ParseResult(**urlparts) + + return urlunparse(parsed) def _ensure_mlflow_reachable(self, tracking_uri: str) -> None: """Check if the MLflow tracking server is reachable.""" - if tracking_uri.startswith(("http://", "https://")): - try: - response = requests.get(tracking_uri, timeout=5) - response.raise_for_status() - except requests.RequestException as e: - # Don't expose credentials in error message - use the original URI - config = get_config() - display_uri = config.mlflow_tracking_uri - raise RuntimeError( - f"Failed to connect to MLflow tracking server at {display_uri}. " - "Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, " - "or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string." - "If your MLflow server has authentication enabled, please make sure that" - "TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD are set correctly." - ) from e + try: + response = requests.get(tracking_uri, timeout=5) + response.raise_for_status() + except requests.RequestException as e: + # Don't expose credentials in error message - use the original URI + config = get_config() + display_uri = config.mlflow_tracking_uri + raise RuntimeError( + f"Failed to connect to MLflow tracking server at {display_uri}. " + "Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, " + "or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string." + "If your MLflow server has authentication enabled, please make sure that" + "TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD are set correctly." + ) from e def log_parameter(self, key: str, value: Any) -> None: """Log a parameter to MLflow.""" - self.mlflow.log_param(key, value) + self.client.log_param(key, value) def log_metric(self, key: str, value: float, step: int | None = None) -> None: """Log a metric to MLflow.""" - self.mlflow.log_metric(key, value, step=step) + self.client.log_metric(key, value, step=step) def log_artifact(self, local_path: str) -> None: """Log an artifact to MLflow.""" - self.mlflow.log_artifact(local_path) + self.client.log_artifact(local_path) def start_run(self) -> None: """Start a new MLflow run.""" - self.mlflow.start_run() + self.client.create_run() def end_run(self) -> None: """End the current MLflow run.""" - self.mlflow.end_run() + self.client.end_run() def _create_backend(base_dir: str | None) -> BaseBackend: diff --git a/tests/conftest.py b/tests/conftest.py index b4c68769..d3818c14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import random import string import subprocess +import time from pathlib import Path from shutil import copytree from textwrap import indent @@ -13,6 +14,7 @@ from typing import Any import pytest +import requests # NOTE: Do NOT import tesseract_core here, as it will cause typeguard to fail @@ -491,3 +493,104 @@ def hacked_get(url, *args, **kwargs): monkeypatch.setattr(engine.requests, "get", hacked_get) yield mock_instance + + +@pytest.fixture(scope="module") +def mlflow_server(): + """MLflow server to use in tests.""" + # Check if docker-compose is available + try: + result = subprocess.run( + ["docker", "compose", "version"], + capture_output=True, + check=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + pytest.fail("docker-compose not available") + + # Start MLflow server with unique project name + project_name = f"test_mlflow_{int(time.time())}" + + compose_file = ( + Path(__file__).parent.parent / "extra" / "mlflow" / "docker-compose-mlflow.yml" + ) + + try: + # Start the services + subprocess.run( + [ + "docker", + "compose", + "-f", + str(compose_file), + "-p", + project_name, + "up", + "-d", + ], + check=True, + capture_output=True, + ) + + res = subprocess.run( + [ + "docker", + "compose", + "-f", + str(compose_file), + "-p", + project_name, + "ps", + "--format", + "json", + ], + check=True, + capture_output=True, + text=True, + ) + service_data = json.loads(res.stdout) + service_port = service_data["Publishers"][0]["PublishedPort"] + + # Note: We don't track containers/volumes here because docker-compose down -v + # will handle cleanup automatically in the finally block + + # Wait for MLflow to be ready (with timeout) + tracking_uri = f"http://localhost:{service_port}" + max_wait = 30 # seconds + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + response = requests.get(tracking_uri, timeout=2) + if response.status_code == 200: + break + except requests.RequestException: + pass + time.sleep(1) + else: + pytest.fail(f"MLflow server did not become ready within {max_wait}s") + + yield tracking_uri + + finally: + # Get logs for debugging + result = subprocess.run( + ["docker", "compose", "-f", str(compose_file), "-p", project_name, "logs"], + capture_output=True, + text=True, + ) + print(result.stdout) + # Stop and remove containers + subprocess.run( + [ + "docker", + "compose", + "-f", + str(compose_file), + "-p", + project_name, + "down", + "-v", + ], + capture_output=True, + ) diff --git a/tests/endtoend_tests/test_endtoend.py b/tests/endtoend_tests/test_endtoend.py index a0afa118..77efbd77 100644 --- a/tests/endtoend_tests/test_endtoend.py +++ b/tests/endtoend_tests/test_endtoend.py @@ -6,12 +6,12 @@ import json import os import shutil -import sqlite3 import subprocess import uuid from pathlib import Path from textwrap import dedent +import mlflow import numpy as np import pytest import requests @@ -1089,26 +1089,27 @@ def test_mpa_file_backend(tmpdir, mpa_test_image): assert artifact_data == "Test artifact content" -@pytest.mark.parametrize("user", [None, "root", "12579:12579"]) -def test_mpa_mlflow_backend(mpa_test_image, tmpdir, user): - """Test the MPA (Metrics, Parameters, and Artifacts) submodule with MLflow backend.""" - if user not in (None, "root"): - Path(tmpdir).chmod(0o777) +def test_mpa_mlflow_backend(mlflow_server, mpa_test_image): + """Test the MPA (Metrics, Parameters, and Artifacts) submodule with MLflow backend, using a local MLflow server.""" + # Hardcode some values specific to docker-compose config in extra/mlflow/mlflow-docker-compose.yaml - # Point MLflow to a local directory + # Inside containers, tracking URIs look like http://{service_name}:{internal_port} + mlflow_server_local = "http://mlflow-server:5000" + # Network name as specified in MLflow docker compose config + network_name = "tesseract-mlflow-server" + + # Run the Tesseract, logging to running MLflow server run_cmd = [ "tesseract", "run", + "--network", + network_name, "--env", - "TESSERACT_MLFLOW_TRACKING_URI=mlflow.db", - *(["--user", user] if user else []), + f"TESSERACT_MLFLOW_TRACKING_URI={mlflow_server_local}", mpa_test_image, "apply", '{"inputs": {}}', - "--output-path", - tmpdir, ] - run_res = subprocess.run( run_cmd, capture_output=True, @@ -1116,36 +1117,46 @@ def test_mpa_mlflow_backend(mpa_test_image, tmpdir, user): ) assert run_res.returncode == 0, run_res.stderr - # Check for MLflow database file - mlflow_db_path = Path(tmpdir) / "mlflow.db" - assert mlflow_db_path.exists(), "Expected MLflow database file to exist" + # Use MLflow client to verify content was logged + mlflow.set_tracking_uri(mlflow_server) - # Query the database to verify content was logged - with sqlite3.connect(str(mlflow_db_path)) as conn: - cursor = conn.cursor() + # Get the most recent run (the one we just created) + from mlflow.tracking import MlflowClient - # Check parameters were logged - cursor.execute("SELECT key, value FROM params") - params = dict(cursor.fetchall()) - assert params["test_parameter"] == "test_param" - assert params["steps_config"] == "5" # MLflow stores params as strings + client = MlflowClient() - # Check metrics were logged - cursor.execute("SELECT key, value, step FROM metrics ORDER BY step") - metrics = cursor.fetchall() - assert len(metrics) == 5 + # Get the default experiment (experiment_id="0") + experiment = client.get_experiment("0") + assert experiment is not None, "Default experiment not found" - # Verify some of the squared_step values - squared_metrics = [m for m in metrics if m[0] == "squared_step"] - assert len(squared_metrics) == 5 - assert squared_metrics[0] == ("squared_step", 0.0, 0) - assert squared_metrics[1] == ("squared_step", 1.0, 1) - assert squared_metrics[4] == ("squared_step", 16.0, 4) - - # Check artifacts were logged (MLflow stores artifact info in runs table) - cursor.execute("SELECT artifact_uri FROM runs") - artifact_uris = [row[0] for row in cursor.fetchall()] - assert len(artifact_uris) > 0 # At least one run with artifacts + runs = client.search_runs(experiment_ids=[experiment.experiment_id]) + assert len(runs) > 0, "No runs found in MLflow" + + # Get the most recent run + print(runs) + run = runs[0] + run_id = run.info.run_id + + # Check parameters were logged + params = run.data.params + assert params["test_parameter"] == "test_param" + assert params["steps_config"] == "5" # MLflow stores params as strings + + # Check metrics were logged + metrics_history = client.get_metric_history(run_id, "squared_step") + assert len(metrics_history) == 5 + + # Verify some of the squared_step values + assert metrics_history[0].value == 0.0 + assert metrics_history[0].step == 0 + assert metrics_history[1].value == 1.0 + assert metrics_history[1].step == 1 + assert metrics_history[4].value == 16.0 + assert metrics_history[4].step == 4 + + # Check artifacts were logged + artifacts = client.list_artifacts(run_id) + assert len(artifacts) > 0, "Expected at least one artifact to be logged" def test_multi_helloworld_endtoend( diff --git a/tests/endtoend_tests/test_mlflow_client.py b/tests/endtoend_tests/test_mlflow_client.py new file mode 100644 index 00000000..00e95cc5 --- /dev/null +++ b/tests/endtoend_tests/test_mlflow_client.py @@ -0,0 +1,382 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end tests for MLflow HTTP client with real MLflow server.""" + +import time +from pathlib import Path + +import mlflow +import pytest + +from tesseract_core.runtime.mlflow_client import MLflowHTTPClient + + +def test_create_and_end_run(mlflow_server): + """Test creating and ending a run.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run( + run_name="test_create_run", tags={"test": "true", "purpose": "endtoend"} + ) + + assert run_id is not None + assert isinstance(run_id, str) + assert client.run_id == run_id + + # Verify run was created using official MLflow client + run = mlflow.get_run(run_id) + assert run.info.run_id == run_id + assert run.info.status == "RUNNING" + assert run.data.tags["test"] == "true" + assert run.data.tags["purpose"] == "endtoend" + + # End the run + client.end_run(status="FINISHED") + assert client.run_id is None + + # Verify run was ended using official MLflow client + run = mlflow.get_run(run_id) + assert run.info.status == "FINISHED" + + +def test_log_parameters(mlflow_server): + """Test logging parameters.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name="test_log_params") + + # Log various parameter types + client.log_param("learning_rate", 0.001) + client.log_param("optimizer", "adam") + client.log_param("batch_size", 32) + client.log_param("epochs", 100) + + # Verify params were logged using official MLflow client + run = mlflow.get_run(run_id) + params = run.data.params + + assert params["learning_rate"] == "0.001" + assert params["optimizer"] == "adam" + assert params["batch_size"] == "32" + assert params["epochs"] == "100" + + client.end_run() + + +def test_log_metrics(mlflow_server): + """Test logging metrics with steps.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name="test_log_metrics") + + # Log metrics with steps + for step in range(5): + client.log_metric("accuracy", 0.5 + step * 0.1, step=step) + client.log_metric("loss", 1.0 - step * 0.15, step=step) + + # Verify metrics were logged using official MLflow client + # Get metric history to verify all steps were logged + mlflow_client = mlflow.MlflowClient(tracking_uri=mlflow_server) + accuracy_history = mlflow_client.get_metric_history(run_id, "accuracy") + loss_history = mlflow_client.get_metric_history(run_id, "loss") + + assert len(accuracy_history) == 5 + assert len(loss_history) == 5 + + # Verify values at specific steps + for i, metric in enumerate(accuracy_history): + assert metric.step == i + assert abs(metric.value - (0.5 + i * 0.1)) < 0.001 + + for i, metric in enumerate(loss_history): + assert metric.step == i + assert abs(metric.value - (1.0 - i * 0.15)) < 0.001 + + client.end_run() + + +def test_log_batch(mlflow_server): + """Test batch logging for efficiency.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name="test_batch_logging") + + # Log batch of items + timestamp = int(time.time() * 1000) + client.log_batch( + metrics=[ + {"key": "train_acc", "value": 0.95, "timestamp": timestamp, "step": 0}, + {"key": "val_acc", "value": 0.92, "timestamp": timestamp, "step": 0}, + ], + params=[ + {"key": "model_type", "value": "cnn"}, + {"key": "layers", "value": "3"}, + ], + tags=[ + {"key": "environment", "value": "test"}, + ], + ) + + # Verify all items were logged using official MLflow client + run = mlflow.get_run(run_id) + + assert run.data.params["model_type"] == "cnn" + assert run.data.params["layers"] == "3" + + assert "train_acc" in run.data.metrics + assert "val_acc" in run.data.metrics + assert abs(run.data.metrics["train_acc"] - 0.95) < 0.001 + assert abs(run.data.metrics["val_acc"] - 0.92) < 0.001 + + assert run.data.tags["environment"] == "test" + + client.end_run() + + +def test_context_manager_normal_exit(mlflow_server): + """Test context manager ends run with FINISHED on normal exit.""" + mlflow.set_tracking_uri(mlflow_server) + + with MLflowHTTPClient(tracking_uri=mlflow_server) as client: + run_id = client.create_run(run_name="test_context_normal") + client.log_param("test", "value") + + # Run should be active + assert client.run_id == run_id + + # After exiting context, verify run was marked as FINISHED using official MLflow client + run = mlflow.get_run(run_id) + assert run.info.status == "FINISHED" + assert run.data.params["test"] == "value" + + +def test_context_manager_exception_exit(mlflow_server): + """Test context manager ends run with FAILED on exception.""" + mlflow.set_tracking_uri(mlflow_server) + run_id = None + + with pytest.raises(ValueError): + with MLflowHTTPClient(tracking_uri=mlflow_server) as client: + run_id = client.create_run(run_name="test_context_exception") + client.log_param("test", "value") + raise ValueError("Test exception") + + # Verify run was marked as FAILED using official MLflow client + assert run_id is not None + run = mlflow.get_run(run_id) + assert run.info.status == "FAILED" + assert run.data.params["test"] == "value" + + +def test_timestamps_in_milliseconds(mlflow_server): + """Test that timestamps are correctly handled in milliseconds.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name="test_timestamps") + + # Log metric with explicit timestamp in milliseconds + timestamp_ms = int(time.time() * 1000) + client.log_metric("test_metric", 0.5, timestamp=timestamp_ms) + + # Verify the timestamp was preserved using official MLflow client + mlflow_client = mlflow.MlflowClient(tracking_uri=mlflow_server) + metric_history = mlflow_client.get_metric_history(run_id, "test_metric") + + assert len(metric_history) == 1 + # MLflow should have stored the timestamp in milliseconds + assert metric_history[0].timestamp == timestamp_ms + assert abs(metric_history[0].value - 0.5) < 0.001 + + client.end_run() + + +def test_experiment_id_as_string(mlflow_server): + """Test that experiment_id is correctly handled as string.""" + # MLflow quirk: experiment_id must be a string, not an int + mlflow.set_tracking_uri(mlflow_server) + client = MLflowHTTPClient( + tracking_uri=mlflow_server, + experiment_id="0", # Default experiment + ) + + with client: + run_id = client.create_run(run_name="test_experiment_id") + + # Verify run was created in the correct experiment using official MLflow client + run = mlflow.get_run(run_id) + assert run.info.experiment_id == "0" + + client.end_run() + + +def test_tag_conversion_to_list_format(mlflow_server): + """Test that tags dict is converted to MLflow's list format.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + # Pass tags as dict + run_id = client.create_run( + run_name="test_tag_conversion", + tags={ + "team": "ml-platform", + "project": "tesseract", + "version": "1.0", + }, + ) + + # Verify tags were stored correctly using official MLflow client + run = mlflow.get_run(run_id) + tags = run.data.tags + + assert tags["team"] == "ml-platform" + assert tags["project"] == "tesseract" + assert tags["version"] == "1.0" + + client.end_run() + + +def test_multiple_runs_sequential(mlflow_server): + """Test creating multiple runs sequentially.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + # First run + run_id_1 = client.create_run(run_name="test_multi_run_1") + client.log_param("run_number", "1") + client.end_run() + + # Second run + run_id_2 = client.create_run(run_name="test_multi_run_2") + client.log_param("run_number", "2") + client.end_run() + + # Verify both runs exist using official MLflow client + assert run_id_1 != run_id_2 + + run_1 = mlflow.get_run(run_id_1) + run_2 = mlflow.get_run(run_id_2) + + assert run_1.data.params["run_number"] == "1" + assert run_2.data.params["run_number"] == "2" + assert run_1.info.status == "FINISHED" + assert run_2.info.status == "FINISHED" + + +def test_error_handling_no_active_run(mlflow_server): + """Test that operations without active run raise clear errors.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + + with client: + # Try to log without creating a run + with pytest.raises(RuntimeError, match="No active run"): + client.log_param("key", "value") + + with pytest.raises(RuntimeError, match="No active run"): + client.log_metric("metric", 0.5) + + with pytest.raises(RuntimeError, match="No active run"): + client.get_run() + + with pytest.raises(RuntimeError, match="No active run"): + client.end_run() + + +def test_update_run_status(mlflow_server): + """Test updating run status directly.""" + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name="test_update_status") + + # Update to FAILED status + client.update_run(run_id, status="FAILED") + + # Verify status was updated using official MLflow client + run = mlflow.get_run(run_id) + assert run.info.status == "FAILED" + + +@pytest.mark.parametrize( + "suffix,content,mode", + [ + (".txt", "Model configuration\nLayers: 3\nActivation: ReLU\n", "w"), + (".json", '{"model": "cnn", "layers": 3, "activation": "ReLU"}', "w"), + (".yaml", "model:\n type: cnn\n layers: 3\n activation: ReLU\n", "w"), + (".csv", "epoch,loss,accuracy\n1,0.5,0.8\n2,0.3,0.9\n", "w"), + (".html", "

Model Report

", "w"), + ( + ".png", + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89", + "wb", + ), + ], +) +def test_log_artifact(tmp_path, mlflow_server, suffix, content, mode): + """Test artifact logging implementation correctness. + + This test validates that our HTTP client correctly implements the MLflow + artifact upload protocol, matching http_artifact_repo.py: + - Uses PUT request with file data + - Includes Content-Type header based on MIME type + - Constructs URL as: /api/2.0/mlflow-artifacts/artifacts/{run_id}/{path} + + Tests multiple MIME types to ensure proper Content-Type header handling. + Verifies artifacts are correctly logged using the official MLflow SDK. + """ + client = MLflowHTTPClient(tracking_uri=mlflow_server) + mlflow.set_tracking_uri(mlflow_server) + + with client: + run_id = client.create_run(run_name=f"test_artifact_upload_{suffix}") + + # Create test file with a predictable name + file_path = tmp_path / f"file{suffix}" + with open(file_path, mode=mode) as f: + f.write(content) + + artifact_filename = Path(file_path).name + + # Log artifact using our HTTP client + client.log_artifact(file_path) + + # Verify artifact was logged using official MLflow SDK + mlflow_client = mlflow.MlflowClient(tracking_uri=mlflow_server) + artifacts = mlflow_client.list_artifacts(run_id=run_id) + + # Check that the artifact exists + assert len(artifacts) > 0, f"No artifacts found for run {run_id}" + artifact_names = [artifact.path for artifact in artifacts] + assert artifact_filename in artifact_names, ( + f"Artifact {artifact_filename} not found in {artifact_names}" + ) + + # Download and verify the artifact content matches + download_path = mlflow_client.download_artifacts(run_id, artifact_filename) + + # Verify content + if mode == "wb": + # Binary file + with open(download_path, "rb") as f: + downloaded_content = f.read() + assert downloaded_content == content, "Binary artifact content mismatch" + else: + # Text file + with open(download_path) as f: + downloaded_content = f.read() + assert downloaded_content == content, "Text artifact content mismatch" + + client.end_run() diff --git a/tests/runtime_tests/test_mpa.py b/tests/runtime_tests/test_mpa.py index 86cbd5f1..fe97ceb9 100644 --- a/tests/runtime_tests/test_mpa.py +++ b/tests/runtime_tests/test_mpa.py @@ -1,12 +1,12 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Tests for the MPA library.""" +"""Tests for the MPA module.""" import csv import json -import os -import sqlite3 +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer import pytest @@ -20,6 +20,45 @@ ) +class Always400Handler(BaseHTTPRequestHandler): + """HTTP request handler that always returns 400.""" + + def do_GET(self): + """Handle GET requests with 400.""" + self.send_response(400) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(b"Bad Request") + + def do_POST(self): + """Handle POST requests with 400.""" + self.send_response(400) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(b"Bad Request") + + def log_message(self, format, *args): + """Suppress log messages.""" + pass + + +@pytest.fixture(scope="module") +def dummy_mlflow_server(): + """Start a dummy HTTP server that always returns 400.""" + server = HTTPServer(("localhost", 0), Always400Handler) + port = server.server_address[1] + + # Start server in a background thread + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + + try: + yield f"http://localhost:{port}" + finally: + # Shutdown server + server.shutdown() + + def test_start_run_context_manager(): """Test that start_run works as a context manager.""" with start_run(): @@ -140,124 +179,68 @@ def test_log_artifact_missing_file(): backend.log_artifact("non_existent_file.txt") -def test_mlflow_backend_creation(tmpdir): - """Test that MLflowBackend is created when mlflow_tracking_uri is set.""" - pytest.importorskip("mlflow") # Skip if MLflow is not installed - mlflow_db_file = tmpdir / "mlflow.db" - update_config(mlflow_tracking_uri=f"sqlite:///{mlflow_db_file}") - backend = mpa._create_backend(None) - assert isinstance(backend, mpa.MLflowBackend) - +def test_mlflow_backend_creation_fails_with_unreachable_server(dummy_mlflow_server): + """Test that MLflowBackend creation fails when server returns 400.""" + update_config(mlflow_tracking_uri=dummy_mlflow_server) + with pytest.raises( + RuntimeError, match="Failed to connect to MLflow tracking server" + ): + mpa._create_backend(None) -def test_mlflow_log_calls(tmpdir): - """Test MLflow backend logging functions with temporary directory.""" - pytest.importorskip("mlflow") # Skip if MLflow is not installed - mlflow_db_file = tmpdir / "mlflow.db" - update_config(mlflow_tracking_uri=f"sqlite:///{mlflow_db_file}") - with start_run(): - log_parameter("model_type", "neural_network") - log_parameter("epochs", 100) - - log_metric("accuracy", 0.85) - log_metric("loss", 0.25, step=1) - - artifact_file = tmpdir / "model_config.json" - artifact_file.write_text("Test content", encoding="utf-8") - log_artifact(str(artifact_file)) - - # Verify MLflow database file was created - assert mlflow_db_file.exists() - - # Query the database to verify content was logged - with sqlite3.connect(str(mlflow_db_file)) as conn: - cursor = conn.cursor() - - # Check parameters were logged - cursor.execute("SELECT key, value FROM params") - params = dict(cursor.fetchall()) - assert params["model_type"] == "neural_network" - assert params["epochs"] == "100" - - # Check metrics were logged - cursor.execute("SELECT key, value, step FROM metrics ORDER BY step") - metrics = cursor.fetchall() - assert len(metrics) == 2 - assert metrics[0] == ("accuracy", 0.85, 0) # step defaults to 0 - assert metrics[1] == ("loss", 0.25, 1) - - # Check artifacts were logged (MLflow stores artifact info in runs table) - cursor.execute("SELECT artifact_uri FROM runs") - artifact_uris = [row[0] for row in cursor.fetchall()] - assert len(artifact_uris) > 0 # At least one run with artifacts - - # Verify the artifact file was actually copied to the artifact location - artifact_found = False - for artifact_uri in artifact_uris: - if artifact_uri and os.path.exists(artifact_uri): - try: - artifact_files = os.listdir(artifact_uri) - if "model_config.json" in artifact_files: - artifact_found = True - break - except OSError: - continue - - assert artifact_found - - -def test_build_tracking_uri_with_credentials(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_with_credentials(dummy_mlflow_server): update_config( - mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_uri=dummy_mlflow_server, mlflow_tracking_username="testuser", mlflow_tracking_password="testpass", ) tracking_uri = mpa.MLflowBackend._build_tracking_uri() - assert tracking_uri == "http://testuser:testpass@localhost:5000" + # Extract host:port from dummy_mlflow_server + expected_uri = dummy_mlflow_server.replace("http://", "http://testuser:testpass@") + assert tracking_uri == expected_uri -def test_build_tracking_uri_without_credentials(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_without_credentials(dummy_mlflow_server): update_config( - mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_uri=dummy_mlflow_server, mlflow_tracking_username="", mlflow_tracking_password="", ) tracking_uri = mpa.MLflowBackend._build_tracking_uri() - assert tracking_uri == "http://localhost:5000" + assert tracking_uri == dummy_mlflow_server -def test_build_tracking_uri_url_encoded_credentials(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_url_encoded_credentials(dummy_mlflow_server): + # Use a dummy HTTPS URL for testing URL encoding + dummy_https_url = dummy_mlflow_server.replace("http://", "https://") update_config( - mlflow_tracking_uri="https://mlflow.example.com", + mlflow_tracking_uri=dummy_https_url, mlflow_tracking_username="user@example.com", mlflow_tracking_password="p@ss:w0rd!", ) tracking_uri = mpa.MLflowBackend._build_tracking_uri() - assert ( - tracking_uri == "https://user%40example.com:p%40ss%3Aw0rd%21@mlflow.example.com" - ) + # Verify that special characters are URL-encoded + assert "user%40example.com" in tracking_uri + assert "p%40ss%3Aw0rd%21" in tracking_uri -def test_build_tracking_uri_with_path_and_query(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_with_path_and_query(dummy_mlflow_server): + # Add path and query to dummy server URL + uri_with_path = f"{dummy_mlflow_server}/api/mlflow?param=value" update_config( - mlflow_tracking_uri="http://localhost:5000/api/mlflow?param=value", + mlflow_tracking_uri=uri_with_path, mlflow_tracking_username="testuser", mlflow_tracking_password="testpass", ) tracking_uri = mpa.MLflowBackend._build_tracking_uri() - assert ( - tracking_uri == "http://testuser:testpass@localhost:5000/api/mlflow?param=value" - ) + # Verify credentials are inserted correctly with path and query preserved + assert "testuser:testpass@" in tracking_uri + assert "/api/mlflow?param=value" in tracking_uri -def test_build_tracking_uri_username_without_password(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_username_without_password(dummy_mlflow_server): update_config( - mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_uri=dummy_mlflow_server, mlflow_tracking_username="testuser", mlflow_tracking_password="", ) @@ -268,10 +251,9 @@ def test_build_tracking_uri_username_without_password(): mpa.MLflowBackend._build_tracking_uri() -def test_build_tracking_uri_password_without_username(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_password_without_username(dummy_mlflow_server): update_config( - mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_uri=dummy_mlflow_server, mlflow_tracking_username="", mlflow_tracking_password="testpass", ) @@ -282,14 +264,14 @@ def test_build_tracking_uri_password_without_username(): mpa.MLflowBackend._build_tracking_uri() -def test_build_tracking_uri_sqlite_ignores_credentials(): - pytest.importorskip("mlflow") +def test_build_tracking_uri_non_http_scheme_raises_error(): + """Test that non-HTTP/HTTPS schemes raise an error.""" update_config( mlflow_tracking_uri="sqlite:///mlflow.db", - mlflow_tracking_username="testuser", - mlflow_tracking_password="testpass", + mlflow_tracking_username="", + mlflow_tracking_password="", ) - tracking_uri = mpa.MLflowBackend._build_tracking_uri() - assert "testuser" not in tracking_uri - assert "testpass" not in tracking_uri - assert tracking_uri.startswith("sqlite:///") + with pytest.raises( + ValueError, match="MLflow logging only supports accessing MLflow via HTTP/HTTPS" + ): + mpa.MLflowBackend._build_tracking_uri()