Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ yaml-rust2 = "0.10.4"
luks2 = "0.5.0"
scopeguard = "1.2.0"
tar = "0.4"
proxy-protocol = "0.5.0"

[profile.release]
panic = "abort"
1 change: 1 addition & 0 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ hyper-rustls.workspace = true
http-body-util.workspace = true
x509-parser.workspace = true
jemallocator.workspace = true
proxy-protocol.workspace = true
wavekv.workspace = true
tdx-attest.workspace = true
flate2.workspace = true
Expand Down
6 changes: 6 additions & 0 deletions gateway/gateway.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ workers = 32
external_port = 443
# Maximum concurrent connections per app. 0 means unlimited.
max_connections_per_app = 2000
# Whether to read PROXY protocol from inbound connections (e.g. from Cloudflare).
inbound_pp_enabled = false
# Whether to send PROXY protocol headers to backend apps.
outbound_pp_enabled = false

[core.proxy.timeouts]
# Timeout for establishing a connection to the target app.
Expand All @@ -81,6 +85,8 @@ write = "5s"
shutdown = "5s"
# Timeout for total connection duration.
total = "5h"
# Timeout for proxy protocol header.
pp_header = "5s"

[core.recycle]
enabled = true
Expand Down
11 changes: 11 additions & 0 deletions gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ pub struct ProxyConfig {
pub app_address_ns_compat: bool,
/// Maximum concurrent connections per app. 0 means unlimited.
pub max_connections_per_app: u64,
/// Whether to read PROXY protocol headers from inbound connections
/// (e.g. when behind a PP-aware load balancer like Cloudflare).
#[serde(default)]
pub inbound_pp_enabled: bool,
/// Whether to send PROXY protocol headers on outbound connections to backend apps.
/// This is a server-side setting; it must NOT be controlled by client input (e.g. SNI).
#[serde(default)]
pub outbound_pp_enabled: bool,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -142,6 +150,9 @@ pub struct Timeouts {
pub write: Duration,
#[serde(with = "serde_duration")]
pub shutdown: Duration,
/// Timeout for reading the proxy protocol header from inbound connections.
#[serde(with = "serde_duration")]
pub pp_header: Duration,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down
1 change: 1 addition & 0 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod distributed_certbot;
mod kv;
mod main_service;
mod models;
mod pp;
mod proxy;
mod web_routes;

Expand Down
178 changes: 178 additions & 0 deletions gateway/src/pp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// SPDX-FileCopyrightText: © 2024-2025 Phala Network <dstack@phala.network>
//
// SPDX-License-Identifier: Apache-2.0

use std::net::SocketAddr;

use anyhow::{bail, Context, Result};
use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader};
use tokio::{
io::{AsyncRead, AsyncReadExt},
net::TcpStream,
};

use crate::config::ProxyConfig;

const V1_PROTOCOL_PREFIX: &str = "PROXY";
const V1_PREFIX_LEN: usize = 5;
const V1_MAX_LENGTH: usize = 107;
const V1_TERMINATOR: &[u8] = b"\r\n";

const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n";
const V2_PREFIX_LEN: usize = 12;
const V2_MINIMUM_LEN: usize = 16;
const V2_LENGTH_INDEX: usize = 14;
const READ_BUFFER_LEN: usize = 512;
const V2_MAX_LENGTH: usize = 2048;

/// Read or synthesize the inbound proxy protocol header.
///
/// When `inbound_pp_enabled` is true, reads a PP header from the stream (e.g. from an upstream
/// load balancer). When false, synthesizes one from the TCP peer address.
pub(crate) async fn get_inbound_pp_header(
inbound: TcpStream,
config: &ProxyConfig,
) -> Result<(TcpStream, ProxyHeader)> {
if config.inbound_pp_enabled {
read_proxy_header(inbound).await
} else {
let header = create_inbound_pp_header(&inbound);
Ok((inbound, header))
}
}

pub struct DisplayAddr<'a>(pub &'a ProxyHeader);

impl std::fmt::Display for DisplayAddr<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ProxyHeader::Version2 { addresses, .. } => match addresses {
v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
v2::ProxyAddresses::Unix { .. } => write!(f, "<unix>"),
v2::ProxyAddresses::Unspec => write!(f, "<unspec>"),
},
ProxyHeader::Version1 { addresses, .. } => match addresses {
v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
v1::ProxyAddresses::Unknown => write!(f, "<unknown>"),
},
_ => write!(f, "<unknown ver>"),
}
}
}

fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader {
let peer_addr = inbound.peer_addr().ok();
let local_addr = inbound.local_addr().ok();

match (peer_addr, local_addr) {
(Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => {
ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Ipv4 {
source,
destination,
},
}
}
(Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => {
ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Ipv6 {
source,
destination,
},
}
}
_ => ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Unspec,
},
}
}

async fn read_proxy_header<I>(mut stream: I) -> Result<(I, ProxyHeader)>
where
I: AsyncRead + Unpin,
{
let mut buffer = [0; READ_BUFFER_LEN];
let mut dynamic_buffer = None;

stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;

if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() {
read_v1_header(&mut stream, &mut buffer).await?;
} else {
stream
.read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN])
.await?;
if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX {
dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
} else {
bail!("no valid proxy protocol header detected");
}
}

let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);

let header =
proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?;
Ok((stream, header))
}

async fn read_v2_header<I>(
mut stream: I,
buffer: &mut [u8; READ_BUFFER_LEN],
) -> Result<Option<Vec<u8>>>
where
I: AsyncRead + Unpin,
{
let length =
u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
let full_length = V2_MINIMUM_LEN + length;

if full_length > V2_MAX_LENGTH {
bail!("v2 proxy protocol header is too long");
}

if full_length > READ_BUFFER_LEN {
let mut dynamic_buffer = Vec::with_capacity(full_length);
dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
dynamic_buffer.resize(full_length, 0);
stream
.read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
.await?;

Ok(Some(dynamic_buffer))
} else {
stream
.read_exact(&mut buffer[V2_MINIMUM_LEN..full_length])
.await?;

Ok(None)
}
}

async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()>
where
I: AsyncRead + Unpin,
{
let mut end_found = false;
for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
buffer[i] = stream.read_u8().await?;

if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
end_found = true;
break;
}
}
if !end_found {
bail!("no valid proxy protocol header detected");
}

Ok(())
}
24 changes: 19 additions & 5 deletions gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ use tokio::{
};
use tracing::{debug, error, info, info_span, Instrument};

use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter};
use crate::{
config::ProxyConfig,
main_service::Proxy,
models::EnteredCounter,
pp::{get_inbound_pp_header, DisplayAddr},
};

#[derive(Debug, Clone)]
pub(crate) struct AddressInfo {
Expand Down Expand Up @@ -123,8 +128,16 @@ fn parse_dst_info(subdomain: &str) -> Result<DstInfo> {

pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0);

async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> {
async fn handle_connection(inbound: TcpStream, state: Proxy) -> Result<()> {
let timeouts = &state.config.proxy.timeouts;

let pp_fut = get_inbound_pp_header(inbound, &state.config.proxy);
let (mut inbound, pp_header) = timeout(timeouts.pp_header, pp_fut)
.await
.context("proxy protocol header timeout")?
.context("failed to read proxy protocol header")?;
info!("client address: {}", DisplayAddr(&pp_header));

let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound))
.await
.context("take sni timeout")?
Expand All @@ -138,14 +151,15 @@ async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> {
let dst = parse_dst_info(subdomain)?;
debug!("dst: {dst:?}");
if dst.is_tls {
tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await
tls_passthough::proxy_to_app(state, inbound, pp_header, buffer, &dst.app_id, dst.port)
.await
} else {
state
.proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2)
.proxy(inbound, pp_header, buffer, &dst.app_id, dst.port, dst.is_h2)
.await
}
} else {
tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await
tls_passthough::proxy_with_sni(state, inbound, pp_header, buffer, &sni).await
}
}

Expand Down
Loading
Loading