From d5a137ca384d3c117b969466c082a95e971a9123 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 15 Apr 2026 19:56:41 -0700 Subject: [PATCH] gw: implement proxy protocol with server-side control Add PROXY protocol support to the gateway with two server-side config options instead of client-controlled SNI suffixes: - inbound_pp_enabled: read PP headers from upstream load balancers - outbound_pp_enabled: send PP headers to backend apps The original PR#361 used a 'p' suffix in the SNI subdomain to toggle outbound PP per-connection. This is a security flaw: a client could connect to a PP-expecting port without sending PP headers, allowing source address spoofing. Both flags are now server-side config only. --- Cargo.lock | 38 ++++++ Cargo.toml | 1 + gateway/Cargo.toml | 1 + gateway/gateway.toml | 6 + gateway/src/config.rs | 11 ++ gateway/src/main.rs | 1 + gateway/src/pp.rs | 178 ++++++++++++++++++++++++++++ gateway/src/proxy.rs | 24 +++- gateway/src/proxy/tls_passthough.rs | 10 +- gateway/src/proxy/tls_terminate.rs | 13 +- 10 files changed, 274 insertions(+), 9 deletions(-) create mode 100644 gateway/src/pp.rs diff --git a/Cargo.lock b/Cargo.lock index 59ac5fda..7c10c2e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2211,6 +2211,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" + [[package]] name = "documented" version = "0.9.2" @@ -2299,6 +2305,7 @@ dependencies = [ "or-panic", "parcelona", "pin-project", + "proxy-protocol", "ra-rpc", "ra-tls", "rand 0.8.5", @@ -5619,6 +5626,16 @@ dependencies = [ "prost 0.13.5", ] +[[package]] +name = "proxy-protocol" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e50c72c21c738f5c5f350cc33640aee30bf7cd20f9d9da20ed41bce2671d532" +dependencies = [ + "bytes", + "snafu", +] + [[package]] name = "prpc" version = "0.6.0" @@ -7239,6 +7256,27 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab12d3c261b2308b0d80c26fffb58d17eba81a4be97890101f416b478c79ca7" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1508efa03c362e23817f96cde18abed596a25219a8b2c66e8db33c03543d315b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.5.10" diff --git a/Cargo.toml b/Cargo.toml index d2d5a680..ae5be4f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -247,6 +247,7 @@ yaml-rust2 = "0.10.4" luks2 = "0.5.0" scopeguard = "1.2.0" tar = "0.4" +proxy-protocol = "0.5.0" [profile.release] panic = "abort" diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 1f57ebf0..a1aefe17 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -54,6 +54,7 @@ hyper-rustls.workspace = true http-body-util.workspace = true x509-parser.workspace = true jemallocator.workspace = true +proxy-protocol.workspace = true wavekv.workspace = true tdx-attest.workspace = true flate2.workspace = true diff --git a/gateway/gateway.toml b/gateway/gateway.toml index 09c669e5..87e8d997 100644 --- a/gateway/gateway.toml +++ b/gateway/gateway.toml @@ -58,6 +58,10 @@ workers = 32 external_port = 443 # Maximum concurrent connections per app. 0 means unlimited. max_connections_per_app = 2000 +# Whether to read PROXY protocol from inbound connections (e.g. from Cloudflare). +inbound_pp_enabled = false +# Whether to send PROXY protocol headers to backend apps. +outbound_pp_enabled = false [core.proxy.timeouts] # Timeout for establishing a connection to the target app. @@ -81,6 +85,8 @@ write = "5s" shutdown = "5s" # Timeout for total connection duration. total = "5h" +# Timeout for proxy protocol header. +pp_header = "5s" [core.recycle] enabled = true diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 45b899c0..29672c59 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -117,6 +117,14 @@ pub struct ProxyConfig { pub app_address_ns_compat: bool, /// Maximum concurrent connections per app. 0 means unlimited. pub max_connections_per_app: u64, + /// Whether to read PROXY protocol headers from inbound connections + /// (e.g. when behind a PP-aware load balancer like Cloudflare). + #[serde(default)] + pub inbound_pp_enabled: bool, + /// Whether to send PROXY protocol headers on outbound connections to backend apps. + /// This is a server-side setting; it must NOT be controlled by client input (e.g. SNI). + #[serde(default)] + pub outbound_pp_enabled: bool, } #[derive(Debug, Clone, Deserialize)] @@ -142,6 +150,9 @@ pub struct Timeouts { pub write: Duration, #[serde(with = "serde_duration")] pub shutdown: Duration, + /// Timeout for reading the proxy protocol header from inbound connections. + #[serde(with = "serde_duration")] + pub pp_header: Duration, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 52ceb0e4..1349d966 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -32,6 +32,7 @@ mod distributed_certbot; mod kv; mod main_service; mod models; +mod pp; mod proxy; mod web_routes; diff --git a/gateway/src/pp.rs b/gateway/src/pp.rs new file mode 100644 index 00000000..196ecdaa --- /dev/null +++ b/gateway/src/pp.rs @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use std::net::SocketAddr; + +use anyhow::{bail, Context, Result}; +use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader}; +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + net::TcpStream, +}; + +use crate::config::ProxyConfig; + +const V1_PROTOCOL_PREFIX: &str = "PROXY"; +const V1_PREFIX_LEN: usize = 5; +const V1_MAX_LENGTH: usize = 107; +const V1_TERMINATOR: &[u8] = b"\r\n"; + +const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n"; +const V2_PREFIX_LEN: usize = 12; +const V2_MINIMUM_LEN: usize = 16; +const V2_LENGTH_INDEX: usize = 14; +const READ_BUFFER_LEN: usize = 512; +const V2_MAX_LENGTH: usize = 2048; + +/// Read or synthesize the inbound proxy protocol header. +/// +/// When `inbound_pp_enabled` is true, reads a PP header from the stream (e.g. from an upstream +/// load balancer). When false, synthesizes one from the TCP peer address. +pub(crate) async fn get_inbound_pp_header( + inbound: TcpStream, + config: &ProxyConfig, +) -> Result<(TcpStream, ProxyHeader)> { + if config.inbound_pp_enabled { + read_proxy_header(inbound).await + } else { + let header = create_inbound_pp_header(&inbound); + Ok((inbound, header)) + } +} + +pub struct DisplayAddr<'a>(pub &'a ProxyHeader); + +impl std::fmt::Display for DisplayAddr<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ProxyHeader::Version2 { addresses, .. } => match addresses { + v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Unix { .. } => write!(f, ""), + v2::ProxyAddresses::Unspec => write!(f, ""), + }, + ProxyHeader::Version1 { addresses, .. } => match addresses { + v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Unknown => write!(f, ""), + }, + _ => write!(f, ""), + } + } +} + +fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader { + let peer_addr = inbound.peer_addr().ok(); + let local_addr = inbound.local_addr().ok(); + + match (peer_addr, local_addr) { + (Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv4 { + source, + destination, + }, + } + } + (Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv6 { + source, + destination, + }, + } + } + _ => ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Unspec, + }, + } +} + +async fn read_proxy_header(mut stream: I) -> Result<(I, ProxyHeader)> +where + I: AsyncRead + Unpin, +{ + let mut buffer = [0; READ_BUFFER_LEN]; + let mut dynamic_buffer = None; + + stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; + + if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() { + read_v1_header(&mut stream, &mut buffer).await?; + } else { + stream + .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) + .await?; + if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX { + dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; + } else { + bail!("no valid proxy protocol header detected"); + } + } + + let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); + + let header = + proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?; + Ok((stream, header)) +} + +async fn read_v2_header( + mut stream: I, + buffer: &mut [u8; READ_BUFFER_LEN], +) -> Result>> +where + I: AsyncRead + Unpin, +{ + let length = + u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; + let full_length = V2_MINIMUM_LEN + length; + + if full_length > V2_MAX_LENGTH { + bail!("v2 proxy protocol header is too long"); + } + + if full_length > READ_BUFFER_LEN { + let mut dynamic_buffer = Vec::with_capacity(full_length); + dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); + dynamic_buffer.resize(full_length, 0); + stream + .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(Some(dynamic_buffer)) + } else { + stream + .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(None) + } +} + +async fn read_v1_header(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()> +where + I: AsyncRead + Unpin, +{ + let mut end_found = false; + for i in V1_PREFIX_LEN..V1_MAX_LENGTH { + buffer[i] = stream.read_u8().await?; + + if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { + end_found = true; + break; + } + } + if !end_found { + bail!("no valid proxy protocol header detected"); + } + + Ok(()) +} diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index dd39d0ac..b530cd8d 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -23,7 +23,12 @@ use tokio::{ }; use tracing::{debug, error, info, info_span, Instrument}; -use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter}; +use crate::{ + config::ProxyConfig, + main_service::Proxy, + models::EnteredCounter, + pp::{get_inbound_pp_header, DisplayAddr}, +}; #[derive(Debug, Clone)] pub(crate) struct AddressInfo { @@ -123,8 +128,16 @@ fn parse_dst_info(subdomain: &str) -> Result { pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); -async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { +async fn handle_connection(inbound: TcpStream, state: Proxy) -> Result<()> { let timeouts = &state.config.proxy.timeouts; + + let pp_fut = get_inbound_pp_header(inbound, &state.config.proxy); + let (mut inbound, pp_header) = timeout(timeouts.pp_header, pp_fut) + .await + .context("proxy protocol header timeout")? + .context("failed to read proxy protocol header")?; + info!("client address: {}", DisplayAddr(&pp_header)); + let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound)) .await .context("take sni timeout")? @@ -138,14 +151,15 @@ async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { let dst = parse_dst_info(subdomain)?; debug!("dst: {dst:?}"); if dst.is_tls { - tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await + tls_passthough::proxy_to_app(state, inbound, pp_header, buffer, &dst.app_id, dst.port) + .await } else { state - .proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2) + .proxy(inbound, pp_header, buffer, &dst.app_id, dst.port, dst.is_h2) .await } } else { - tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await + tls_passthough::proxy_with_sni(state, inbound, pp_header, buffer, &sni).await } } diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index 57bb3830..589f861a 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -6,6 +6,7 @@ use std::fmt::Debug; use std::sync::atomic::Ordering; use anyhow::{bail, Context, Result}; +use proxy_protocol::ProxyHeader; use tokio::{io::AsyncWriteExt, net::TcpStream, task::JoinSet, time::timeout}; use tracing::{debug, info, warn}; @@ -96,6 +97,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result, sni: &str, ) -> Result<()> { @@ -107,7 +109,7 @@ pub(crate) async fn proxy_with_sni( .with_context(|| format!("DNS TXT resolve timeout for {sni}"))? .with_context(|| format!("failed to resolve app address for {sni}"))?; debug!("target address is {}:{}", addr.app_id, addr.port); - proxy_to_app(state, inbound, buffer, &addr.app_id, addr.port).await + proxy_to_app(state, inbound, pp_header, buffer, &addr.app_id, addr.port).await } /// Check if app has reached max connections limit @@ -171,6 +173,7 @@ pub(crate) async fn connect_multiple_hosts( pub(crate) async fn proxy_to_app( state: Proxy, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, app_id: &str, port: u16, @@ -184,6 +187,11 @@ pub(crate) async fn proxy_to_app( .await .with_context(|| format!("connecting timeout to app {app_id}: {addresses:?}:{port}"))? .with_context(|| format!("failed to connect to app {app_id}: {addresses:?}:{port}"))?; + if state.config.proxy.outbound_pp_enabled { + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } outbound .write_all(&buffer) .await diff --git a/gateway/src/proxy/tls_terminate.rs b/gateway/src/proxy/tls_terminate.rs index d6e1ec0f..81ca11fd 100644 --- a/gateway/src/proxy/tls_terminate.rs +++ b/gateway/src/proxy/tls_terminate.rs @@ -14,8 +14,9 @@ use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; use rustls::version::{TLS12, TLS13}; +use proxy_protocol::ProxyHeader; use serde::Serialize; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _, ReadBuf}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; @@ -268,9 +269,10 @@ impl Proxy { Ok(tls_stream) } - pub(crate) async fn proxy( + pub(super) async fn proxy( &self, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, app_id: &str, port: u16, @@ -289,13 +291,18 @@ impl Proxy { debug!("selected top n hosts: {addresses:?}"); let tls_stream = self.tls_accept(inbound, buffer, h2).await?; let max_connections = self.config.proxy.max_connections_per_app; - let (outbound, _counter) = timeout( + let (mut outbound, _counter) = timeout( self.config.proxy.timeouts.connect, connect_multiple_hosts(addresses, port, max_connections, app_id), ) .await .map_err(|_| anyhow!("connecting timeout"))? .context("failed to connect to app")?; + if self.config.proxy.outbound_pp_enabled { + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } bridge( IgnoreUnexpectedEofStream::new(tls_stream), outbound,