diff --git a/config/charts/inferencepool/SLO-ROUTING-README.md b/config/charts/inferencepool/SLO-ROUTING-README.md new file mode 100644 index 000000000..ea8c8d801 --- /dev/null +++ b/config/charts/inferencepool/SLO-ROUTING-README.md @@ -0,0 +1,296 @@ +# SLO-Aware Routing with Latency Prediction + +This document describes the modifications made to the InferencePool Helm chart to support SLO-aware routing with latency prediction sidecars. + +## Overview + +The SLO-aware routing feature enables intelligent request routing based on predicted latency using machine learning models. The system consists of: + +1. **EPP (Endpoint Picker) Container**: Main routing logic with latency prediction enabled +2. **Training Server Sidecar**: Continuously trains XGBoost models on observed latency metrics +3. **Prediction Server Sidecars**: Multiple replicas that serve latency predictions for TTFT (Time to First Token) and TPOT (Time Per Output Token) + +## Architecture + +``` +┌─────────────────────────────────────────────────────┐ +│ EPP Pod │ +├──────────────┬──────────────┬──────────────────────┤ +│ EPP │ Training │ Prediction Servers │ +│ Container │ Server │ (3 replicas) │ +│ │ │ │ +│ Port 9002 │ Port 8000 │ Ports 8001-8003 │ +│ (ext-proc) │ (training) │ (prediction) │ +└──────────────┴──────────────┴──────────────────────┘ + │ │ │ + │ └──────┬───────────┘ + │ │ + │ Model Training + │ & Synchronization + │ + Routing Decision + (with latency prediction) +``` + +## Modified Files + +### 1. `templates/epp-deployment.yaml` +- Added support for `sidecars.trainingServer` configuration +- Added support for `sidecars.predictionServers` with configurable replicas +- Automatically creates volumes for model storage +- Injects ConfigMaps for training and prediction server configuration + +### 2. `templates/epp-service.yaml` +- Automatically exposes ports for training server (8000) +- Automatically exposes ports for prediction servers (8001-8003 by default) +- Ports are only added when sidecars are enabled + +### 3. `templates/latency-predictor-config.yaml` (NEW) +- Creates ConfigMap for training server configuration +- Creates ConfigMap for prediction server configuration +- Supports customizable model paths, retraining intervals, and other parameters + +### 4. `values.yaml` +- Added comprehensive `sidecars` section with commented examples +- Supports configuration for training and prediction server images, resources, and behavior + +### 5. `values-slo-example.yaml` (NEW) +- Complete working example of SLO-aware routing configuration +- Demonstrates all required settings including EPP flags, environment variables, and plugin configuration + +## Usage + +### Quick Start with Example Configuration + +```bash +# Install with SLO-aware routing enabled +helm install my-slo-pool oci://registry.k8s.io/gateway-api-inference-extension/charts/inferencepool \ + --namespace inference \ + --values values-slo-example.yaml \ + --set inferencePool.modelServers.matchLabels.app=my-model-server +``` + +### Custom Configuration + +Create a custom values file: + +```yaml +inferenceExtension: + image: + hub: quay.io/your-org + name: epp + tag: slo-experimental + + flags: + - name: enable-latency-predictor + value: "true" + - name: v + value: "4" + + env: + - name: PREDICTION_SERVER_URL + value: "http://localhost:8001,http://localhost:8002,http://localhost:8003" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + - name: LATENCY_MAX_SAMPLE_SIZE + value: "10000" + + pluginsCustomConfig: + slo-plugins.yaml: | + apiVersion: inference.networking.x-k8s.io/v1alpha1 + kind: EndpointPickerConfig + plugins: + - type: slo-request-tracker + - type: slo-scorer + - type: slo-aware-profile-handler + schedulingProfiles: + - name: slo + plugins: + - pluginRef: slo-request-tracker + - pluginRef: slo-scorer + + sidecars: + trainingServer: + enabled: true + image: + hub: quay.io/your-org + name: latency-training + tag: latest + resources: + requests: + cpu: "2000m" + memory: "4Gi" + limits: + cpu: "4000m" + memory: "8Gi" + + predictionServers: + enabled: true + replicas: 3 + image: + hub: quay.io/your-org + name: latency-prediction + tag: latest + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" +``` + +## Configuration Reference + +### Training Server Configuration + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `sidecars.trainingServer.enabled` | Enable training server sidecar | `false` | +| `sidecars.trainingServer.image.hub` | Container registry | `us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension` | +| `sidecars.trainingServer.image.name` | Image name | `latency-training` | +| `sidecars.trainingServer.image.tag` | Image tag | `latest` | +| `sidecars.trainingServer.config.retrainingIntervalSec` | Retraining interval in seconds | `1` | +| `sidecars.trainingServer.config.minSamplesForRetrain` | Minimum samples before retraining | `100` | +| `sidecars.trainingServer.config.modelType` | ML model type | `xgboost` | +| `sidecars.trainingServer.persistence.enabled` | Enable persistent storage for models | `false` | + +### Prediction Server Configuration + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `sidecars.predictionServers.enabled` | Enable prediction server sidecars | `false` | +| `sidecars.predictionServers.replicas` | Number of prediction server replicas | `3` | +| `sidecars.predictionServers.image.hub` | Container registry | `us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension` | +| `sidecars.predictionServers.image.name` | Image name | `latency-prediction` | +| `sidecars.predictionServers.image.tag` | Image tag | `latest` | +| `sidecars.predictionServers.config.modelSyncIntervalSec` | Model sync interval in seconds | `10` | +| `sidecars.predictionServers.config.modelType` | ML model type | `xgboost` | + +### EPP Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `PREDICTION_SERVER_URL` | Comma-separated prediction server URLs | `http://localhost:8001,http://localhost:8002,http://localhost:8003` | +| `TRAINING_SERVER_URL` | Training server URL | `http://localhost:8000` | +| `LATENCY_MAX_SAMPLE_SIZE` | Maximum sample size for latency prediction | `10000` | +| `NEG_HEADROOM_TPOT_WEIGHT` | Weight for TPOT in negative headroom calculation | `0.2` | +| `NEG_HEADROOM_TTFT_WEIGHT` | Weight for TTFT in negative headroom calculation | `0.8` | + +## Building Container Images + +### Prerequisites + +```bash +cd /path/to/gateway-api-inference-extension +git checkout slo-prediction-experimental +``` + +### Build EPP Image + +```bash +export IMAGE_REGISTRY="quay.io/your-org" +export EPP_TAG="slo-experimental" +make image-build image-push +``` + +### Build Latency Predictor Images + +```bash +cd latencypredictor-v1 + +# Edit build-deploy.sh to set your registry +# Then build and push: +./build-deploy.sh build + +# Tag and push manually +docker tag latencypredictor-v2-training-server:latest ${IMAGE_REGISTRY}/latency-training:slo-experimental +docker tag latencypredictor-v2-prediction-server:latest ${IMAGE_REGISTRY}/latency-prediction:slo-experimental +docker push ${IMAGE_REGISTRY}/latency-training:slo-experimental +docker push ${IMAGE_REGISTRY}/latency-prediction:slo-experimental +``` + +## Verification + +After deployment, verify all containers are running: + +```bash +# Check pod status +kubectl get pods -n your-namespace + +# Expected: 1 pod with 5 containers (1 EPP + 1 training + 3 prediction) + +# Check EPP logs +kubectl logs -n your-namespace -c epp + +# Check training server logs +kubectl logs -n your-namespace -c training-server + +# Check prediction server logs +kubectl logs -n your-namespace -c prediction-server-1 +``` + +## Service Ports + +When sidecars are enabled, the service automatically exposes these ports: + +- `9002`: EPP gRPC ext-proc (always) +- `9090`: EPP metrics (always) +- `8000`: Training server (when `trainingServer.enabled: true`) +- `8001-800N`: Prediction servers (when `predictionServers.enabled: true`, N = replicas) + +## Plugins + +The SLO-aware routing requires these plugins: + +- `slo-request-tracker`: Tracks request SLO requirements +- `slo-scorer`: Scores endpoints based on predicted latency vs SLO +- `slo-aware-profile-handler`: Handles different scheduling profiles +- `max-score-picker`: Selects endpoint with maximum score + +### Scheduling Profiles + +- **default**: Standard routing with queue and kv-cache scoring +- **slo**: SLO-aware routing using latency predictions + +## Troubleshooting + +### Sidecars Not Starting + +Check if images are accessible: +```bash +kubectl describe pod -n your-namespace +``` + +### Training Server Issues + +Check ConfigMap and logs: +```bash +kubectl get configmap latency-predictor-config -n your-namespace -o yaml +kubectl logs -c training-server -n your-namespace +``` + +### Prediction Server Issues + +Verify prediction servers can reach training server: +```bash +kubectl exec -c prediction-server-1 -n your-namespace -- \ + curl http://localhost:8000/healthz +``` + +## Integration with llm-d + +To use this chart in llm-d, update your helmfile: + +```yaml +releases: + - name: gaie-slo + namespace: llm-d-slo + chart: oci://quay.io/your-org/charts/inferencepool + version: v1.0.1-slo + values: + - gaie-slo/values.yaml + - gaie-slo/values-slo.yaml +``` + +See the main documentation for complete integration instructions. diff --git a/config/charts/inferencepool/templates/epp-deployment.yaml b/config/charts/inferencepool/templates/epp-deployment.yaml index f01699a96..556727366 100644 --- a/config/charts/inferencepool/templates/epp-deployment.yaml +++ b/config/charts/inferencepool/templates/epp-deployment.yaml @@ -31,9 +31,12 @@ spec: - "json" - --config-file - "/config/{{ .Values.inferenceExtension.pluginsConfigFile }}" + {{- if ne .Values.inferencePool.apiVersion "inference.networking.k8s.io/v1" }} + - --pool-group + - "{{ (split "/" .Values.inferencePool.apiVersion)._0 }}" + {{- end }} {{- range .Values.inferenceExtension.flags }} - - "--{{ .name }}" - - "{{ .value }}" + - "--{{ .name }}={{ .value }}" {{- end }} {{- if eq (.Values.inferencePool.modelServerType | default "vllm") "triton-tensorrt-llm" }} - --total-queued-requests-metric @@ -84,10 +87,142 @@ spec: volumeMounts: - name: plugins-config-volume mountPath: "/config" + {{- if .Values.inferenceExtension.sidecars }} + {{- if .Values.inferenceExtension.sidecars.trainingServer }} + {{- if .Values.inferenceExtension.sidecars.trainingServer.enabled }} + # Training Server Sidecar Container + - name: training-server + image: {{ .Values.inferenceExtension.sidecars.trainingServer.image.hub }}/{{ .Values.inferenceExtension.sidecars.trainingServer.image.name }}:{{ .Values.inferenceExtension.sidecars.trainingServer.image.tag }} + imagePullPolicy: {{ .Values.inferenceExtension.sidecars.trainingServer.image.pullPolicy | default "Always" }} + ports: + - containerPort: 8000 + name: training-port + livenessProbe: + httpGet: + path: /healthz + port: 8000 + initialDelaySeconds: 30 + periodSeconds: 20 + readinessProbe: + httpGet: + path: /readyz + port: 8000 + initialDelaySeconds: 45 + periodSeconds: 10 + {{- if .Values.inferenceExtension.sidecars.trainingServer.resources }} + resources: + {{- toYaml .Values.inferenceExtension.sidecars.trainingServer.resources | nindent 10 }} + {{- end }} + envFrom: + {{- if .Values.inferenceExtension.sidecars.trainingServer.envFrom }} + {{- toYaml .Values.inferenceExtension.sidecars.trainingServer.envFrom | nindent 10 }} + {{- else }} + - configMapRef: + name: latency-predictor-config + {{- end }} + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "training" + {{- if .Values.inferenceExtension.sidecars.trainingServer.env }} + {{- toYaml .Values.inferenceExtension.sidecars.trainingServer.env | nindent 8 }} + {{- end }} + volumeMounts: + - name: training-server-storage + mountPath: /models + {{- end }} + {{- end }} + {{- if .Values.inferenceExtension.sidecars.predictionServers }} + {{- if .Values.inferenceExtension.sidecars.predictionServers.enabled }} + {{- $replicas := int (.Values.inferenceExtension.sidecars.predictionServers.replicas | default 3) }} + {{- range $i := until $replicas }} + # Prediction Server Sidecar Container {{ add $i 1 }} + - name: prediction-server-{{ add $i 1 }} + image: {{ $.Values.inferenceExtension.sidecars.predictionServers.image.hub }}/{{ $.Values.inferenceExtension.sidecars.predictionServers.image.name }}:{{ $.Values.inferenceExtension.sidecars.predictionServers.image.tag }} + imagePullPolicy: {{ $.Values.inferenceExtension.sidecars.predictionServers.image.pullPolicy | default "Always" }} + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "{{ add 8001 $i }}"] + ports: + - containerPort: {{ add 8001 $i }} + name: predict-port-{{ add $i 1 }} + livenessProbe: + httpGet: + path: /healthz + port: {{ add 8001 $i }} + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: {{ add 8001 $i }} + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + {{- if $.Values.inferenceExtension.sidecars.predictionServers.resources }} + resources: + {{- toYaml $.Values.inferenceExtension.sidecars.predictionServers.resources | nindent 10 }} + {{- end }} + envFrom: + {{- if $.Values.inferenceExtension.sidecars.predictionServers.envFrom }} + {{- toYaml $.Values.inferenceExtension.sidecars.predictionServers.envFrom | nindent 10 }} + {{- else }} + - configMapRef: + name: prediction-server-config + {{- end }} + env: + - name: PREDICT_PORT + value: "{{ add 8001 $i }}" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-{{ add $i 1 }}" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + {{- if $.Values.inferenceExtension.sidecars.predictionServers.env }} + {{- toYaml $.Values.inferenceExtension.sidecars.predictionServers.env | nindent 8 }} + {{- end }} + volumeMounts: + - name: prediction-server-{{ add $i 1 }}-storage + mountPath: /server_models + {{- end }} + {{- end }} + {{- end }} + {{- end }} volumes: - name: plugins-config-volume configMap: name: {{ include "gateway-api-inference-extension.name" . }} + {{- if .Values.inferenceExtension.sidecars }} + {{- if .Values.inferenceExtension.sidecars.trainingServer }} + {{- if .Values.inferenceExtension.sidecars.trainingServer.enabled }} + - name: training-server-storage + {{- if .Values.inferenceExtension.sidecars.trainingServer.persistence }} + {{- if .Values.inferenceExtension.sidecars.trainingServer.persistence.enabled }} + persistentVolumeClaim: + claimName: {{ .Values.inferenceExtension.sidecars.trainingServer.persistence.claimName | default "training-models-pvc" }} + {{- else }} + emptyDir: {} + {{- end }} + {{- else }} + emptyDir: {} + {{- end }} + {{- end }} + {{- end }} + {{- if .Values.inferenceExtension.sidecars.predictionServers }} + {{- if .Values.inferenceExtension.sidecars.predictionServers.enabled }} + {{- $replicas := int (.Values.inferenceExtension.sidecars.predictionServers.replicas | default 3) }} + {{- range $i := until $replicas }} + - name: prediction-server-{{ add $i 1 }}-storage + emptyDir: {} + {{- end }} + {{- end }} + {{- end }} + {{- end }} {{- if .Values.inferenceExtension.affinity }} affinity: {{- toYaml .Values.inferenceExtension.affinity | nindent 8 }} diff --git a/config/charts/inferencepool/templates/epp-service.yaml b/config/charts/inferencepool/templates/epp-service.yaml index b1a48df91..b39a58c82 100644 --- a/config/charts/inferencepool/templates/epp-service.yaml +++ b/config/charts/inferencepool/templates/epp-service.yaml @@ -12,9 +12,33 @@ spec: - name: grpc-ext-proc protocol: TCP port: {{ .Values.inferenceExtension.extProcPort | default 9002 }} + targetPort: {{ .Values.inferenceExtension.extProcPort | default 9002 }} + appProtocol: http2 - name: http-metrics protocol: TCP port: {{ .Values.inferenceExtension.metricsPort | default 9090 }} + targetPort: {{ .Values.inferenceExtension.metricsPort | default 9090 }} + {{- if .Values.inferenceExtension.sidecars }} + {{- if .Values.inferenceExtension.sidecars.trainingServer }} + {{- if .Values.inferenceExtension.sidecars.trainingServer.enabled }} + - name: latency-training + protocol: TCP + port: 8000 + targetPort: 8000 + {{- end }} + {{- end }} + {{- if .Values.inferenceExtension.sidecars.predictionServers }} + {{- if .Values.inferenceExtension.sidecars.predictionServers.enabled }} + {{- $replicas := int (.Values.inferenceExtension.sidecars.predictionServers.replicas | default 3) }} + {{- range $i := until $replicas }} + - name: latency-predict-{{ add $i 1 }} + protocol: TCP + port: {{ add 8001 $i }} + targetPort: {{ add 8001 $i }} + {{- end }} + {{- end }} + {{- end }} + {{- end }} {{- with .Values.inferenceExtension.extraServicePorts }} {{- toYaml . | nindent 4 }} {{- end }} diff --git a/config/charts/inferencepool/templates/latency-predictor-config.yaml b/config/charts/inferencepool/templates/latency-predictor-config.yaml new file mode 100644 index 000000000..d54a388bc --- /dev/null +++ b/config/charts/inferencepool/templates/latency-predictor-config.yaml @@ -0,0 +1,41 @@ +{{- if .Values.inferenceExtension.sidecars }} +{{- if or .Values.inferenceExtension.sidecars.trainingServer.enabled .Values.inferenceExtension.sidecars.predictionServers.enabled }} +--- +# ConfigMap for Training Server Configuration +apiVersion: v1 +kind: ConfigMap +metadata: + name: latency-predictor-config + namespace: {{ .Release.Namespace }} + labels: + {{- include "gateway-api-inference-extension.labels" . | nindent 4 }} +data: + LATENCY_RETRAINING_INTERVAL_SEC: {{ .Values.inferenceExtension.sidecars.trainingServer.config.retrainingIntervalSec | default "1" | quote }} + LATENCY_MIN_SAMPLES_FOR_RETRAIN: {{ .Values.inferenceExtension.sidecars.trainingServer.config.minSamplesForRetrain | default "100" | quote }} + LATENCY_TTFT_MODEL_PATH: {{ .Values.inferenceExtension.sidecars.trainingServer.config.ttftModelPath | default "/models/ttft.joblib" | quote }} + LATENCY_TPOT_MODEL_PATH: {{ .Values.inferenceExtension.sidecars.trainingServer.config.tpotModelPath | default "/models/tpot.joblib" | quote }} + LATENCY_TTFT_SCALER_PATH: {{ .Values.inferenceExtension.sidecars.trainingServer.config.ttftScalerPath | default "/models/ttft_scaler.joblib" | quote }} + LATENCY_TPOT_SCALER_PATH: {{ .Values.inferenceExtension.sidecars.trainingServer.config.tpotScalerPath | default "/models/tpot_scaler.joblib" | quote }} + LATENCY_MODEL_TYPE: {{ .Values.inferenceExtension.sidecars.trainingServer.config.modelType | default "xgboost" | quote }} +--- +# ConfigMap for Prediction Server Configuration +apiVersion: v1 +kind: ConfigMap +metadata: + name: prediction-server-config + namespace: {{ .Release.Namespace }} + labels: + {{- include "gateway-api-inference-extension.labels" . | nindent 4 }} +data: + MODEL_SYNC_INTERVAL_SEC: {{ .Values.inferenceExtension.sidecars.predictionServers.config.modelSyncIntervalSec | default "10" | quote }} + LATENCY_MODEL_TYPE: {{ .Values.inferenceExtension.sidecars.predictionServers.config.modelType | default "xgboost" | quote }} + PREDICT_HOST: {{ .Values.inferenceExtension.sidecars.predictionServers.config.host | default "0.0.0.0" | quote }} + PREDICT_PORT: {{ .Values.inferenceExtension.sidecars.predictionServers.config.port | default "8001" | quote }} + TRAINING_SERVER_URL: {{ .Values.inferenceExtension.sidecars.predictionServers.config.trainingServerUrl | default "http://localhost:8000" | quote }} + LOCAL_TTFT_MODEL_PATH: {{ .Values.inferenceExtension.sidecars.predictionServers.config.localTtftModelPath | default "/local_models/ttft.joblib" | quote }} + LOCAL_TPOT_MODEL_PATH: {{ .Values.inferenceExtension.sidecars.predictionServers.config.localTpotModelPath | default "/local_models/tpot.joblib" | quote }} + LOCAL_TTFT_SCALER_PATH: {{ .Values.inferenceExtension.sidecars.predictionServers.config.localTtftScalerPath | default "/local_models/ttft_scaler.joblib" | quote }} + LOCAL_TPOT_SCALER_PATH: {{ .Values.inferenceExtension.sidecars.predictionServers.config.localTpotScalerPath | default "/local_models/tpot_scaler.joblib" | quote }} + HTTP_TIMEOUT: {{ .Values.inferenceExtension.sidecars.predictionServers.config.httpTimeout | default "30" | quote }} +{{- end }} +{{- end }} diff --git a/config/charts/inferencepool/values-slo-example.yaml b/config/charts/inferencepool/values-slo-example.yaml new file mode 100644 index 000000000..04d5016be --- /dev/null +++ b/config/charts/inferencepool/values-slo-example.yaml @@ -0,0 +1,124 @@ +# Example values file for SLO-aware routing with latency prediction +# This file demonstrates how to enable and configure the SLO prediction sidecars + +inferenceExtension: + replicas: 1 + image: + name: epp + hub: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension + tag: main + pullPolicy: Always + extProcPort: 9002 + pluginsConfigFile: "slo-plugins.yaml" # Use custom SLO plugins config + + # Enable latency prediction flag + flags: + - name: enable-latency-predictor + value: "true" + - name: v + value: "4" + + # EPP environment variables for SLO prediction + env: + - name: PREDICTION_SERVER_URL + value: "http://localhost:8001,http://localhost:8002,http://localhost:8003" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + - name: LATENCY_MAX_SAMPLE_SIZE + value: "10000" + - name: NEG_HEADROOM_TPOT_WEIGHT + value: "0.2" + - name: NEG_HEADROOM_TTFT_WEIGHT + value: "0.8" + + # Custom plugins configuration for SLO routing + pluginsCustomConfig: + slo-plugins.yaml: | + apiVersion: inference.networking.x-k8s.io/v1alpha1 + kind: EndpointPickerConfig + plugins: + - type: queue-scorer + - type: kv-cache-utilization-scorer + - type: slo-request-tracker + - type: slo-scorer + - type: slo-aware-profile-handler + - type: max-score-picker + schedulingProfiles: + - name: default + plugins: + - pluginRef: slo-request-tracker + - pluginRef: queue-scorer + - pluginRef: kv-cache-utilization-scorer + - pluginRef: max-score-picker + - name: slo + plugins: + - pluginRef: slo-request-tracker + - pluginRef: slo-scorer + - pluginRef: max-score-picker + + # Enable SLO prediction sidecars + sidecars: + trainingServer: + enabled: true + image: + hub: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension + name: latency-training + tag: latest + pullPolicy: Always + resources: + requests: + cpu: "2000m" + memory: "4Gi" + limits: + cpu: "4000m" + memory: "8Gi" + config: + retrainingIntervalSec: "1" + minSamplesForRetrain: "100" + ttftModelPath: "/models/ttft.joblib" + tpotModelPath: "/models/tpot.joblib" + ttftScalerPath: "/models/ttft_scaler.joblib" + tpotScalerPath: "/models/tpot_scaler.joblib" + modelType: "xgboost" + persistence: + enabled: false # Set to true if you want persistent model storage + # claimName: "training-models-pvc" + + predictionServers: + enabled: true + replicas: 3 # Number of prediction server replicas + image: + hub: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension + name: latency-prediction + tag: latest + pullPolicy: Always + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + config: + modelSyncIntervalSec: "10" + modelType: "xgboost" + host: "0.0.0.0" + port: "8001" + trainingServerUrl: "http://localhost:8000" + localTtftModelPath: "/local_models/ttft.joblib" + localTpotModelPath: "/local_models/tpot.joblib" + localTtftScalerPath: "/local_models/ttft_scaler.joblib" + localTpotScalerPath: "/local_models/tpot_scaler.joblib" + httpTimeout: "30" + +inferencePool: + targetPorts: + - number: 8000 + modelServerType: vllm + apiVersion: inference.networking.k8s.io/v1 + # modelServers: + # matchLabels: + # app: vllm-llama3-8b-instruct + +provider: + name: none diff --git a/config/charts/inferencepool/values.yaml b/config/charts/inferencepool/values.yaml index d45e6ed39..13e57cb79 100644 --- a/config/charts/inferencepool/values.yaml +++ b/config/charts/inferencepool/values.yaml @@ -40,6 +40,66 @@ inferenceExtension: tolerations: [] + # SLO-aware routing with latency prediction sidecars + # Uncomment and configure to enable SLO prediction + # sidecars: + # trainingServer: + # enabled: false + # image: + # hub: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension + # name: latency-training + # tag: latest + # pullPolicy: Always + # resources: + # requests: + # cpu: "2000m" + # memory: "4Gi" + # limits: + # cpu: "4000m" + # memory: "8Gi" + # config: + # retrainingIntervalSec: "1" + # minSamplesForRetrain: "100" + # ttftModelPath: "/models/ttft.joblib" + # tpotModelPath: "/models/tpot.joblib" + # ttftScalerPath: "/models/ttft_scaler.joblib" + # tpotScalerPath: "/models/tpot_scaler.joblib" + # modelType: "xgboost" + # persistence: + # enabled: false + # claimName: "training-models-pvc" + # env: [] + # envFrom: [] + # + # predictionServers: + # enabled: false + # replicas: 3 + # image: + # hub: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension + # name: latency-prediction + # tag: latest + # pullPolicy: Always + # resources: + # requests: + # cpu: "500m" + # memory: "1Gi" + # limits: + # cpu: "1000m" + # memory: "2Gi" + # config: + # modelSyncIntervalSec: "10" + # modelType: "xgboost" + # host: "0.0.0.0" + # port: "8001" + # trainingServerUrl: "http://localhost:8000" + # localTtftModelPath: "/local_models/ttft.joblib" + # localTpotModelPath: "/local_models/tpot.joblib" + # localTtftScalerPath: "/local_models/ttft_scaler.joblib" + # localTpotScalerPath: "/local_models/tpot_scaler.joblib" + # httpTimeout: "30" + # env: [] + # envFrom: [] + inferencePool: targetPorts: - number: 8000 diff --git a/latencypredictor-v1/prediction_server.py b/latencypredictor-v1/prediction_server.py index ca841fb58..5eb7261ad 100644 --- a/latencypredictor-v1/prediction_server.py +++ b/latencypredictor-v1/prediction_server.py @@ -731,6 +731,114 @@ async def readiness_check(): } +# P/D Disaggregation Prediction Models + +class PDPredictionRequest(BaseModel): + """Request for P/D disaggregation multi-phase prediction.""" + kv_cache_percentage: float = Field(..., ge=0.0, le=1.0) + input_token_length: int = Field(..., ge=0) + num_request_waiting: int = Field(..., ge=0) + num_request_running: int = Field(..., ge=0) + num_tokens_generated: int = Field(..., ge=0, description="For TPOT prediction") + prefix_cache_score: float = Field(..., ge=0.0, le=1.0, description="Prefix cache hit ratio score") + kv_cache_size_mb: float = Field(..., ge=0.0, description="KV cache size in MB for transfer prediction") + network_bandwidth_gbps: float = Field(100.0, gt=0.0, description="Network bandwidth in Gbps (default: 100 for RDMA)") + pod_role: str = Field(..., pattern="^(prefill|decode)$", description="Pod role: 'prefill' or 'decode'") + + +class PDPredictionResponse(BaseModel): + """Response for P/D disaggregation multi-phase prediction.""" + prefill_ttft_ms: float = Field(..., description="Predicted prefill phase TTFT") + kv_transfer_ms: float = Field(..., description="Predicted KV transfer latency") + decode_tpot_ms: float = Field(..., description="Predicted decode phase TPOT") + total_ttft_ms: float = Field(..., description="Total TTFT (prefill + KV transfer)") + predicted_at: datetime + model_type: str + quantile: float + kv_transfer_overhead_ms: float = Field(..., description="Fixed KV transfer overhead") + + +def estimate_kv_transfer_latency(kv_cache_size_mb: float, bandwidth_gbps: float) -> Tuple[float, float]: + """ + Estimate KV transfer latency based on cache size and bandwidth. + + Args: + kv_cache_size_mb: Size of KV cache in MB + bandwidth_gbps: Network bandwidth in Gbps + + Returns: + Tuple of (transfer_time_ms, overhead_ms) + """ + # Fixed overhead for setup (connection, handshake, etc.) + overhead_ms = float(os.getenv("KV_TRANSFER_OVERHEAD_MS", "5.0")) + + # Convert bandwidth from Gbps to MB/s: 1 Gbps = 125 MB/s + bandwidth_MB_per_sec = bandwidth_gbps * 125.0 + + # Transfer time = size / bandwidth (in seconds), convert to ms + transfer_time_ms = (kv_cache_size_mb / bandwidth_MB_per_sec) * 1000.0 + + return transfer_time_ms, overhead_ms + + +@app.post("/predict/pd", response_model=PDPredictionResponse) +async def predict_pd_endpoint(request: PDPredictionRequest): + """ + Make multi-phase latency predictions for P/D disaggregation. + + This endpoint predicts: + - Prefill phase TTFT + - KV transfer latency + - Decode phase TPOT + + The predictions account for the role of the pod (prefill vs decode) and + estimate KV transfer time based on cache size and network bandwidth. + """ + try: + # Get base predictions for prefill and decode + features = request.dict() + + # Predict prefill TTFT (with prefix cache score) + prefill_ttft, _ = predictor.predict(features) + prefill_ttft = max(0, prefill_ttft) + + # Predict decode TPOT (without prefix cache impact) + decode_features = features.copy() + decode_features['prefix_cache_score'] = 0.0 # Decode doesn't benefit from prefix cache + _, decode_tpot = predictor.predict(decode_features) + decode_tpot = max(0, decode_tpot) + + # Estimate KV transfer latency + transfer_time_ms, overhead_ms = estimate_kv_transfer_latency( + request.kv_cache_size_mb, + request.network_bandwidth_gbps + ) + kv_transfer_ms = transfer_time_ms + overhead_ms + + # Calculate total TTFT (prefill + KV transfer) + total_ttft_ms = prefill_ttft + kv_transfer_ms + + logging.info(f"P/D Prediction: prefill={prefill_ttft:.2f}ms, kv_transfer={kv_transfer_ms:.2f}ms, " + f"decode_tpot={decode_tpot:.2f}ms, total_ttft={total_ttft_ms:.2f}ms") + + return PDPredictionResponse( + prefill_ttft_ms=prefill_ttft, + kv_transfer_ms=kv_transfer_ms, + decode_tpot_ms=decode_tpot, + total_ttft_ms=total_ttft_ms, + predicted_at=datetime.now(timezone.utc), + model_type=predictor.model_type.value, + quantile=predictor.quantile, + kv_transfer_overhead_ms=overhead_ms + ) + + except HTTPException: + raise + except Exception as e: + logging.error(f"P/D prediction failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"P/D prediction failed: {str(e)}") + + @app.get("/", include_in_schema=False) async def root(): """Root endpoint.""" @@ -741,7 +849,8 @@ async def root(): "description": f"Predicting {predictor.quantile:.0%} quantile for TTFT and TPOT latencies", "is_ready": predictor.is_ready, "sync_interval": settings.MODEL_SYNC_INTERVAL_SEC, - "training_server": settings.TRAINING_SERVER_URL + "training_server": settings.TRAINING_SERVER_URL, + "pd_support": "enabled" } diff --git a/latencypredictor-v1/training_server.py b/latencypredictor-v1/training_server.py index fa8a81118..51425bfe0 100644 --- a/latencypredictor-v1/training_server.py +++ b/latencypredictor-v1/training_server.py @@ -628,22 +628,22 @@ def train(self): if len(df_ttft) >= settings.MIN_SAMPLES_FOR_RETRAIN: # Updated TTFT features to include prefix_cache_score ttft_feature_cols_tree = [ - 'kv_cache_percentage','input_token_length','num_request_waiting', - 'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket' - ] - ttft_feature_cols_br = [ - 'kv_cache_percentage','input_token_length','num_request_waiting', - 'num_request_running','prefix_cache_score','effective_input_tokens' - ] - - # Build X_ttft for all model types, then trim for BR - X_ttft = df_ttft[ttft_feature_cols_tree] - if self.model_type == ModelType.BAYESIAN_RIDGE: - X_ttft = X_ttft[ttft_feature_cols_br] + 'kv_cache_percentage','input_token_length','num_request_waiting', + 'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket' + ] + ttft_feature_cols_br = [ + 'kv_cache_percentage','input_token_length','num_request_waiting', + 'num_request_running','prefix_cache_score','effective_input_tokens' + ] + + # Build X_ttft for all model types, then trim for BR + X_ttft = df_ttft[ttft_feature_cols_tree] + if self.model_type == ModelType.BAYESIAN_RIDGE: + X_ttft = X_ttft[ttft_feature_cols_br] - y_ttft = raw_ttft['actual_ttft_ms'] + y_ttft = raw_ttft['actual_ttft_ms'] - try: + try: # raw_ttft still has the original columns including 'prefix_cache_score' raw_ttft['_prefix_bucket'] = raw_ttft['prefix_cache_score'].clip(0, 1).apply( lambda s: min(int(s * self.prefix_buckets), self.prefix_buckets - 1) @@ -677,8 +677,6 @@ def train(self): new_ttft_model, new_ttft_scaler, test_records, cols, 'actual_ttft_ms' ) - - if ql is not None: self.ttft_quantile_loss_scores.append(ql) self.ttft_coverage_scores.append(coverage) @@ -690,7 +688,7 @@ def train(self): else: logging.info(f"TTFT model trained on {len(df_ttft)} samples. Quantile metrics = N/A (insufficient test data)") - except Exception: + except Exception: logging.error("Error training TTFT model", exc_info=True) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 0e0a1d03d..ba400e181 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -111,6 +111,29 @@ type RequestContext struct { TPOTObservations []float64 PredictedTPOTObservations []float64 + // -- P/D disaggregation SLO tracking fields -- + PDMode bool // True if prefill/decode disaggregation is enabled + PrefillPodName string // Name of the prefill pod + DecodePodName string // Name of the decode pod + PrefillStartTime time.Time // When prefill phase started + PrefillEndTime time.Time // When prefill phase completed + KVTransferStartTime time.Time // When KV transfer started + KVTransferEndTime time.Time // When KV transfer completed + DecodeStartTime time.Time // When decode phase started + PrefillTTFTBudget float64 // Allocated TTFT budget for prefill phase (ms) + KVTransferBudget float64 // Allocated TTFT budget for KV transfer (ms) + DecodeTPOTBudget float64 // TPOT budget for decode phase (ms) + ActualPrefillLatency float64 // Actual prefill latency (ms) + ActualKVTransferLatency float64 // Actual KV transfer latency (ms) + ActualDecodeTPOT float64 // Actual decode TPOT (ms) + RemainingTTFTBudget float64 // Remaining TTFT budget after prefill and KV transfer (ms) + RemainingTPOTBudget float64 // Remaining TPOT budget (ms) + PredictedPrefillTTFT float64 // Predicted prefill TTFT (ms) + PredictedKVTransferMs float64 // Predicted KV transfer latency (ms) + PredictedDecodeTPOT float64 // Predicted decode TPOT (ms) + PDSLOViolation bool // True if any phase violated its SLO budget + PDSLOViolationPhase string // Which phase violated SLO ("prefill", "kv_transfer", "decode") + Response *Response reqHeaderResp *extProcPb.ProcessingResponse diff --git a/pkg/epp/requestcontrol/pd_slo_tracker.go b/pkg/epp/requestcontrol/pd_slo_tracker.go new file mode 100644 index 000000000..192898c29 --- /dev/null +++ b/pkg/epp/requestcontrol/pd_slo_tracker.go @@ -0,0 +1,315 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers for P/D disaggregation SLO tracking. +package requestcontrol + +import ( + "context" + "os" + "strconv" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// Configuration constants for P/D SLO tracking + +// EnablePDSLOTracking enables/disables P/D SLO tracking +var EnablePDSLOTracking = func() bool { + if value, exists := os.LookupEnv("ENABLE_PD_SLO_TRACKING"); exists { + if parsedValue, err := strconv.ParseBool(value); err == nil { + return parsedValue + } + } + return true // default: enabled +}() + +// TTFTPrefillBudgetRatio is the fraction of TTFT SLO allocated to prefill phase +// The remainder (1 - ratio) is allocated to KV transfer +var TTFTPrefillBudgetRatio = func() float64 { + if value, exists := os.LookupEnv("TTFT_PREFILL_BUDGET_RATIO"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 && parsedValue <= 1 { + return parsedValue + } + } + return 0.7 // default: 70% for prefill, 30% for KV transfer +}() + +// KVTransferOverheadMs is the fixed overhead for KV transfer setup (ms) +var KVTransferOverheadMs = func() float64 { + if value, exists := os.LookupEnv("KV_TRANSFER_OVERHEAD_MS"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 5.0 // default: 5ms overhead +}() + +// KVTransferBandwidthGbps is the estimated network bandwidth for KV transfer (Gbps) +var KVTransferBandwidthGbps = func() float64 { + if value, exists := os.LookupEnv("KV_TRANSFER_BANDWIDTH_GBPS"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 100.0 // default: 100 Gbps (RDMA/RoCE) +}() + +// PDSLOBufferFactor is the safety margin for P/D SLO budgets +var PDSLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("PD_SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.9 // default: 90% of budget (10% safety margin) +}() + +// PDPredictionRequest extends the standard prediction request with P/D specific fields +type PDPredictionRequest struct { + latencypredictor.PredictionRequest + PodRole string // "prefill" | "decode" + KVCacheSizeMB float64 // Size of KV cache to transfer + NetworkBandwidthGbps float64 // Network bandwidth + IsDisaggregated bool // True if using P/D disaggregation +} + +// PDPredictionResponse contains predictions for all P/D phases +type PDPredictionResponse struct { + PrefillTTFT float64 // Prefill phase latency (ms) + KVTransferMs float64 // KV transfer latency (ms) + DecodeTPOT float64 // Decode phase TPOT (ms) + TotalTTFT float64 // prefill_ttft + kv_transfer_ms + PredictionTime time.Time +} + +// AllocatePDSLOBudgets allocates TTFT and TPOT SLO budgets across P/D phases +func AllocatePDSLOBudgets(ctx context.Context, reqCtx *handlers.RequestContext, ttftSLO, tpotSLO float64) { + logger := log.FromContext(ctx) + + if !reqCtx.PDMode { + logger.V(logutil.DEBUG).Info("P/D mode not enabled, skipping budget allocation") + return + } + + // Allocate TTFT budget: prefill gets TTFTPrefillBudgetRatio, KV transfer gets remainder + reqCtx.PrefillTTFTBudget = ttftSLO * TTFTPrefillBudgetRatio * PDSLOBufferFactor + reqCtx.KVTransferBudget = ttftSLO * (1 - TTFTPrefillBudgetRatio) * PDSLOBufferFactor + + // TPOT budget goes entirely to decode phase + reqCtx.DecodeTPOTBudget = tpotSLO * PDSLOBufferFactor + + // Initialize remaining budgets + reqCtx.RemainingTTFTBudget = ttftSLO + reqCtx.RemainingTPOTBudget = tpotSLO + + // Initialize phase start times + // Prefill starts immediately when budget is allocated + reqCtx.PrefillStartTime = time.Now() + // KV transfer will start when prefill completes (set in UpdatePrefillPhase) + // Decode starts after KV transfer (set in UpdateKVTransferPhase) + + logger.V(logutil.DEBUG).Info("Allocated P/D SLO budgets", + "ttft_slo", ttftSLO, + "tpot_slo", tpotSLO, + "prefill_budget", reqCtx.PrefillTTFTBudget, + "kv_transfer_budget", reqCtx.KVTransferBudget, + "decode_tpot_budget", reqCtx.DecodeTPOTBudget, + "prefill_ratio", TTFTPrefillBudgetRatio, + "buffer_factor", PDSLOBufferFactor) +} + +// UpdatePrefillPhase records prefill phase completion and updates budgets +func UpdatePrefillPhase(ctx context.Context, reqCtx *handlers.RequestContext) { + logger := log.FromContext(ctx) + + if !reqCtx.PDMode || reqCtx.PrefillStartTime.IsZero() { + return + } + + reqCtx.PrefillEndTime = time.Now() + reqCtx.ActualPrefillLatency = float64(reqCtx.PrefillEndTime.Sub(reqCtx.PrefillStartTime).Milliseconds()) + + // KV transfer starts immediately after prefill completes + reqCtx.KVTransferStartTime = reqCtx.PrefillEndTime + + // Update remaining TTFT budget + reqCtx.RemainingTTFTBudget -= reqCtx.ActualPrefillLatency + + // Check if prefill violated its budget + if reqCtx.ActualPrefillLatency > reqCtx.PrefillTTFTBudget { + reqCtx.PDSLOViolation = true + reqCtx.PDSLOViolationPhase = "prefill" + logger.Info("Prefill phase violated SLO budget", + "actual", reqCtx.ActualPrefillLatency, + "budget", reqCtx.PrefillTTFTBudget, + "overage", reqCtx.ActualPrefillLatency-reqCtx.PrefillTTFTBudget) + } + + logger.V(logutil.DEBUG).Info("Prefill phase completed", + "latency_ms", reqCtx.ActualPrefillLatency, + "budget_ms", reqCtx.PrefillTTFTBudget, + "remaining_ttft_budget", reqCtx.RemainingTTFTBudget, + "slo_violated", reqCtx.PDSLOViolation) +} + +// UpdateKVTransferPhase records KV transfer completion and updates budgets +func UpdateKVTransferPhase(ctx context.Context, reqCtx *handlers.RequestContext) { + logger := log.FromContext(ctx) + + if !reqCtx.PDMode || reqCtx.KVTransferStartTime.IsZero() { + return + } + + reqCtx.KVTransferEndTime = time.Now() + reqCtx.ActualKVTransferLatency = float64(reqCtx.KVTransferEndTime.Sub(reqCtx.KVTransferStartTime).Milliseconds()) + + // Decode starts immediately after KV transfer completes + reqCtx.DecodeStartTime = reqCtx.KVTransferEndTime + + // Update remaining TTFT budget + reqCtx.RemainingTTFTBudget -= reqCtx.ActualKVTransferLatency + + // Check if KV transfer violated its budget + if reqCtx.ActualKVTransferLatency > reqCtx.KVTransferBudget { + reqCtx.PDSLOViolation = true + reqCtx.PDSLOViolationPhase = "kv_transfer" + logger.Info("KV transfer phase violated SLO budget", + "actual", reqCtx.ActualKVTransferLatency, + "budget", reqCtx.KVTransferBudget, + "overage", reqCtx.ActualKVTransferLatency-reqCtx.KVTransferBudget) + } + + logger.V(logutil.DEBUG).Info("KV transfer phase completed", + "latency_ms", reqCtx.ActualKVTransferLatency, + "budget_ms", reqCtx.KVTransferBudget, + "remaining_ttft_budget", reqCtx.RemainingTTFTBudget, + "slo_violated", reqCtx.PDSLOViolation) +} + +// UpdateDecodePhase records decode phase TPOT and checks against budget +func UpdateDecodePhase(ctx context.Context, reqCtx *handlers.RequestContext, tpot float64) { + logger := log.FromContext(ctx) + + if !reqCtx.PDMode { + return + } + + reqCtx.ActualDecodeTPOT = tpot + reqCtx.RemainingTPOTBudget -= tpot + + // Check if decode violated its budget + if reqCtx.ActualDecodeTPOT > reqCtx.DecodeTPOTBudget { + reqCtx.PDSLOViolation = true + reqCtx.PDSLOViolationPhase = "decode" + logger.Info("Decode phase violated SLO budget", + "actual_tpot", reqCtx.ActualDecodeTPOT, + "budget", reqCtx.DecodeTPOTBudget, + "overage", reqCtx.ActualDecodeTPOT-reqCtx.DecodeTPOTBudget) + } + + logger.V(logutil.DEBUG).Info("Decode phase TPOT recorded", + "tpot_ms", reqCtx.ActualDecodeTPOT, + "budget_ms", reqCtx.DecodeTPOTBudget, + "remaining_tpot_budget", reqCtx.RemainingTPOTBudget, + "slo_violated", reqCtx.PDSLOViolation) +} + +// EstimateKVTransferLatency estimates KV transfer latency based on cache size and bandwidth +func EstimateKVTransferLatency(kvCacheSizeMB float64, bandwidthGbps float64) float64 { + if bandwidthGbps <= 0 { + bandwidthGbps = KVTransferBandwidthGbps + } + + // Convert bandwidth from Gbps to MB/s (megabytes per second) + // 1 Gbps = 1,000,000,000 bits/s = 125,000,000 bytes/s = 125 MB/s + bandwidthMBPerSec := bandwidthGbps * 125.0 + + // Transfer time = size / bandwidth (in seconds), convert to ms + transferTimeMs := (kvCacheSizeMB / bandwidthMBPerSec) * 1000.0 + + // Add fixed overhead + return transferTimeMs + KVTransferOverheadMs +} + +// PredictPDPhases predicts latency for all P/D phases +// This is a placeholder - actual implementation will call the extended latency predictor +func PredictPDPhases( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + prefillMetrics, decodeMetrics interface{}, + prompt string, + kvCacheSizeMB float64, +) (*PDPredictionResponse, error) { + logger := log.FromContext(ctx) + + // For now, use the existing predictor for prefill and decode + // TODO: Extend the actual latency predictor to support P/D phase predictions + + response := &PDPredictionResponse{ + PredictionTime: time.Now(), + } + + // Estimate KV transfer latency + response.KVTransferMs = EstimateKVTransferLatency(kvCacheSizeMB, KVTransferBandwidthGbps) + + // Use existing predictor for prefill TTFT (will be enhanced later) + // For now, return placeholder values + response.PrefillTTFT = 0 + response.DecodeTPOT = 0 + response.TotalTTFT = response.PrefillTTFT + response.KVTransferMs + + logger.V(logutil.DEBUG).Info("Predicted P/D phases", + "prefill_ttft", response.PrefillTTFT, + "kv_transfer_ms", response.KVTransferMs, + "decode_tpot", response.DecodeTPOT, + "total_ttft", response.TotalTTFT) + + return response, nil +} + +// GetPDSLOMetrics returns end-to-end P/D SLO compliance metrics +func GetPDSLOMetrics(reqCtx *handlers.RequestContext) map[string]interface{} { + if !reqCtx.PDMode { + return nil + } + + return map[string]interface{}{ + "pd_mode": reqCtx.PDMode, + "prefill_pod": reqCtx.PrefillPodName, + "decode_pod": reqCtx.DecodePodName, + "prefill_latency_ms": reqCtx.ActualPrefillLatency, + "kv_transfer_latency_ms": reqCtx.ActualKVTransferLatency, + "decode_tpot_ms": reqCtx.ActualDecodeTPOT, + "prefill_budget_ms": reqCtx.PrefillTTFTBudget, + "kv_transfer_budget_ms": reqCtx.KVTransferBudget, + "decode_tpot_budget_ms": reqCtx.DecodeTPOTBudget, + "remaining_ttft_budget_ms": reqCtx.RemainingTTFTBudget, + "remaining_tpot_budget_ms": reqCtx.RemainingTPOTBudget, + "slo_violation": reqCtx.PDSLOViolation, + "slo_violation_phase": reqCtx.PDSLOViolationPhase, + "total_ttft_ms": reqCtx.ActualPrefillLatency + reqCtx.ActualKVTransferLatency, + "predicted_prefill_ttft_ms": reqCtx.PredictedPrefillTTFT, + "predicted_kv_transfer_ms": reqCtx.PredictedKVTransferMs, + "predicted_decode_tpot_ms": reqCtx.PredictedDecodeTPOT, + "prefill_prediction_error_ms": reqCtx.ActualPrefillLatency - reqCtx.PredictedPrefillTTFT, + "kv_prediction_error_ms": reqCtx.ActualKVTransferLatency - reqCtx.PredictedKVTransferMs, + "decode_prediction_error_ms": reqCtx.ActualDecodeTPOT - reqCtx.PredictedDecodeTPOT, + } +} diff --git a/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go b/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go index cc57a9963..ae4a2b933 100644 --- a/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go +++ b/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go @@ -104,6 +104,42 @@ func (t *SLORequestTracker) PostResponse(ctx context.Context, reqCtx *handlers.R return } + // Detect PD disaggregation mode and allocate SLO budgets + if reqCtx.SchedulingResult != nil && !reqCtx.PDMode { + // Check if both prefill and decode profiles exist in results + _, hasPrefill := reqCtx.SchedulingResult.ProfileResults["prefill"] + _, hasDecode := reqCtx.SchedulingResult.ProfileResults["decode"] + + if hasPrefill && hasDecode { + // PD disaggregation is active + reqCtx.PDMode = true + + // Extract pod names from profile results + if prefillResult := reqCtx.SchedulingResult.ProfileResults["prefill"]; prefillResult != nil && len(prefillResult.TargetPods) > 0 { + reqCtx.PrefillPodName = prefillResult.TargetPods[0].GetPod().NamespacedName.Name + } + if decodeResult := reqCtx.SchedulingResult.ProfileResults["decode"]; decodeResult != nil && len(decodeResult.TargetPods) > 0 { + reqCtx.DecodePodName = decodeResult.TargetPods[0].GetPod().NamespacedName.Name + } + + // Allocate PD SLO budgets if SLOs are specified + if reqCtx.SchedulingRequest != nil { + ttftSLO := reqCtx.SchedulingRequest.TTFTSLO + tpotSLO := reqCtx.SchedulingRequest.AvgTPOTSLO + + if ttftSLO > 0 || tpotSLO > 0 { + logger.V(logutil.DEBUG).Info("Allocating PD SLO budgets", + "ttft_slo", ttftSLO, + "tpot_slo", tpotSLO, + "prefill_pod", reqCtx.PrefillPodName, + "decode_pod", reqCtx.DecodePodName) + + requestcontrol.AllocatePDSLOBudgets(ctx, reqCtx, ttftSLO, tpotSLO) + } + } + } + } + if err := requestcontrol.ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, reqCtx); err != nil { logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") } @@ -120,7 +156,19 @@ func (t *SLORequestTracker) PostResponseChunk(ctx context.Context, reqCtx *handl now := time.Now() if reqCtx.TTFT == 0 { + // First token received - this marks end of prefill phase requestcontrol.ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, reqCtx, now) + + // For PD disaggregation, track prefill and KV transfer phases + if reqCtx.PDMode { + logger.V(logutil.DEBUG).Info("First token received in PD mode, updating prefill and KV transfer phases") + + // Update prefill phase tracking + requestcontrol.UpdatePrefillPhase(ctx, reqCtx) + + // Update KV transfer phase tracking + requestcontrol.UpdateKVTransferPhase(ctx, reqCtx) + } } else { requestcontrol.ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, reqCtx, now) } @@ -135,6 +183,32 @@ func (t *SLORequestTracker) PostResponseComplete(ctx context.Context, reqCtx *ha return } + // For PD disaggregation, update decode phase tracking and check for SLO violations + if reqCtx.PDMode { + logger.V(logutil.DEBUG).Info("Response complete in PD mode, updating decode phase") + requestcontrol.UpdateDecodePhase(ctx, reqCtx, reqCtx.AvgTPOT) + + // Log PD SLO tracking summary + if reqCtx.PDSLOViolation { + logger.Info("PD SLO violation detected", + "violation_phase", reqCtx.PDSLOViolationPhase, + "prefill_latency_ms", reqCtx.ActualPrefillLatency, + "prefill_budget_ms", reqCtx.PrefillTTFTBudget, + "kv_transfer_latency_ms", reqCtx.ActualKVTransferLatency, + "kv_transfer_budget_ms", reqCtx.KVTransferBudget, + "decode_tpot_ms", reqCtx.ActualDecodeTPOT, + "decode_budget_ms", reqCtx.DecodeTPOTBudget) + } else { + logger.V(logutil.DEBUG).Info("PD SLO met successfully", + "prefill_latency_ms", reqCtx.ActualPrefillLatency, + "prefill_budget_ms", reqCtx.PrefillTTFTBudget, + "kv_transfer_latency_ms", reqCtx.ActualKVTransferLatency, + "kv_transfer_budget_ms", reqCtx.KVTransferBudget, + "decode_tpot_ms", reqCtx.ActualDecodeTPOT, + "decode_budget_ms", reqCtx.DecodeTPOTBudget) + } + } + if reqCtx.TTFT > 0 { logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", reqCtx.TTFT, "avgPredictedTTFT", reqCtx.PredictedTTFT) metrics.RecordRequestTTFT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.TTFT/1000)