diff --git a/Cargo.lock b/Cargo.lock index 59ac5fda0..7c10c2e96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2211,6 +2211,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" + [[package]] name = "documented" version = "0.9.2" @@ -2299,6 +2305,7 @@ dependencies = [ "or-panic", "parcelona", "pin-project", + "proxy-protocol", "ra-rpc", "ra-tls", "rand 0.8.5", @@ -5619,6 +5626,16 @@ dependencies = [ "prost 0.13.5", ] +[[package]] +name = "proxy-protocol" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e50c72c21c738f5c5f350cc33640aee30bf7cd20f9d9da20ed41bce2671d532" +dependencies = [ + "bytes", + "snafu", +] + [[package]] name = "prpc" version = "0.6.0" @@ -7239,6 +7256,27 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab12d3c261b2308b0d80c26fffb58d17eba81a4be97890101f416b478c79ca7" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1508efa03c362e23817f96cde18abed596a25219a8b2c66e8db33c03543d315b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.5.10" diff --git a/Cargo.toml b/Cargo.toml index d2d5a680d..ae5be4f96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -247,6 +247,7 @@ yaml-rust2 = "0.10.4" luks2 = "0.5.0" scopeguard = "1.2.0" tar = "0.4" +proxy-protocol = "0.5.0" [profile.release] panic = "abort" diff --git a/dstack-types/src/lib.rs b/dstack-types/src/lib.rs index 514e84fc1..82463cbf3 100644 --- a/dstack-types/src/lib.rs +++ b/dstack-types/src/lib.rs @@ -45,6 +45,18 @@ pub struct AppCompose { pub storage_fs: Option, #[serde(default, with = "human_size")] pub swap_size: u64, + /// Per-port attributes consumed by the gateway (e.g. PROXY protocol). + #[serde(default)] + pub ports: Vec, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct PortAttrs { + pub port: u16, + /// Whether the gateway should send a PROXY protocol header on outbound + /// connections to this port. + #[serde(default)] + pub pp: bool, } fn default_true() -> bool { diff --git a/dstack-util/src/system_setup.rs b/dstack-util/src/system_setup.rs index b7f59cfd1..d41d78bcd 100644 --- a/dstack-util/src/system_setup.rs +++ b/dstack-util/src/system_setup.rs @@ -53,7 +53,8 @@ use crate::{ use cert_client::CertRequestClient; use cmd_lib::run_fun as cmd; use dstack_gateway_rpc::{ - gateway_client::GatewayClient, RegisterCvmRequest, RegisterCvmResponse, WireGuardPeer, + gateway_client::GatewayClient, PortAttrs as RpcPortAttrs, PortAttrsList, RegisterCvmRequest, + RegisterCvmResponse, WireGuardPeer, }; use ra_tls::rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}; use serde_human_bytes as hex_bytes; @@ -446,11 +447,24 @@ impl<'a> GatewayContext<'a> { gateway_url: &str, key_store: &GatewayKeyStore, ) -> Result { + let port_attrs = PortAttrsList { + attrs: self + .shared + .app_compose + .ports + .iter() + .map(|p| RpcPortAttrs { + port: p.port as u32, + pp: p.pp, + }) + .collect(), + }; let client = self.create_gateway_client(gateway_url, &key_store.client_key, &key_store.client_cert)?; let result = client .register_cvm(RegisterCvmRequest { client_public_key: key_store.wg_pk.clone(), + port_attrs: Some(port_attrs.clone()), }) .await .context("Failed to register CVM"); @@ -471,6 +485,7 @@ impl<'a> GatewayContext<'a> { client .register_cvm(RegisterCvmRequest { client_public_key: key_store.wg_pk.clone(), + port_attrs: Some(port_attrs), }) .await .context("Failed to register CVM") diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 1f57ebf06..a1aefe17f 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -54,6 +54,7 @@ hyper-rustls.workspace = true http-body-util.workspace = true x509-parser.workspace = true jemallocator.workspace = true +proxy-protocol.workspace = true wavekv.workspace = true tdx-attest.workspace = true flate2.workspace = true diff --git a/gateway/dstack-app/builder/entrypoint.sh b/gateway/dstack-app/builder/entrypoint.sh index 915c289e6..29c02c964 100755 --- a/gateway/dstack-app/builder/entrypoint.sh +++ b/gateway/dstack-app/builder/entrypoint.sh @@ -111,6 +111,7 @@ localhost_enabled = false app_address_ns_compat = true workers = ${PROXY_WORKERS:-32} max_connections_per_app = ${MAX_CONNECTIONS_PER_APP:-0} +inbound_pp_enabled = ${INBOUND_PP_ENABLED:-false} [core.proxy.timeouts] connect = "${TIMEOUT_CONNECT:-5s}" @@ -122,6 +123,13 @@ idle = "${TIMEOUT_IDLE:-10m}" write = "${TIMEOUT_WRITE:-5s}" shutdown = "${TIMEOUT_SHUTDOWN:-5s}" total = "${TIMEOUT_TOTAL:-5h}" +pp_header = "${TIMEOUT_PP_HEADER:-5s}" + +[core.proxy.port_attrs_fetch] +timeout = "${PORT_ATTRS_FETCH_TIMEOUT:-10s}" +max_retries = ${PORT_ATTRS_FETCH_MAX_RETRIES:-5} +backoff_initial = "${PORT_ATTRS_FETCH_BACKOFF_INITIAL:-1s}" +backoff_max = "${PORT_ATTRS_FETCH_BACKOFF_MAX:-30s}" [core.recycle] enabled = true diff --git a/gateway/dstack-app/deploy-to-vmm.sh b/gateway/dstack-app/deploy-to-vmm.sh index 65a61a185..51ee1c16f 100755 --- a/gateway/dstack-app/deploy-to-vmm.sh +++ b/gateway/dstack-app/deploy-to-vmm.sh @@ -31,6 +31,7 @@ if [ -f ".env" ]; then # Load variables from .env echo "Loading environment variables from .env file..." set -a + # shellcheck disable=SC1091 source .env set +a else @@ -92,7 +93,14 @@ GUEST_AGENT_ADDR=127.0.0.1:9206 WG_ADDR=0.0.0.0:9202 # The token used to launch the App -APP_LAUNCH_TOKEN=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 32 | head -n 1) +APP_LAUNCH_TOKEN=$(tr -dc 'a-zA-Z0-9' < /dev/urandom | fold -w 32 | head -n 1) + +# PROXY protocol: read v1/v2 header from inbound connections (e.g. when this +# gateway sits behind a PP-aware L4 LB such as Cloudflare Spectrum or haproxy +# with send-proxy). Set to "true" only if the upstream LB is configured to +# send PROXY headers; otherwise leave disabled or every connection will be +# rejected. +# INBOUND_PP_ENABLED=false EOF echo "Please edit the .env file and set the required variables, then run this script again." @@ -125,7 +133,7 @@ done CLI="../../vmm/src/vmm-cli.py --url $VMM_RPC" -WG_PORT=$(echo $WG_ADDR | cut -d':' -f2) +WG_PORT=$(echo "$WG_ADDR" | cut -d':' -f2) COMPOSE_TMP=$(mktemp) cp docker-compose.yaml "$COMPOSE_TMP" @@ -175,6 +183,7 @@ APP_LAUNCH_TOKEN=$APP_LAUNCH_TOKEN RPC_DOMAIN=$RPC_DOMAIN NODE_ID=$NODE_ID PROXY_LISTEN_PORT=$PROXY_LISTEN_PORT +INBOUND_PP_ENABLED=${INBOUND_PP_ENABLED:-false} EOF if [ -n "$APP_COMPOSE_FILE" ]; then diff --git a/gateway/dstack-app/docker-compose.yaml b/gateway/dstack-app/docker-compose.yaml index 1518834d4..76cd6e827 100644 --- a/gateway/dstack-app/docker-compose.yaml +++ b/gateway/dstack-app/docker-compose.yaml @@ -41,6 +41,12 @@ services: - TIMEOUT_TOTAL=${TIMEOUT_TOTAL:-5h} - ADMIN_LISTEN_ADDR=${ADMIN_LISTEN_ADDR:-0.0.0.0} - ADMIN_LISTEN_PORT=${ADMIN_LISTEN_PORT:-8001} + - INBOUND_PP_ENABLED=${INBOUND_PP_ENABLED:-false} + - TIMEOUT_PP_HEADER=${TIMEOUT_PP_HEADER:-5s} + - PORT_ATTRS_FETCH_TIMEOUT=${PORT_ATTRS_FETCH_TIMEOUT:-10s} + - PORT_ATTRS_FETCH_MAX_RETRIES=${PORT_ATTRS_FETCH_MAX_RETRIES:-5} + - PORT_ATTRS_FETCH_BACKOFF_INITIAL=${PORT_ATTRS_FETCH_BACKOFF_INITIAL:-1s} + - PORT_ATTRS_FETCH_BACKOFF_MAX=${PORT_ATTRS_FETCH_BACKOFF_MAX:-30s} restart: always volumes: diff --git a/gateway/gateway.toml b/gateway/gateway.toml index 09c669e5d..2f9a18cb1 100644 --- a/gateway/gateway.toml +++ b/gateway/gateway.toml @@ -58,6 +58,18 @@ workers = 32 external_port = 443 # Maximum concurrent connections per app. 0 means unlimited. max_connections_per_app = 2000 +# Whether to read PROXY protocol from inbound connections (e.g. from Cloudflare). +inbound_pp_enabled = false + +[core.proxy.port_attrs_fetch] +# Background lazy-fetch of port_attrs from legacy CVM agents. +# Single Info() RPC timeout. +timeout = "10s" +# Retries cover the WireGuard / agent warmup window after registration. +max_retries = 5 +# Exponential backoff between retries; doubles each attempt up to backoff_max. +backoff_initial = "1s" +backoff_max = "30s" [core.proxy.timeouts] # Timeout for establishing a connection to the target app. @@ -81,6 +93,8 @@ write = "5s" shutdown = "5s" # Timeout for total connection duration. total = "5h" +# Timeout for proxy protocol header. +pp_header = "5s" [core.recycle] enabled = true diff --git a/gateway/rpc/proto/gateway_rpc.proto b/gateway/rpc/proto/gateway_rpc.proto index f85d7f877..da0e5e7c3 100644 --- a/gateway/rpc/proto/gateway_rpc.proto +++ b/gateway/rpc/proto/gateway_rpc.proto @@ -12,6 +12,25 @@ package gateway; message RegisterCvmRequest { // The public key of the WireGuard interface of the CVM. string client_public_key = 1; + // Per-port attributes the gateway should apply when proxying to this CVM. + // Wrapped in a message so we can distinguish "not reported" (old CVM → + // gateway falls back to fetching app-compose via Info()) from "reported + // empty" (new CVM with no special port behaviour). + optional PortAttrsList port_attrs = 2; +} + +// PortAttrsList wraps a list of PortAttrs so it can be optional on the wire. +message PortAttrsList { + repeated PortAttrs attrs = 1; +} + +// PortAttrs declares per-port behaviour for the gateway. +message PortAttrs { + // The CVM port these attributes apply to. + uint32 port = 1; + // Whether the gateway should send a PROXY protocol header on outbound + // connections to this port. + bool pp = 2; } // DebugRegisterCvmRequest is the request for DebugRegisterCvm (only works when debug_mode is enabled). diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 45b899c01..965c887e0 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -117,6 +117,32 @@ pub struct ProxyConfig { pub app_address_ns_compat: bool, /// Maximum concurrent connections per app. 0 means unlimited. pub max_connections_per_app: u64, + /// Port the dstack guest-agent listens on inside each CVM. Used by the + /// gateway to fetch app metadata (e.g. port_attrs for legacy CVMs). + pub agent_port: u16, + /// Whether to read PROXY protocol headers from inbound connections + /// (e.g. when behind a PP-aware load balancer like Cloudflare). + #[serde(default)] + pub inbound_pp_enabled: bool, + /// Background lazy-fetch behaviour for `port_attrs` (legacy CVMs). + pub port_attrs_fetch: PortAttrsFetchConfig, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PortAttrsFetchConfig { + /// Timeout for a single `Info()` RPC attempt. + #[serde(with = "serde_duration")] + pub timeout: Duration, + /// Maximum number of attempts after the initial try (0 = no retry). + /// Retries cover the window where a freshly-registered CVM hasn't + /// finished its WireGuard handshake yet. + pub max_retries: u32, + /// Delay before the first retry; doubles on each subsequent retry, + /// capped at `backoff_max`. + #[serde(with = "serde_duration")] + pub backoff_initial: Duration, + #[serde(with = "serde_duration")] + pub backoff_max: Duration, } #[derive(Debug, Clone, Deserialize)] @@ -142,6 +168,9 @@ pub struct Timeouts { pub write: Duration, #[serde(with = "serde_duration")] pub shutdown: Duration, + /// Timeout for reading the proxy protocol header from inbound connections. + #[serde(with = "serde_duration")] + pub pp_header: Duration, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/gateway/src/debug_service.rs b/gateway/src/debug_service.rs index 137ebd739..b00176a14 100644 --- a/gateway/src/debug_service.rs +++ b/gateway/src/debug_service.rs @@ -35,6 +35,8 @@ impl DebugRpc for DebugRpcHandler { &request.app_id, &request.instance_id, &request.client_public_key, + "", + None, ) } diff --git a/gateway/src/kv/mod.rs b/gateway/src/kv/mod.rs index 97b195c96..6994b18ee 100644 --- a/gateway/src/kv/mod.rs +++ b/gateway/src/kv/mod.rs @@ -42,6 +42,14 @@ use serde::{Deserialize, Serialize}; use tokio::sync::watch; use wavekv::{node::NodeState, types::NodeId, Node}; +/// Per-port flags applied by the gateway when proxying to a CVM port. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct PortFlags { + /// Send a PROXY protocol header on outbound connections to this port. + #[serde(default)] + pub pp: bool, +} + /// Instance core data (persistent) #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct InstanceData { @@ -49,6 +57,16 @@ pub struct InstanceData { pub ip: Ipv4Addr, pub public_key: String, pub reg_time: u64, + /// Per-port flags reported at registration. `None` means "not reported" + /// (legacy CVM); the gateway will fall back to fetching app-compose via + /// Info() on first connection and populate this lazily. + #[serde(default)] + pub port_attrs: Option>, + /// Hex-encoded compose_hash that `port_attrs` was learned against. + /// When a re-registration presents a different compose_hash (app upgrade), + /// the cache is invalidated and re-fetched lazily. + #[serde(default)] + pub port_attrs_hash: String, } /// Gateway node status (stored separately for independent updates) diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 52ceb0e4b..1349d9667 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -32,6 +32,7 @@ mod distributed_certbot; mod kv; mod main_service; mod models; +mod pp; mod proxy; mod web_routes; diff --git a/gateway/src/main_service.rs b/gateway/src/main_service.rs index 7e733ec6d..55c76554d 100644 --- a/gateway/src/main_service.rs +++ b/gateway/src/main_service.rs @@ -28,7 +28,10 @@ use rinja::Template as _; use safe_write::safe_write; use serde::{Deserialize, Serialize}; use smallvec::{smallvec, SmallVec}; -use tokio::sync::Notify; +use tokio::sync::{ + mpsc::{unbounded_channel, UnboundedSender}, + Notify, +}; use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, warn}; @@ -37,7 +40,7 @@ use crate::{ config::{Config, TlsConfig}, kv::{ fetch_peers_from_bootnode, AppIdValidator, HttpsClientConfig, InstanceData, KvStore, - NodeData, NodeStatus, WaveKvSyncService, + NodeData, NodeStatus, PortFlags, WaveKvSyncService, }, models::{InstanceInfo, WgConf}, proxy::{create_acceptor_with_cert_resolver, AddressGroup, AddressInfo}, @@ -75,6 +78,10 @@ pub struct ProxyInner { pub(crate) wavekv_sync: Option>, /// HTTPS client config for mTLS (used for bootnode peer discovery) https_config: Option, + /// Sender for the background port_attrs lazy-fetch worker. The proxy fast + /// path enqueues unknown instance_ids and immediately returns `pp=false` + /// so a missing cache never blocks a connection. + pub(crate) port_attrs_tx: UnboundedSender, } #[derive(Debug, Serialize, Deserialize, Default)] @@ -103,9 +110,13 @@ pub struct ProxyOptions { impl Proxy { pub async fn new(options: ProxyOptions) -> Result { - Ok(Self { - _inner: Arc::new(ProxyInner::new(options).await?), - }) + let (port_attrs_tx, port_attrs_rx) = unbounded_channel(); + let inner = ProxyInner::new(options, port_attrs_tx).await?; + let proxy = Self { + _inner: Arc::new(inner), + }; + crate::proxy::port_attrs::spawn_fetcher(proxy.clone(), port_attrs_rx); + Ok(proxy) } } @@ -114,7 +125,10 @@ impl ProxyInner { self.state.lock().or_panic("Failed to lock AppState") } - pub async fn new(options: ProxyOptions) -> Result { + pub async fn new( + options: ProxyOptions, + port_attrs_tx: UnboundedSender, + ) -> Result { let ProxyOptions { config, my_app_id, @@ -270,6 +284,7 @@ impl ProxyInner { kv_store, wavekv_sync, https_config: Some(https_config), + port_attrs_tx, }) } @@ -378,12 +393,20 @@ impl Proxy { }) } - /// Register a CVM with the given app_id, instance_id and client_public_key + /// Register a CVM with the given app_id, instance_id and client_public_key. + /// + /// `port_attrs = None` means the CVM didn't report port attributes (legacy + /// CVM). The gateway will lazily fetch them via Info() on first connection. + /// + /// `compose_hash` is the attested compose_hash — used to invalidate any + /// cached `port_attrs` when the app is upgraded. pub fn do_register_cvm( &self, app_id: &str, instance_id: &str, client_public_key: &str, + compose_hash: &str, + port_attrs: Option>, ) -> Result { let mut state = self.lock(); @@ -403,11 +426,23 @@ impl Proxy { bail!("[{instance_id}] client public key is empty"); } let client_info = state - .new_client_by_id(instance_id, app_id, client_public_key) + .new_client_by_id( + instance_id, + app_id, + client_public_key, + compose_hash, + port_attrs, + ) .context("failed to allocate IP address for client")?; if let Err(err) = state.reconfigure() { error!("failed to reconfigure: {err:?}"); } + // Capture the prewarm decision before continuing under the lock. + // If the instance arrived without port_attrs (legacy CVM, or + // compose_hash mismatch invalidated the cache), enqueue a + // background fetch so the first proxied connection isn't the one + // that triggers it. The fetcher dedupes, so this is safe. + let needs_prewarm = client_info.port_attrs.is_none(); let gateways = state.get_active_nodes(); let servers = gateways .iter() @@ -425,12 +460,16 @@ impl Proxy { }), agent: Some(GuestAgentConfig { external_port: port.into(), - internal_port: 8090, + internal_port: state.config.proxy.agent_port.into(), domain: base_domain, app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), }), gateways, }; + drop(state); + if needs_prewarm { + let _ = self.port_attrs_tx.send(instance_id.to_string()); + } self.notify_state_updated.notify_one(); Ok(response) } @@ -449,6 +488,8 @@ fn build_state_from_kv_store(instances: BTreeMap) -> Proxy reg_time: UNIX_EPOCH .checked_add(Duration::from_secs(data.reg_time)) .unwrap_or(UNIX_EPOCH), + port_attrs: data.port_attrs, + port_attrs_hash: data.port_attrs_hash, connections: Default::default(), }; state.allocated_addresses.insert(data.ip); @@ -742,6 +783,8 @@ fn reload_instances_from_kv_store(proxy: &Proxy, store: &KvStore) -> Result<()> reg_time: UNIX_EPOCH .checked_add(Duration::from_secs(data.reg_time)) .unwrap_or(UNIX_EPOCH), + port_attrs: data.port_attrs.clone(), + port_attrs_hash: data.port_attrs_hash.clone(), connections: Default::default(), }; @@ -823,6 +866,8 @@ impl ProxyState { id: &str, app_id: &str, public_key: &str, + compose_hash: &str, + port_attrs: Option>, ) -> Result { if id.is_empty() { bail!("instance_id is empty (no_instance_id is set?)"); @@ -841,6 +886,23 @@ impl ProxyState { // Update reg_time so other nodes will pick up the change existing.reg_time = SystemTime::now(); } + // App upgrade detection: a different attested compose_hash invalidates + // any cached port_attrs from the previous code. + if existing.port_attrs_hash != compose_hash { + info!( + "compose_hash changed for instance {id} ({} -> {compose_hash}), \ + invalidating cached port_attrs", + existing.port_attrs_hash + ); + existing.port_attrs = None; + existing.port_attrs_hash = compose_hash.to_string(); + } + // Only override cached port_attrs when the caller actually reports + // them. A `None` request (legacy CVM) means "I don't know" — let + // the lazy fetch path run again. + if port_attrs.is_some() { + existing.port_attrs = port_attrs.clone(); + } let existing = existing.clone(); if self.valid_ip(existing.ip) { // Sync existing instance to KvStore (might be from legacy state) @@ -849,6 +911,8 @@ impl ProxyState { ip: existing.ip, public_key: existing.public_key.clone(), reg_time: encode_ts(existing.reg_time), + port_attrs: existing.port_attrs.clone(), + port_attrs_hash: existing.port_attrs_hash.clone(), }; if let Err(err) = self.kv_store.sync_instance(&existing.id, &data) { error!("failed to sync existing instance to KvStore: {err:?}"); @@ -867,12 +931,52 @@ impl ProxyState { ip, public_key: public_key.to_string(), reg_time: SystemTime::now(), + port_attrs, + port_attrs_hash: compose_hash.to_string(), connections: Default::default(), }; self.add_instance(host_info.clone()); Ok(host_info) } + /// Lookup an instance's IP. Returns `None` if the instance is unknown. + pub(crate) fn instance_ip(&self, instance_id: &str) -> Option { + self.state.instances.get(instance_id).map(|i| i.ip) + } + + /// Lookup an instance's port_attrs. `None` means the CVM never reported + /// them (legacy CVM), so the caller should fall back to fetching via Info(). + pub(crate) fn instance_port_attrs( + &self, + instance_id: &str, + ) -> Option<&BTreeMap> { + self.state.instances.get(instance_id)?.port_attrs.as_ref() + } + + /// Update an instance's port_attrs (used after a lazy fetch via Info()). + /// Persists to the WaveKV store so other gateway nodes pick it up. + pub(crate) fn update_instance_port_attrs( + &mut self, + instance_id: &str, + attrs: BTreeMap, + ) { + let Some(info) = self.state.instances.get_mut(instance_id) else { + return; + }; + info.port_attrs = Some(attrs.clone()); + let data = InstanceData { + app_id: info.app_id.clone(), + ip: info.ip, + public_key: info.public_key.clone(), + reg_time: encode_ts(info.reg_time), + port_attrs: Some(attrs), + port_attrs_hash: info.port_attrs_hash.clone(), + }; + if let Err(err) = self.kv_store.sync_instance(instance_id, &data) { + error!("failed to sync updated port_attrs to KvStore: {err:?}"); + } + } + fn add_instance(&mut self, info: InstanceInfo) { // Sync to KvStore let data = InstanceData { @@ -880,6 +984,8 @@ impl ProxyState { ip: info.ip, public_key: info.public_key.clone(), reg_time: encode_ts(info.reg_time), + port_attrs: info.port_attrs.clone(), + port_attrs_hash: info.port_attrs_hash.clone(), }; if let Err(err) = self.kv_store.sync_instance(&info.id, &data) { error!("failed to sync instance to KvStore: {err:?}"); @@ -921,6 +1027,7 @@ impl ProxyState { return Ok(smallvec![AddressInfo { ip: Ipv4Addr::new(127, 0, 0, 1), counter: Default::default(), + instance_id: "localhost".to_string(), }]); } let n = self.config.proxy.connect_top_n; @@ -928,6 +1035,7 @@ impl ProxyState { return Ok(smallvec![AddressInfo { ip: instance.ip, counter: instance.connections.clone(), + instance_id: instance.id.clone(), }]); }; let app_instances = self.state.apps.get(id).context("app not found")?; @@ -955,7 +1063,12 @@ impl ProxyState { .filter_map(|instance_id| { let instance = self.state.instances.get(instance_id)?; let (_, elapsed) = handshakes.get(&instance.public_key)?; - Some((instance.ip, *elapsed, instance.connections.clone())) + Some(( + instance.ip, + *elapsed, + instance.connections.clone(), + instance.id.clone(), + )) }) .collect::>(), }; @@ -963,7 +1076,11 @@ impl ProxyState { instances.truncate(n); Ok(instances .into_iter() - .map(|(ip, _, counter)| AddressInfo { ip, counter }) + .map(|(ip, _, counter, instance_id)| AddressInfo { + ip, + counter, + instance_id, + }) .collect()) } @@ -973,6 +1090,7 @@ impl ProxyState { return Some(smallvec![AddressInfo { ip: info.ip, counter: info.connections.clone(), + instance_id: info.id.clone(), }]); } @@ -999,6 +1117,7 @@ impl ProxyState { smallvec![AddressInfo { ip: info.ip, counter: info.connections.clone(), + instance_id: info.id.clone(), }] }) } @@ -1271,8 +1390,25 @@ impl GatewayRpc for RpcHandler { .context("App authorization failed")?; let app_id = hex::encode(&app_info.app_id); let instance_id = hex::encode(&app_info.instance_id); - self.state - .do_register_cvm(&app_id, &instance_id, &request.client_public_key) + let compose_hash = hex::encode(&app_info.compose_hash); + let port_attrs = request.port_attrs.map(|list| { + list.attrs + .into_iter() + .filter_map(|p| { + // Wire format is uint32 to avoid varint shenanigans, but valid TCP + // ports fit in u16. Drop out-of-range entries instead of truncating. + let port = u16::try_from(p.port).ok()?; + Some((port, PortFlags { pp: p.pp })) + }) + .collect::>() + }); + self.state.do_register_cvm( + &app_id, + &instance_id, + &request.client_public_key, + &compose_hash, + port_attrs, + ) } async fn acme_info(self) -> Result { diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap index f211b458a..d82452413 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap @@ -12,5 +12,7 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, + port_attrs: None, + port_attrs_hash: "", connections: 0, } diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap index 5b07304c0..a664c42db 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap @@ -12,5 +12,7 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, + port_attrs: None, + port_attrs_hash: "", connections: 0, } diff --git a/gateway/src/main_service/tests.rs b/gateway/src/main_service/tests.rs index 1a43b1540..1457a64de 100644 --- a/gateway/src/main_service/tests.rs +++ b/gateway/src/main_service/tests.rs @@ -55,14 +55,14 @@ async fn test_config() { let state = create_test_state().await; let mut info = state .lock() - .new_client_by_id("test-id-0", "app-id-0", "test-pubkey-0") + .new_client_by_id("test-id-0", "app-id-0", "test-pubkey-0", "", None) .unwrap(); info.reg_time = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info); let mut info1 = state .lock() - .new_client_by_id("test-id-1", "app-id-1", "test-pubkey-1") + .new_client_by_id("test-id-1", "app-id-1", "test-pubkey-1", "", None) .unwrap(); info1.reg_time = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info1); diff --git a/gateway/src/models.rs b/gateway/src/models.rs index 37caa274e..4bb06e3f7 100644 --- a/gateway/src/models.rs +++ b/gateway/src/models.rs @@ -15,6 +15,8 @@ use std::{ time::SystemTime, }; +use crate::kv::PortFlags; + mod filters { pub fn hex(data: impl AsRef<[u8]>) -> rinja::Result { Ok(hex::encode(data)) @@ -60,6 +62,14 @@ pub struct InstanceInfo { pub ip: Ipv4Addr, pub public_key: String, pub reg_time: SystemTime, + /// Per-port flags. `None` means the CVM didn't report any (legacy); + /// gateway will lazily populate via Info() on first proxied connection. + #[serde(default)] + pub port_attrs: Option>, + /// Hex-encoded compose_hash that `port_attrs` was learned against. The + /// cache is invalidated when a new registration presents a different hash. + #[serde(default)] + pub port_attrs_hash: String, #[serde(skip)] pub connections: Arc, } diff --git a/gateway/src/pp.rs b/gateway/src/pp.rs new file mode 100644 index 000000000..f6c6e09f2 --- /dev/null +++ b/gateway/src/pp.rs @@ -0,0 +1,305 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use std::net::SocketAddr; + +use anyhow::{bail, Context, Result}; +use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader}; +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + net::TcpStream, +}; + +use crate::config::ProxyConfig; + +const V1_PROTOCOL_PREFIX: &str = "PROXY"; +const V1_PREFIX_LEN: usize = 5; +const V1_MAX_LENGTH: usize = 107; +const V1_TERMINATOR: &[u8] = b"\r\n"; + +const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n"; +const V2_PREFIX_LEN: usize = 12; +const V2_MINIMUM_LEN: usize = 16; +const V2_LENGTH_INDEX: usize = 14; +const READ_BUFFER_LEN: usize = 512; +const V2_MAX_LENGTH: usize = 2048; + +/// Read or synthesize the inbound proxy protocol header. +/// +/// When `inbound_pp_enabled` is true, reads a PP header from the stream (e.g. from an upstream +/// load balancer). When false, synthesizes one from the TCP peer address. +pub(crate) async fn get_inbound_pp_header( + inbound: TcpStream, + config: &ProxyConfig, +) -> Result<(TcpStream, ProxyHeader)> { + if config.inbound_pp_enabled { + read_proxy_header(inbound).await + } else { + let header = create_inbound_pp_header(&inbound); + Ok((inbound, header)) + } +} + +pub struct DisplayAddr<'a>(pub &'a ProxyHeader); + +impl std::fmt::Display for DisplayAddr<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ProxyHeader::Version2 { addresses, .. } => match addresses { + v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Unix { .. } => write!(f, ""), + v2::ProxyAddresses::Unspec => write!(f, ""), + }, + ProxyHeader::Version1 { addresses, .. } => match addresses { + v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Unknown => write!(f, ""), + }, + _ => write!(f, ""), + } + } +} + +fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader { + let peer_addr = inbound.peer_addr().ok(); + let local_addr = inbound.local_addr().ok(); + + match (peer_addr, local_addr) { + (Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv4 { + source, + destination, + }, + } + } + (Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv6 { + source, + destination, + }, + } + } + _ => ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Unspec, + }, + } +} + +async fn read_proxy_header(mut stream: I) -> Result<(I, ProxyHeader)> +where + I: AsyncRead + Unpin, +{ + let mut buffer = [0; READ_BUFFER_LEN]; + let mut dynamic_buffer = None; + + stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; + + if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() { + read_v1_header(&mut stream, &mut buffer).await?; + } else { + stream + .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) + .await?; + if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX { + dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; + } else { + bail!("no valid proxy protocol header detected"); + } + } + + let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); + + let header = + proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?; + Ok((stream, header)) +} + +async fn read_v2_header( + mut stream: I, + buffer: &mut [u8; READ_BUFFER_LEN], +) -> Result>> +where + I: AsyncRead + Unpin, +{ + let length = + u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; + let full_length = V2_MINIMUM_LEN + length; + + if full_length > V2_MAX_LENGTH { + bail!("v2 proxy protocol header is too long"); + } + + if full_length > READ_BUFFER_LEN { + let mut dynamic_buffer = Vec::with_capacity(full_length); + dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); + dynamic_buffer.resize(full_length, 0); + stream + .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(Some(dynamic_buffer)) + } else { + stream + .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(None) + } +} + +async fn read_v1_header(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()> +where + I: AsyncRead + Unpin, +{ + let mut end_found = false; + for i in V1_PREFIX_LEN..V1_MAX_LENGTH { + buffer[i] = stream.read_u8().await?; + + if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { + end_found = true; + break; + } + } + if !end_found { + bail!("no valid proxy protocol header detected (v1 terminator not found)"); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader}; + + fn extract_v4(header: ProxyHeader) -> (std::net::SocketAddrV4, std::net::SocketAddrV4) { + match header { + ProxyHeader::Version1 { + addresses: + v1::ProxyAddresses::Ipv4 { + source, + destination, + }, + .. + } => (source, destination), + ProxyHeader::Version2 { + addresses: + v2::ProxyAddresses::Ipv4 { + source, + destination, + }, + .. + } => (source, destination), + other => panic!("expected ipv4 header, got {other:?}"), + } + } + + #[tokio::test] + async fn parses_v1_ipv4() { + // v1 is ASCII: "PROXY TCP4 \r\n" + let header = b"PROXY TCP4 1.2.3.4 5.6.7.8 11111 22222\r\n"; + let (_stream, parsed) = read_proxy_header(&header[..]).await.expect("v1 parse"); + let (src, dst) = extract_v4(parsed); + assert_eq!(src.ip().octets(), [1, 2, 3, 4]); + assert_eq!(src.port(), 11111); + assert_eq!(dst.ip().octets(), [5, 6, 7, 8]); + assert_eq!(dst.port(), 22222); + } + + #[tokio::test] + async fn parses_v2_ipv4() { + // v2 magic + ver/cmd 0x21 + family/proto 0x11 (TCP/IPv4) + len 12 + let mut header = Vec::new(); + header.extend_from_slice(V2_PROTOCOL_PREFIX); + header.extend_from_slice(&[0x21, 0x11, 0x00, 0x0c]); + header.extend_from_slice(&[1, 2, 3, 4]); // src ip + header.extend_from_slice(&[5, 6, 7, 8]); // dst ip + header.extend_from_slice(&11111u16.to_be_bytes()); // src port + header.extend_from_slice(&22222u16.to_be_bytes()); // dst port + + let (_stream, parsed) = read_proxy_header(&header[..]).await.expect("v2 parse"); + let (src, dst) = extract_v4(parsed); + assert_eq!(src.ip().octets(), [1, 2, 3, 4]); + assert_eq!(src.port(), 11111); + assert_eq!(dst.ip().octets(), [5, 6, 7, 8]); + assert_eq!(dst.port(), 22222); + } + + #[tokio::test] + async fn rejects_no_prefix() { + // Looks neither like v1 ("PROXY") nor v2 magic. + let bytes = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let err = read_proxy_header(&bytes[..]).await.unwrap_err(); + assert!( + format!("{err:#}").contains("no valid proxy protocol header"), + "unexpected error: {err:#}" + ); + } + + #[tokio::test] + async fn rejects_v1_without_terminator() { + // PROXY prefix matched but no \r\n terminator within V1_MAX_LENGTH bytes. + let bytes = vec![b'P'; V1_MAX_LENGTH + 8]; // all 'P' — never closes + let mut head = b"PROXY".to_vec(); + head.extend(std::iter::repeat(b'A').take(V1_MAX_LENGTH)); + let err = read_proxy_header(&head[..]).await.unwrap_err(); + let msg = format!("{err:#}"); + assert!( + msg.contains("v1 terminator not found") || msg.contains("no valid proxy"), + "unexpected error: {msg}" + ); + // Sanity: the longer no-terminator buffer would also fail (read past) + let _ = bytes; + } + + #[tokio::test] + async fn rejects_v2_oversize_length() { + // v2 prefix with a length field exceeding V2_MAX_LENGTH. + let mut header = Vec::new(); + header.extend_from_slice(V2_PROTOCOL_PREFIX); + header.extend_from_slice(&[0x21, 0x11]); + // length = V2_MAX_LENGTH bytes -> total = MIN + that, blows the cap + header.extend_from_slice(&(V2_MAX_LENGTH as u16).to_be_bytes()); + let err = read_proxy_header(&header[..]).await.unwrap_err(); + assert!( + format!("{err:#}").contains("too long"), + "unexpected error: {err:#}" + ); + } + + #[test] + fn synthesizes_unspec_when_no_addrs() { + // We can't construct a real TcpStream in a unit test cheaply; just + // assert the helper returns Unspec for the all-None branch by going + // through the public Display impl. + let header = ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Unspec, + }; + assert_eq!(format!("{}", DisplayAddr(&header)), ""); + } + + #[test] + fn display_v2_ipv4_source() { + let header = ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv4 { + source: "9.8.7.6:1234".parse().unwrap(), + destination: "1.2.3.4:80".parse().unwrap(), + }, + }; + assert_eq!(format!("{}", DisplayAddr(&header)), "9.8.7.6:1234"); + } +} diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index dd39d0ac9..e3a26f310 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -23,17 +23,26 @@ use tokio::{ }; use tracing::{debug, error, info, info_span, Instrument}; -use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter}; +use crate::{ + config::ProxyConfig, + main_service::Proxy, + models::EnteredCounter, + pp::{get_inbound_pp_header, DisplayAddr}, +}; #[derive(Debug, Clone)] pub(crate) struct AddressInfo { pub ip: Ipv4Addr, pub counter: Arc, + /// Instance id this address belongs to. Used to look up per-instance state + /// (e.g. port_attrs) after the racing connect picks a winner. + pub instance_id: String, } pub(crate) type AddressGroup = smallvec::SmallVec<[AddressInfo; 4]>; mod io_bridge; +pub(crate) mod port_attrs; mod sni; mod tls_passthough; mod tls_terminate; @@ -123,8 +132,16 @@ fn parse_dst_info(subdomain: &str) -> Result { pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); -async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { +async fn handle_connection(inbound: TcpStream, state: Proxy) -> Result<()> { let timeouts = &state.config.proxy.timeouts; + + let pp_fut = get_inbound_pp_header(inbound, &state.config.proxy); + let (mut inbound, pp_header) = timeout(timeouts.pp_header, pp_fut) + .await + .context("proxy protocol header timeout")? + .context("failed to read proxy protocol header")?; + info!("client address: {}", DisplayAddr(&pp_header)); + let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound)) .await .context("take sni timeout")? @@ -138,14 +155,15 @@ async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { let dst = parse_dst_info(subdomain)?; debug!("dst: {dst:?}"); if dst.is_tls { - tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await + tls_passthough::proxy_to_app(state, inbound, pp_header, buffer, &dst.app_id, dst.port) + .await } else { state - .proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2) + .proxy(inbound, pp_header, buffer, &dst.app_id, dst.port, dst.is_h2) .await } } else { - tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await + tls_passthough::proxy_with_sni(state, inbound, pp_header, buffer, &sni).await } } diff --git a/gateway/src/proxy/port_attrs.rs b/gateway/src/proxy/port_attrs.rs new file mode 100644 index 000000000..db21dd331 --- /dev/null +++ b/gateway/src/proxy/port_attrs.rs @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Per-port attribute lookup with background lazy fetch from legacy CVMs. +//! +//! Two paths: +//! +//! - Fast path (`should_send_pp`): a synchronous, non-blocking lookup. On a +//! cache miss it enqueues the instance for the background worker and +//! optimistically returns `pp = false` so the connection isn't blocked. +//! - Slow path ([`spawn_fetcher`]): a single background task that drains the +//! queue, dedupes in-flight instances, calls the agent's `Info()` RPC with +//! a timeout, and writes the result back to WaveKV. + +use std::collections::{BTreeMap, HashSet}; +use std::net::Ipv4Addr; +use std::sync::{Arc, Mutex}; + +use anyhow::{Context, Result}; +use dstack_guest_agent_rpc::dstack_guest_client::DstackGuestClient; +use dstack_types::AppCompose; +use http_client::prpc::PrpcClient; +use or_panic::ResultOrPanic; +use tokio::sync::mpsc::UnboundedReceiver; +use tracing::{debug, warn}; + +use crate::{kv::PortFlags, main_service::Proxy}; + +/// Outcome of a single fetch attempt, distinguishing what we can usefully retry. +enum FetchError { + /// Transient: connection failed, RPC timed out, agent returned 5xx, etc. + /// The CVM might just be warming up — retry with backoff. + Transient(anyhow::Error), + /// Permanent: instance is gone, or the CVM responded with data we can't + /// parse (malformed tcb_info, schema mismatch). Retrying won't help. + Permanent(anyhow::Error), +} + +/// Decide whether the gateway should send a PROXY protocol header on the +/// outbound connection to (`instance_id`, `port`). +/// +/// Cache hit returns the declared value. Cache miss returns `false` and asks +/// the background worker to populate the cache for the next request — this +/// keeps the data path off the critical Info() RPC. +pub(crate) fn should_send_pp(state: &Proxy, instance_id: &str, port: u16) -> bool { + if let Some(attrs) = state.lock().instance_port_attrs(instance_id) { + return attrs.get(&port).map(|f| f.pp).unwrap_or(false); + } + // Best-effort enqueue. If the channel is closed (shutdown) or the worker + // is gone, silently drop — `false` is the conservative default anyway. + let _ = state.port_attrs_tx.send(instance_id.to_string()); + false +} + +/// Spawn the background lazy-fetch worker. Should be called once at startup. +pub(crate) fn spawn_fetcher(state: Proxy, mut rx: UnboundedReceiver) { + let in_flight: Arc>> = Default::default(); + tokio::spawn(async move { + while let Some(instance_id) = rx.recv().await { + // Dedupe: only one fetch per instance at a time. The entry is + // removed once the retry loop terminates (success, exhausted, + // or unknown-instance), so a later registration with a new + // compose_hash can re-trigger via the same path. + { + let mut in_flight = in_flight.lock().or_panic("port_attrs in_flight poisoned"); + if !in_flight.insert(instance_id.clone()) { + continue; + } + } + let state = state.clone(); + let in_flight = in_flight.clone(); + let id = instance_id.clone(); + tokio::spawn(async move { + fetch_with_retry(&state, &id).await; + in_flight + .lock() + .or_panic("port_attrs in_flight poisoned") + .remove(&id); + }); + } + }); +} + +async fn fetch_with_retry(state: &Proxy, instance_id: &str) { + let cfg = &state.config.proxy.port_attrs_fetch; + let mut attempt = 0u32; + let mut backoff = cfg.backoff_initial; + loop { + let result = + match tokio::time::timeout(cfg.timeout, fetch_and_store(state, instance_id)).await { + Ok(r) => r, + // The Info() RPC took too long. Treat as transient — the CVM + // may just be slow to come up. + Err(_) => Err(FetchError::Transient(anyhow::anyhow!( + "Info() rpc timed out after {:?}", + cfg.timeout + ))), + }; + match result { + Ok(()) => { + debug!("port_attrs cached for instance {instance_id} (attempt {attempt})"); + return; + } + Err(FetchError::Permanent(err)) => { + // Either the instance was recycled while queued, or the + // agent responded with data we can't parse. Retrying won't + // change either; bail. + debug!("port_attrs fetch for {instance_id}: permanent failure: {err:#}"); + return; + } + Err(FetchError::Transient(err)) => { + warn!("port_attrs fetch for {instance_id} failed (attempt {attempt}): {err:#}"); + } + } + if attempt >= cfg.max_retries { + warn!( + "port_attrs fetch for {instance_id} giving up after {} attempts", + attempt + 1 + ); + return; + } + tokio::time::sleep(backoff).await; + attempt += 1; + backoff = (backoff * 2).min(cfg.backoff_max); + } +} + +async fn fetch_and_store(state: &Proxy, instance_id: &str) -> Result<(), FetchError> { + let (ip, agent_port) = { + let guard = state.lock(); + let ip = guard + .instance_ip(instance_id) + // Instance was recycled — never coming back under this id. + .ok_or_else(|| FetchError::Permanent(anyhow::anyhow!("unknown instance")))?; + (ip, guard.config.proxy.agent_port) + }; + let attrs = fetch_port_attrs(ip, agent_port).await?; + state.lock().update_instance_port_attrs(instance_id, attrs); + Ok(()) +} + +async fn fetch_port_attrs( + ip: Ipv4Addr, + agent_port: u16, +) -> Result, FetchError> { + let url = format!("http://{ip}:{agent_port}/prpc"); + let client = DstackGuestClient::new(PrpcClient::new(url)); + // Network/RPC errors here are transient: agent might still be coming up. + let info = client + .info() + .await + .context("agent Info() rpc failed") + .map_err(FetchError::Transient)?; + + // Anything below this point is the agent telling us something we can't + // turn into port_attrs — treat as permanent. + if info.tcb_info.is_empty() { + // Legacy CVM with public_tcbinfo=false; we cannot inspect app-compose + // remotely. Cache an empty map so we don't keep retrying. + return Ok(BTreeMap::new()); + } + let tcb: serde_json::Value = serde_json::from_str(&info.tcb_info) + .context("invalid tcb_info json") + .map_err(FetchError::Permanent)?; + let raw = tcb + .get("app_compose") + .and_then(|v| v.as_str()) + .ok_or_else(|| FetchError::Permanent(anyhow::anyhow!("tcb_info missing app_compose")))?; + let app_compose: AppCompose = serde_json::from_str(raw) + .context("failed to parse app_compose from tcb_info") + .map_err(FetchError::Permanent)?; + Ok(app_compose + .ports + .into_iter() + .map(|p| (p.port, PortFlags { pp: p.pp })) + .collect()) +} diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index 57bb3830c..04a02994e 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -6,6 +6,7 @@ use std::fmt::Debug; use std::sync::atomic::Ordering; use anyhow::{bail, Context, Result}; +use proxy_protocol::ProxyHeader; use tokio::{io::AsyncWriteExt, net::TcpStream, task::JoinSet, time::timeout}; use tracing::{debug, info, warn}; @@ -14,7 +15,7 @@ use crate::{ models::{Counting, EnteredCounter}, }; -use super::{io_bridge::bridge, AddressGroup}; +use super::{io_bridge::bridge, port_attrs::should_send_pp, AddressGroup}; #[derive(Debug)] struct AppAddress { @@ -96,6 +97,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result, sni: &str, ) -> Result<()> { @@ -107,7 +109,7 @@ pub(crate) async fn proxy_with_sni( .with_context(|| format!("DNS TXT resolve timeout for {sni}"))? .with_context(|| format!("failed to resolve app address for {sni}"))?; debug!("target address is {}:{}", addr.app_id, addr.port); - proxy_to_app(state, inbound, buffer, &addr.app_id, addr.port).await + proxy_to_app(state, inbound, pp_header, buffer, &addr.app_id, addr.port).await } /// Check if app has reached max connections limit @@ -134,56 +136,70 @@ fn check_connection_limit( } /// connect to multiple hosts simultaneously and return the first successful connection +/// along with the instance_id of the winning address. pub(crate) async fn connect_multiple_hosts( addresses: AddressGroup, port: u16, max_connections: u64, app_id: &str, -) -> Result<(TcpStream, EnteredCounter)> { +) -> Result<(TcpStream, EnteredCounter, String)> { check_connection_limit(&addresses, max_connections, app_id)?; let mut join_set = JoinSet::new(); for addr in addresses { let counter = addr.counter.enter(); - let addr = addr.ip; - debug!("connecting to {addr}:{port}"); - let future = TcpStream::connect((addr, port)); - join_set.spawn(async move { (future.await.map_err(|e| (e, addr, port)), counter) }); + let ip = addr.ip; + let instance_id = addr.instance_id; + debug!("connecting to {ip}:{port}"); + let future = TcpStream::connect((ip, port)); + join_set.spawn(async move { + ( + future.await.map_err(|e| (e, ip, port)), + counter, + instance_id, + ) + }); } // select the first successful connection - let (connection, counter) = loop { - let (result, counter) = join_set + let (connection, counter, instance_id) = loop { + let (result, counter, instance_id) = join_set .join_next() .await .context("No connection success")? .context("Failed to join the connect task")?; match result { - Ok(connection) => break (connection, counter), + Ok(connection) => break (connection, counter, instance_id), Err((e, addr, port)) => { info!("failed to connect to app@{addr}:{port}: {e}"); } } }; debug!("connected to {:?}", connection.peer_addr()); - Ok((connection, counter)) + Ok((connection, counter, instance_id)) } pub(crate) async fn proxy_to_app( state: Proxy, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, app_id: &str, port: u16, ) -> Result<()> { let addresses = state.lock().select_top_n_hosts(app_id)?; let max_connections = state.config.proxy.max_connections_per_app; - let (mut outbound, _counter) = timeout( + let (mut outbound, _counter, instance_id) = timeout( state.config.proxy.timeouts.connect, connect_multiple_hosts(addresses.clone(), port, max_connections, app_id), ) .await .with_context(|| format!("connecting timeout to app {app_id}: {addresses:?}:{port}"))? .with_context(|| format!("failed to connect to app {app_id}: {addresses:?}:{port}"))?; + if should_send_pp(&state, &instance_id, port) { + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } outbound .write_all(&buffer) .await diff --git a/gateway/src/proxy/tls_terminate.rs b/gateway/src/proxy/tls_terminate.rs index d6e1ec0fc..f0c5adcd5 100644 --- a/gateway/src/proxy/tls_terminate.rs +++ b/gateway/src/proxy/tls_terminate.rs @@ -13,9 +13,10 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; +use proxy_protocol::ProxyHeader; use rustls::version::{TLS12, TLS13}; use serde::Serialize; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _, ReadBuf}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; @@ -27,6 +28,7 @@ use crate::config::{CryptoProvider, ProxyConfig, TlsVersion}; use crate::main_service::Proxy; use super::io_bridge::bridge; +use super::port_attrs::should_send_pp; use super::tls_passthough::connect_multiple_hosts; #[pin_project::pin_project] @@ -268,9 +270,10 @@ impl Proxy { Ok(tls_stream) } - pub(crate) async fn proxy( + pub(super) async fn proxy( &self, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, app_id: &str, port: u16, @@ -289,13 +292,18 @@ impl Proxy { debug!("selected top n hosts: {addresses:?}"); let tls_stream = self.tls_accept(inbound, buffer, h2).await?; let max_connections = self.config.proxy.max_connections_per_app; - let (outbound, _counter) = timeout( + let (mut outbound, _counter, instance_id) = timeout( self.config.proxy.timeouts.connect, connect_multiple_hosts(addresses, port, max_connections, app_id), ) .await .map_err(|_| anyhow!("connecting timeout"))? .context("failed to connect to app")?; + if should_send_pp(self, &instance_id, port) { + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } bridge( IgnoreUnexpectedEofStream::new(tls_stream), outbound, diff --git a/guest-agent/src/rpc_service.rs b/guest-agent/src/rpc_service.rs index 202ad73e6..ef4262834 100644 --- a/guest-agent/src/rpc_service.rs +++ b/guest-agent/src/rpc_service.rs @@ -718,6 +718,7 @@ mod tests { secure_time: false, storage_fs: None, swap_size: 0, + ports: Vec::new(), }; let dummy_appcompose_wrapper = AppComposeWrapper {