1use serde::{Deserialize, Deserializer, Serialize};
2
3use crate::error::{LavaFlowError, Result, ValidationReason};
4
5const MAX_IDENTIFIER_LEN: usize = 64;
6const MAX_HOSTNAME_LEN: usize = 253;
7type ValidationError = (String, ValidationReason);
8type ValidationResult = std::result::Result<String, ValidationError>;
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)]
12pub struct ProcessName(String);
13
14impl ProcessName {
15 pub fn new(value: impl Into<String>) -> Result<Self> {
23 let value = validate_identifier(value.into())
24 .map_err(|(value, reason)| LavaFlowError::InvalidProcessName { value, reason })?;
25 Ok(Self(value))
26 }
27
28 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32
33 pub fn into_inner(self) -> String {
35 self.0
36 }
37}
38
39impl<'de> Deserialize<'de> for ProcessName {
40 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 let value = String::deserialize(deserializer)?;
45 ProcessName::new(value).map_err(serde::de::Error::custom)
46 }
47}
48
49#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)]
51pub struct ChannelId(String);
52
53impl ChannelId {
54 pub fn new(value: impl Into<String>) -> Result<Self> {
58 let value = validate_identifier(value.into())
59 .map_err(|(value, reason)| LavaFlowError::InvalidChannelId { value, reason })?;
60 Ok(Self(value))
61 }
62
63 pub fn as_str(&self) -> &str {
65 &self.0
66 }
67
68 pub fn into_inner(self) -> String {
70 self.0
71 }
72}
73
74impl<'de> Deserialize<'de> for ChannelId {
75 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
76 where
77 D: Deserializer<'de>,
78 {
79 let value = String::deserialize(deserializer)?;
80 ChannelId::new(value).map_err(serde::de::Error::custom)
81 }
82}
83
84#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
86pub struct ProcessLocation {
87 #[serde(deserialize_with = "deserialize_hostname")]
88 hostname: String,
89 node_id: Option<u32>,
90 device_id: Option<u32>,
91}
92
93impl ProcessLocation {
94 pub fn new(hostname: impl Into<String>) -> Result<Self> {
104 let hostname = validate_hostname(hostname.into())
105 .map_err(|(value, reason)| LavaFlowError::InvalidHostname { value, reason })?;
106 Ok(Self {
107 hostname,
108 node_id: None,
109 device_id: None,
110 })
111 }
112
113 pub fn with_ids(
115 hostname: impl Into<String>,
116 node_id: Option<u32>,
117 device_id: Option<u32>,
118 ) -> Result<Self> {
119 let hostname = validate_hostname(hostname.into())
120 .map_err(|(value, reason)| LavaFlowError::InvalidHostname { value, reason })?;
121 Ok(Self {
122 hostname,
123 node_id,
124 device_id,
125 })
126 }
127
128 pub fn from_hostname() -> Result<Self> {
132 let hostname = hostname::get().map_err(LavaFlowError::HostnameDetection)?;
133 let hostname = hostname.to_string_lossy().trim().to_string();
134 Self::new(hostname)
135 }
136
137 pub fn hostname(&self) -> &str {
139 &self.hostname
140 }
141
142 pub fn node_id(&self) -> Option<u32> {
144 self.node_id
145 }
146
147 pub fn device_id(&self) -> Option<u32> {
149 self.device_id
150 }
151}
152
153#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
155pub enum CommunicationScope {
156 Local,
158 Remote,
160}
161
162impl CommunicationScope {
163 pub fn from_locations(my_location: &ProcessLocation, peer_location: &ProcessLocation) -> Self {
165 if my_location.hostname == peer_location.hostname {
166 Self::Local
167 } else {
168 Self::Remote
169 }
170 }
171}
172
173pub fn detect_scope(
175 my_location: &ProcessLocation,
176 peer_location: &ProcessLocation,
177) -> CommunicationScope {
178 CommunicationScope::from_locations(my_location, peer_location)
179}
180
181fn validate_identifier(value: String) -> ValidationResult {
182 if value.is_empty() {
185 return Err((value, ValidationReason::Empty));
186 }
187 if value.len() > MAX_IDENTIFIER_LEN {
188 return Err((value, ValidationReason::IdentifierTooLong));
189 }
190 let mut chars = value.chars();
191 let first = chars.next().expect("identifier checked non-empty");
192 if !first.is_ascii_alphanumeric() {
193 return Err((value, ValidationReason::InvalidStartCharacter));
194 }
195 if !value
196 .chars()
197 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
198 {
199 return Err((value, ValidationReason::InvalidCharacters));
200 }
201 Ok(value)
202}
203
204fn validate_hostname(value: String) -> ValidationResult {
205 let value = value.trim().to_string();
206 if value.is_empty() {
207 return Err((value, ValidationReason::Empty));
208 }
209 if value.len() > MAX_HOSTNAME_LEN {
210 return Err((value, ValidationReason::HostnameTooLong));
211 }
212 let mut chars = value.chars();
213 let first = chars.next().expect("hostname checked non-empty");
214 if !first.is_ascii_alphanumeric() {
215 return Err((value, ValidationReason::InvalidStartCharacter));
216 }
217 if !value
218 .chars()
219 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_')
220 {
221 return Err((value, ValidationReason::InvalidCharacters));
222 }
223 Ok(value.to_ascii_lowercase())
224}
225
226fn deserialize_hostname<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
227where
228 D: Deserializer<'de>,
229{
230 let value = String::deserialize(deserializer)?;
231 validate_hostname(value).map_err(|(value, reason)| {
232 serde::de::Error::custom(LavaFlowError::InvalidHostname { value, reason })
233 })
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use serde_json::json;
240
241 #[test]
242 fn process_name_accepts_valid_identifiers() {
243 let name = ProcessName::new("gpu_worker_01").expect("valid process name");
244 assert_eq!(name.as_str(), "gpu_worker_01");
245 }
246
247 #[test]
248 fn process_name_into_inner_returns_owned_value() {
249 let name = ProcessName::new("gpu_worker_01").expect("valid process name");
250 assert_eq!(name.into_inner(), "gpu_worker_01");
251 }
252
253 #[test]
254 fn process_name_rejects_invalid_identifiers() {
255 let err = ProcessName::new("GPU-WORKER").expect_err("expected validation error");
256 assert!(matches!(
257 err,
258 LavaFlowError::InvalidProcessName {
259 reason: ValidationReason::InvalidCharacters,
260 ..
261 }
262 ));
263 }
264
265 #[test]
266 fn process_name_rejects_empty_identifiers() {
267 let err = ProcessName::new("").expect_err("expected validation error");
268 assert!(matches!(
269 err,
270 LavaFlowError::InvalidProcessName {
271 reason: ValidationReason::Empty,
272 ..
273 }
274 ));
275 }
276
277 #[test]
278 fn process_name_rejects_too_long_identifiers() {
279 let too_long = "a".repeat(MAX_IDENTIFIER_LEN + 1);
280 let err = ProcessName::new(too_long).expect_err("expected validation error");
281 assert!(matches!(
282 err,
283 LavaFlowError::InvalidProcessName {
284 reason: ValidationReason::IdentifierTooLong,
285 ..
286 }
287 ));
288 }
289
290 #[test]
291 fn process_name_rejects_invalid_start_identifiers() {
292 let err = ProcessName::new("-gpu").expect_err("expected validation error");
293 assert!(matches!(
294 err,
295 LavaFlowError::InvalidProcessName {
296 reason: ValidationReason::InvalidStartCharacter,
297 ..
298 }
299 ));
300 }
301
302 #[test]
303 fn channel_id_accepts_valid_identifiers() {
304 let channel_id = ChannelId::new("channel-0").expect("valid channel id");
305 assert_eq!(channel_id.as_str(), "channel-0");
306 }
307
308 #[test]
309 fn channel_id_into_inner_returns_owned_value() {
310 let channel_id = ChannelId::new("channel-0").expect("valid channel id");
311 assert_eq!(channel_id.into_inner(), "channel-0");
312 }
313
314 #[test]
315 fn channel_id_rejects_invalid_identifiers() {
316 let err = ChannelId::new("-invalid").expect_err("expected validation error");
317 assert!(matches!(
318 err,
319 LavaFlowError::InvalidChannelId {
320 reason: ValidationReason::InvalidStartCharacter,
321 ..
322 }
323 ));
324 }
325
326 #[test]
327 fn process_location_with_ids_preserves_metadata() {
328 let location =
329 ProcessLocation::with_ids("gpu-node-0", Some(7), Some(3)).expect("valid location");
330 assert_eq!(location.hostname(), "gpu-node-0");
331 assert_eq!(location.node_id(), Some(7));
332 assert_eq!(location.device_id(), Some(3));
333 }
334
335 #[test]
336 fn process_location_with_ids_rejects_invalid_hostname() {
337 let err = ProcessLocation::with_ids("gpu node 0", Some(7), Some(3))
338 .expect_err("expected hostname validation error");
339 assert!(matches!(
340 err,
341 LavaFlowError::InvalidHostname {
342 reason: ValidationReason::InvalidCharacters,
343 ..
344 }
345 ));
346 }
347
348 #[test]
349 fn process_location_from_hostname_works() {
350 let location = ProcessLocation::from_hostname().expect("hostname lookup should succeed");
351 assert!(!location.hostname().is_empty());
352 }
353
354 #[test]
355 fn scope_detection_is_local_when_hostnames_match() {
356 let my_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
357 let peer_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
358
359 let scope = CommunicationScope::from_locations(&my_location, &peer_location);
360
361 assert_eq!(scope, CommunicationScope::Local);
362 }
363
364 #[test]
365 fn scope_detection_is_remote_when_hostnames_differ() {
366 let my_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
367 let peer_location = ProcessLocation::new("gpu-node-1").expect("valid hostname");
368
369 let scope = CommunicationScope::from_locations(&my_location, &peer_location);
370
371 assert_eq!(scope, CommunicationScope::Remote);
372 }
373
374 #[test]
375 fn scope_detection_is_case_insensitive_after_normalization() {
376 let my_location = ProcessLocation::new("GPU-NODE-0").expect("valid hostname");
377 let peer_location = ProcessLocation::new("gpu-node-0").expect("valid hostname");
378
379 let scope = detect_scope(&my_location, &peer_location);
380
381 assert_eq!(scope, CommunicationScope::Local);
382 }
383
384 #[test]
385 fn process_location_rejects_empty_hostname() {
386 let err = ProcessLocation::new("").expect_err("expected hostname validation error");
387
388 assert!(matches!(
389 err,
390 LavaFlowError::InvalidHostname {
391 reason: ValidationReason::Empty,
392 ..
393 }
394 ));
395 }
396
397 #[test]
398 fn process_location_rejects_invalid_hostname_characters() {
399 let err =
400 ProcessLocation::new("gpu node 0").expect_err("expected hostname validation error");
401
402 assert!(matches!(
403 err,
404 LavaFlowError::InvalidHostname {
405 reason: ValidationReason::InvalidCharacters,
406 ..
407 }
408 ));
409 }
410
411 #[test]
412 fn process_location_rejects_invalid_start_hostname() {
413 let err =
414 ProcessLocation::new("-gpu-node-0").expect_err("expected hostname validation error");
415
416 assert!(matches!(
417 err,
418 LavaFlowError::InvalidHostname {
419 reason: ValidationReason::InvalidStartCharacter,
420 ..
421 }
422 ));
423 }
424
425 #[test]
426 fn process_location_rejects_too_long_hostname() {
427 let too_long = "a".repeat(MAX_HOSTNAME_LEN + 1);
428 let err = ProcessLocation::new(too_long).expect_err("expected hostname validation error");
429
430 assert!(matches!(
431 err,
432 LavaFlowError::InvalidHostname {
433 reason: ValidationReason::HostnameTooLong,
434 ..
435 }
436 ));
437 }
438
439 #[test]
440 fn process_name_deserialize_rejects_invalid_value() {
441 let err = serde_json::from_str::<ProcessName>("\"GPU-WORKER\"")
442 .expect_err("expected deserialization validation error");
443 assert!(err.to_string().contains("invalid process name"));
444 }
445
446 #[test]
447 fn channel_id_deserialize_rejects_invalid_value() {
448 let err = serde_json::from_str::<ChannelId>("\"-invalid\"")
449 .expect_err("expected deserialization validation error");
450 assert!(err.to_string().contains("invalid channel id"));
451 }
452
453 #[test]
454 fn process_location_deserialize_rejects_invalid_hostname() {
455 let payload = json!({
456 "hostname": "gpu node 0",
457 "node_id": 1,
458 "device_id": 0
459 })
460 .to_string();
461 let err = serde_json::from_str::<ProcessLocation>(&payload)
462 .expect_err("expected deserialization validation error");
463 assert!(err.to_string().contains("invalid hostname"));
464 }
465
466 #[test]
467 fn process_location_deserialize_normalizes_hostname() {
468 let payload = json!({
469 "hostname": "GPU-NODE-0",
470 "node_id": 1,
471 "device_id": 0
472 })
473 .to_string();
474 let location =
475 serde_json::from_str::<ProcessLocation>(&payload).expect("valid deserialization");
476 assert_eq!(location.hostname(), "gpu-node-0");
477 assert_eq!(location.node_id(), Some(1));
478 assert_eq!(location.device_id(), Some(0));
479 }
480}