Skip to main content

lava_flow/
types.rs

1use serde::{Deserialize, Deserializer, Serialize};
2
3use crate::error::{LavaFlowError, Result, ValidationReason};
4
5const MAX_IDENTIFIER_LEN: usize = 64;
6const MAX_HOSTNAME_LEN: usize = 253;
7type ValidationError = (String, ValidationReason);
8type ValidationResult = std::result::Result<String, ValidationError>;
9
10/// Logical process name used by channels and diagnostics.
11#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)]
12pub struct ProcessName(String);
13
14impl ProcessName {
15    /// Creates a validated process name.
16    ///
17    /// Validation rules:
18    /// - non-empty
19    /// - max 64 characters
20    /// - first character is ASCII letter or digit
21    /// - remaining characters are lowercase ASCII letters, digits, `-`, or `_`
22    pub fn new(value: impl Into<String>) -> Result<Self> {
23        let value = validate_identifier(value.into())
24            .map_err(|(value, reason)| LavaFlowError::InvalidProcessName { value, reason })?;
25        Ok(Self(value))
26    }
27
28    /// Returns the process name as a borrowed string slice.
29    pub fn as_str(&self) -> &str {
30        &self.0
31    }
32
33    /// Consumes the wrapper and returns the owned inner string.
34    pub fn into_inner(self) -> String {
35        self.0
36    }
37}
38
39impl<'de> Deserialize<'de> for ProcessName {
40    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        let value = String::deserialize(deserializer)?;
45        ProcessName::new(value).map_err(serde::de::Error::custom)
46    }
47}
48
49/// Stable identifier for a communication channel.
50#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)]
51pub struct ChannelId(String);
52
53impl ChannelId {
54    /// Creates a validated channel identifier.
55    ///
56    /// Validation rules are the same as [`ProcessName::new`].
57    pub fn new(value: impl Into<String>) -> Result<Self> {
58        let value = validate_identifier(value.into())
59            .map_err(|(value, reason)| LavaFlowError::InvalidChannelId { value, reason })?;
60        Ok(Self(value))
61    }
62
63    /// Returns the channel identifier as a borrowed string slice.
64    pub fn as_str(&self) -> &str {
65        &self.0
66    }
67
68    /// Consumes the wrapper and returns the owned inner string.
69    pub fn into_inner(self) -> String {
70        self.0
71    }
72}
73
74impl<'de> Deserialize<'de> for ChannelId {
75    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
76    where
77        D: Deserializer<'de>,
78    {
79        let value = String::deserialize(deserializer)?;
80        ChannelId::new(value).map_err(serde::de::Error::custom)
81    }
82}
83
84/// Location metadata used for topology-aware routing.
85#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
86pub struct ProcessLocation {
87    #[serde(deserialize_with = "deserialize_hostname")]
88    hostname: String,
89    node_id: Option<u32>,
90    device_id: Option<u32>,
91}
92
93impl ProcessLocation {
94    /// Creates a validated location with hostname only.
95    ///
96    /// Hostname validation rules:
97    /// - non-empty
98    /// - max 253 characters
99    /// - first character is ASCII letter or digit
100    /// - remaining characters are ASCII letters, digits, `-`, `.`, or `_`
101    ///
102    /// Hostnames are normalized to lowercase for stable scope comparison.
103    pub fn new(hostname: impl Into<String>) -> Result<Self> {
104        let hostname = validate_hostname(hostname.into())
105            .map_err(|(value, reason)| LavaFlowError::InvalidHostname { value, reason })?;
106        Ok(Self {
107            hostname,
108            node_id: None,
109            device_id: None,
110        })
111    }
112
113    /// Creates a validated location with hostname plus optional node/device metadata.
114    pub fn with_ids(
115        hostname: impl Into<String>,
116        node_id: Option<u32>,
117        device_id: Option<u32>,
118    ) -> Result<Self> {
119        let hostname = validate_hostname(hostname.into())
120            .map_err(|(value, reason)| LavaFlowError::InvalidHostname { value, reason })?;
121        Ok(Self {
122            hostname,
123            node_id,
124            device_id,
125        })
126    }
127
128    /// Detects the local hostname via OS APIs and returns a location.
129    ///
130    /// Returns an error if hostname detection fails.
131    pub fn from_hostname() -> Result<Self> {
132        let hostname = hostname::get().map_err(LavaFlowError::HostnameDetection)?;
133        let hostname = hostname.to_string_lossy().trim().to_string();
134        Self::new(hostname)
135    }
136
137    /// Returns the normalized hostname.
138    pub fn hostname(&self) -> &str {
139        &self.hostname
140    }
141
142    /// Returns the optional scheduler/node index metadata.
143    pub fn node_id(&self) -> Option<u32> {
144        self.node_id
145    }
146
147    /// Returns the optional GPU/device index metadata.
148    pub fn device_id(&self) -> Option<u32> {
149        self.device_id
150    }
151}
152
153/// Communication scope derived from locations.
154#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
155pub enum CommunicationScope {
156    /// Sender and peer are on the same host.
157    Local,
158    /// Sender and peer are on different hosts, or host identity is ambiguous.
159    Remote,
160}
161
162impl CommunicationScope {
163    /// Computes communication scope from two process locations.
164    pub fn from_locations(my_location: &ProcessLocation, peer_location: &ProcessLocation) -> Self {
165        if my_location.hostname == peer_location.hostname {
166            Self::Local
167        } else {
168            Self::Remote
169        }
170    }
171}
172
173/// Detects communication scope between two process locations.
174pub fn detect_scope(
175    my_location: &ProcessLocation,
176    peer_location: &ProcessLocation,
177) -> CommunicationScope {
178    CommunicationScope::from_locations(my_location, peer_location)
179}
180
181fn validate_identifier(value: String) -> ValidationResult {
182    // Shared normalization/validation path for name-like identifiers to keep
183    // behavioral rules consistent across all public wrapper types.
184    if value.is_empty() {
185        return Err((value, ValidationReason::Empty));
186    }
187    if value.len() > MAX_IDENTIFIER_LEN {
188        return Err((value, ValidationReason::IdentifierTooLong));
189    }
190    let mut chars = value.chars();
191    let first = chars.next().expect("identifier checked non-empty");
192    if !first.is_ascii_alphanumeric() {
193        return Err((value, ValidationReason::InvalidStartCharacter));
194    }
195    if !value
196        .chars()
197        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
198    {
199        return Err((value, ValidationReason::InvalidCharacters));
200    }
201    Ok(value)
202}
203
204fn validate_hostname(value: String) -> ValidationResult {
205    let value = value.trim().to_string();
206    if value.is_empty() {
207        return Err((value, ValidationReason::Empty));
208    }
209    if value.len() > MAX_HOSTNAME_LEN {
210        return Err((value, ValidationReason::HostnameTooLong));
211    }
212    let mut chars = value.chars();
213    let first = chars.next().expect("hostname checked non-empty");
214    if !first.is_ascii_alphanumeric() {
215        return Err((value, ValidationReason::InvalidStartCharacter));
216    }
217    if !value
218        .chars()
219        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_')
220    {
221        return Err((value, ValidationReason::InvalidCharacters));
222    }
223    Ok(value.to_ascii_lowercase())
224}
225
226fn deserialize_hostname<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
227where
228    D: Deserializer<'de>,
229{
230    let value = String::deserialize(deserializer)?;
231    validate_hostname(value).map_err(|(value, reason)| {
232        serde::de::Error::custom(LavaFlowError::InvalidHostname { value, reason })
233    })
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use serde_json::json;
240
241    #[test]
242    fn process_name_accepts_valid_identifiers() {
243        let name = ProcessName::new("gpu_worker_01").expect("valid process name");
244        assert_eq!(name.as_str(), "gpu_worker_01");
245    }
246
247    #[test]
248    fn process_name_into_inner_returns_owned_value() {
249        let name = ProcessName::new("gpu_worker_01").expect("valid process name");
250        assert_eq!(name.into_inner(), "gpu_worker_01");
251    }
252
253    #[test]
254    fn process_name_rejects_invalid_identifiers() {
255        let err = ProcessName::new("GPU-WORKER").expect_err("expected validation error");
256        assert!(matches!(
257            err,
258            LavaFlowError::InvalidProcessName {
259                reason: ValidationReason::InvalidCharacters,
260                ..
261            }
262        ));
263    }
264
265    #[test]
266    fn process_name_rejects_empty_identifiers() {
267        let err = ProcessName::new("").expect_err("expected validation error");
268        assert!(matches!(
269            err,
270            LavaFlowError::InvalidProcessName {
271                reason: ValidationReason::Empty,
272                ..
273            }
274        ));
275    }
276
277    #[test]
278    fn process_name_rejects_too_long_identifiers() {
279        let too_long = "a".repeat(MAX_IDENTIFIER_LEN + 1);
280        let err = ProcessName::new(too_long).expect_err("expected validation error");
281        assert!(matches!(
282            err,
283            LavaFlowError::InvalidProcessName {
284                reason: ValidationReason::IdentifierTooLong,
285                ..
286            }
287        ));
288    }
289
290    #[test]
291    fn process_name_rejects_invalid_start_identifiers() {
292        let err = ProcessName::new("-gpu").expect_err("expected validation error");
293        assert!(matches!(
294            err,
295            LavaFlowError::InvalidProcessName {
296                reason: ValidationReason::InvalidStartCharacter,
297                ..
298            }
299        ));
300    }
301
302    #[test]
303    fn channel_id_accepts_valid_identifiers() {
304        let channel_id = ChannelId::new("channel-0").expect("valid channel id");
305        assert_eq!(channel_id.as_str(), "channel-0");
306    }
307
308    #[test]
309    fn channel_id_into_inner_returns_owned_value() {
310        let channel_id = ChannelId::new("channel-0").expect("valid channel id");
311        assert_eq!(channel_id.into_inner(), "channel-0");
312    }
313
314    #[test]
315    fn channel_id_rejects_invalid_identifiers() {
316        let err = ChannelId::new("-invalid").expect_err("expected validation error");
317        assert!(matches!(
318            err,
319            LavaFlowError::InvalidChannelId {
320                reason: ValidationReason::InvalidStartCharacter,
321                ..
322            }
323        ));
324    }
325
326    #[test]
327    fn process_location_with_ids_preserves_metadata() {
328        let location =
329            ProcessLocation::with_ids("gpu-node-0", Some(7), Some(3)).expect("valid location");
330        assert_eq!(location.hostname(), "gpu-node-0");
331        assert_eq!(location.node_id(), Some(7));
332        assert_eq!(location.device_id(), Some(3));
333    }
334
335    #[test]
336    fn process_location_with_ids_rejects_invalid_hostname() {
337        let err = ProcessLocation::with_ids("gpu node 0", Some(7), Some(3))
338            .expect_err("expected hostname validation error");
339        assert!(matches!(
340            err,
341            LavaFlowError::InvalidHostname {
342                reason: ValidationReason::InvalidCharacters,
343                ..
344            }
345        ));
346    }
347
348    #[test]
349    fn process_location_from_hostname_works() {
350        let location = ProcessLocation::from_hostname().expect("hostname lookup should succeed");
351        assert!(!location.hostname().is_empty());
352    }
353
354    #[test]
355    fn scope_detection_is_local_when_hostnames_match() {
356        let my_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
357        let peer_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
358
359        let scope = CommunicationScope::from_locations(&my_location, &peer_location);
360
361        assert_eq!(scope, CommunicationScope::Local);
362    }
363
364    #[test]
365    fn scope_detection_is_remote_when_hostnames_differ() {
366        let my_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
367        let peer_location = ProcessLocation::new("gpu-node-1").expect("valid hostname");
368
369        let scope = CommunicationScope::from_locations(&my_location, &peer_location);
370
371        assert_eq!(scope, CommunicationScope::Remote);
372    }
373
374    #[test]
375    fn scope_detection_is_case_insensitive_after_normalization() {
376        let my_location = ProcessLocation::new("GPU-NODE-0").expect("valid hostname");
377        let peer_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
378
379        let scope = detect_scope(&my_location, &peer_location);
380
381        assert_eq!(scope, CommunicationScope::Local);
382    }
383
384    #[test]
385    fn process_location_rejects_empty_hostname() {
386        let err = ProcessLocation::new("").expect_err("expected hostname validation error");
387
388        assert!(matches!(
389            err,
390            LavaFlowError::InvalidHostname {
391                reason: ValidationReason::Empty,
392                ..
393            }
394        ));
395    }
396
397    #[test]
398    fn process_location_rejects_invalid_hostname_characters() {
399        let err =
400            ProcessLocation::new("gpu node 0").expect_err("expected hostname validation error");
401
402        assert!(matches!(
403            err,
404            LavaFlowError::InvalidHostname {
405                reason: ValidationReason::InvalidCharacters,
406                ..
407            }
408        ));
409    }
410
411    #[test]
412    fn process_location_rejects_invalid_start_hostname() {
413        let err =
414            ProcessLocation::new("-gpu-node-0").expect_err("expected hostname validation error");
415
416        assert!(matches!(
417            err,
418            LavaFlowError::InvalidHostname {
419                reason: ValidationReason::InvalidStartCharacter,
420                ..
421            }
422        ));
423    }
424
425    #[test]
426    fn process_location_rejects_too_long_hostname() {
427        let too_long = "a".repeat(MAX_HOSTNAME_LEN + 1);
428        let err = ProcessLocation::new(too_long).expect_err("expected hostname validation error");
429
430        assert!(matches!(
431            err,
432            LavaFlowError::InvalidHostname {
433                reason: ValidationReason::HostnameTooLong,
434                ..
435            }
436        ));
437    }
438
439    #[test]
440    fn process_name_deserialize_rejects_invalid_value() {
441        let err = serde_json::from_str::<ProcessName>("\"GPU-WORKER\"")
442            .expect_err("expected deserialization validation error");
443        assert!(err.to_string().contains("invalid process name"));
444    }
445
446    #[test]
447    fn channel_id_deserialize_rejects_invalid_value() {
448        let err = serde_json::from_str::<ChannelId>("\"-invalid\"")
449            .expect_err("expected deserialization validation error");
450        assert!(err.to_string().contains("invalid channel id"));
451    }
452
453    #[test]
454    fn process_location_deserialize_rejects_invalid_hostname() {
455        let payload = json!({
456            "hostname": "gpu node 0",
457            "node_id": 1,
458            "device_id": 0
459        })
460        .to_string();
461        let err = serde_json::from_str::<ProcessLocation>(&payload)
462            .expect_err("expected deserialization validation error");
463        assert!(err.to_string().contains("invalid hostname"));
464    }
465
466    #[test]
467    fn process_location_deserialize_normalizes_hostname() {
468        let payload = json!({
469            "hostname": "GPU-NODE-0",
470            "node_id": 1,
471            "device_id": 0
472        })
473        .to_string();
474        let location =
475            serde_json::from_str::<ProcessLocation>(&payload).expect("valid deserialization");
476        assert_eq!(location.hostname(), "gpu-node-0");
477        assert_eq!(location.node_id(), Some(1));
478        assert_eq!(location.device_id(), Some(0));
479    }
480}