-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
43 changed files
with
95,085 additions
and
0 deletions.
There are no files selected for viewing
151 changes: 151 additions & 0 deletions
151
modules/sc-mesh-secure-deployment/src/2_0/features/jamming/channel_quality_estimator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from options import Options, VALID_CHANNELS | ||
from util import map_freq_to_channel | ||
|
||
|
||
class ChannelQualityEstimator: | ||
""" | ||
A class for estimating the quality of communication channels using a pre-trained ResCNN model. | ||
This class utilizes a pre-trained ResCNN model to estimate the channel quality based on input features. | ||
The channel quality is computed as a score that reflects the difference between good and jamming states. | ||
Positive channel quality values indicate a better channel quality, while negative values suggest | ||
potential jamming or lower channel quality. | ||
Attributes: | ||
device (str): A PyTorch device ('cuda' or 'cpu') to run the model on. | ||
model (ResCNN): The pre-trained ResCNN model for channel quality estimation. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
""" | ||
Initializes the ChannelQualityEstimator object. | ||
Loads the pre-trained traced ResCNN model and sets the device to 'cuda' if available, 'cpu' otherwise. | ||
""" | ||
self.args = Options() | ||
# Model related attributes | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
try: | ||
self.model = torch.jit.load(self.args.traced_model_path) | ||
self.model = self.model.to(self.device) | ||
self.model.eval() | ||
except FileNotFoundError: | ||
raise FileNotFoundError("Model file not found") | ||
except Exception as e: | ||
print(f"Exception error in loading pretrained model: {e}") if self.args.debug else None | ||
|
||
def _forward(self, feat_array: np.ndarray) -> np.ndarray: | ||
""" | ||
Computes the class probabilities for a given feature array. | ||
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features. | ||
:return: A 2D NumPy array of shape (n_samples, n_classes) containing the model's class probabilities. | ||
""" | ||
with torch.no_grad(): | ||
# Create tensors | ||
inputs = torch.from_numpy(feat_array).float() | ||
inputs = inputs.to(self.device, non_blocking=True) | ||
# Feed inputs and compute jamming probability for each frequency | ||
out = self.model(inputs) | ||
# Compute the probabilities | ||
probs = F.softmax(out, dim=1) | ||
return probs.cpu().numpy() | ||
|
||
def _compute_channel_quality(self, probs: np.ndarray) -> np.ndarray: | ||
""" | ||
Computes the channel quality scores based on the model's predictions. | ||
The channel quality is calculated as the difference between the weighted sum of good state probabilities | ||
and the average probability of jamming states. Positive channel quality values indicate a better channel | ||
quality, while negative values suggest potential jamming or lower channel quality. | ||
:param probs: A 2D NumPy array of shape (n_channels, n_classes) containing the model's class probabilities. | ||
:return: A 1D NumPy array of shape (n_channels,) representing the channel quality scores for each channel. | ||
""" | ||
# Weights for good states (communication, floor, inter_mid, inter_high) | ||
good_weights = np.array([1.0, 0.6, 0.4]) | ||
# Compute good state probabilities | ||
good_probs = probs[:, :3] # Get the probabilities for the first 3 states | ||
# Compute a score based on the good state probabilities and their weights | ||
good_scores = np.dot(good_probs, good_weights) | ||
# Compute jamming state probabilities | ||
jamming_probs = probs[:, 3:] # Get the probabilities for the remaining jamming states | ||
# Compute a jamming score based on the sum of the jamming probabilities | ||
jamming_scores = np.sum(jamming_probs, axis=1) | ||
# Compute channel quality as the difference between good scores and jamming scores | ||
channel_quality = good_scores - jamming_scores | ||
|
||
return channel_quality | ||
|
||
def check_arrays(self, feat_array: np.ndarray, frequencies: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Check if numpy arrays contain NaN values or contain inf values. | ||
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features. | ||
:param frequencies: A 1D NumPy array of shape (n_channels,) containing the frequencies of the channels. | ||
""" | ||
# Check if frequencies contains NaN values | ||
if np.isnan(frequencies).any(): | ||
frequencies_nan_mask = np.isnan(frequencies) | ||
frequencies = frequencies[~frequencies_nan_mask] | ||
# Reshape feat_array to match the new size of frequencies | ||
feat_array = feat_array[~frequencies_nan_mask] | ||
|
||
# Check if feat_array contains NaN values | ||
elif np.isnan(feat_array).any(): | ||
feat_array_nan_mask = np.isnan(feat_array) | ||
feat_array = feat_array[~feat_array_nan_mask] | ||
|
||
# Check if frequencies contains infinite values | ||
if np.isinf(frequencies).any(): | ||
frequencies_inf_mask = np.isinf(frequencies) | ||
frequencies = frequencies[~frequencies_inf_mask] | ||
# Reshape feat_array to match the new size of frequencies | ||
feat_array = feat_array[~frequencies_inf_mask] | ||
|
||
# Check if feat_array contains infinite values | ||
elif np.isinf(feat_array).any(): | ||
feat_array_inf_mask = np.isinf(feat_array) | ||
feat_array = feat_array[~feat_array_inf_mask] | ||
|
||
return feat_array, frequencies | ||
|
||
def estimate(self, feat_array: np.ndarray, frequencies: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
""" | ||
Estimates the channel quality for a given feature array. | ||
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features. | ||
:param frequencies: A 1D NumPy array of shape (n_channels,) containing the frequencies of the channels. | ||
:return: A tuple containing the estimated channel quality scores (1D NumPy array) and the class probabilities | ||
(2D NumPy array). | ||
""" | ||
feat_array, frequencies = self.check_arrays(feat_array, frequencies) | ||
|
||
# Check if feat_array and frequencies arrays are empty | ||
if feat_array.size == 0 or frequencies.size == 0: | ||
return np.array([]), np.empty((0, 0)), np.array([]) | ||
|
||
try: | ||
# Compute class probabilities for the given features | ||
probs = self._forward(feat_array) | ||
# Compute the channel quality | ||
channel_quality = self._compute_channel_quality(probs) | ||
try: | ||
freq_list = [int(freq) for freq in frequencies.tolist()] | ||
channel_list = [map_freq_to_channel(freq) for freq in freq_list] | ||
mask = np.isin(channel_list, VALID_CHANNELS) | ||
# Normalize the quality values | ||
quality_normalized = (channel_quality[mask] - (-1)) / (1 - (-1)) | ||
except Exception: | ||
quality_normalized = np.array([]) | ||
|
||
return channel_quality, probs, quality_normalized | ||
except Exception: | ||
return np.array([]), np.empty((0, 0)), np.array([]) |
88 changes: 88 additions & 0 deletions
88
modules/sc-mesh-secure-deployment/src/2_0/features/jamming/feature_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import socket | ||
import threading | ||
import time | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class FeatureClient(ABC, threading.Thread): | ||
def __init__(self, node_id: str, host: str, port: int) -> None: | ||
""" | ||
Initialize the FeatureClient object. | ||
:param node_id: An integer representing the node ID. | ||
:param host: A string representing the host address. | ||
:param port: An integer representing the port number. | ||
""" | ||
super().__init__() | ||
self.node_id = node_id | ||
self.host = host | ||
self.port = port | ||
self.running = threading.Event() | ||
self.switching = threading.Event() | ||
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) | ||
|
||
@abstractmethod | ||
def run(self) -> None: | ||
self.socket.connect((self.host, self.port)) | ||
self.running.set() | ||
receive_thread = threading.Thread(target=self.receive_messages) | ||
receive_thread.start() | ||
|
||
@abstractmethod | ||
def run_client_fsm(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def receive_messages(self) -> None: | ||
""" | ||
Receive messages from the socket server. | ||
""" | ||
while True: | ||
try: | ||
message = self.socket.recv(1024).decode() | ||
if not message: | ||
print("No message... break") | ||
break | ||
|
||
# Split the received message into individual JSON objects | ||
json_objects = message.strip().split('\n') | ||
for json_object_str in json_objects: | ||
try: | ||
print(f"Received message: {message}") | ||
json_object = json.loads(json_object_str) | ||
action = json_object.get("action") | ||
|
||
# Jamming Policy | ||
if action == "broadcast": | ||
print(f"Broadcast message received...") | ||
|
||
except json.JSONDecodeError as e: | ||
print(f"Failed to decode JSON: {e}") | ||
|
||
except ConnectionResetError: | ||
print("Connection forcibly closed by the remote host") | ||
break | ||
|
||
|
||
@abstractmethod | ||
def send_messages(self, action) -> None: | ||
""" | ||
Send message to the server. | ||
:param action: Action to be taken by client. | ||
""" | ||
data = [{'action': 'broadcast', 'node_id': self.node_id}, | ||
{'action': 'broadcast', 'node_id': self.node_id}] | ||
|
||
for message in data: | ||
json_str = json.dumps(message) | ||
self.socket.send(json_str.encode()) | ||
print("Sent message to server") | ||
time.sleep(5) | ||
|
||
@abstractmethod | ||
def stop(self) -> None: | ||
self.running.clear() | ||
self.socket.close() | ||
self.join() |
142 changes: 142 additions & 0 deletions
142
modules/sc-mesh-secure-deployment/src/2_0/features/jamming/feature_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import json | ||
import logging | ||
import signal | ||
import socket | ||
import sys | ||
import threading | ||
from abc import ABC, abstractmethod | ||
from typing import List, Tuple, Dict, Any | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class FeatureServer(ABC): | ||
def __init__(self, host: str, port: int): | ||
""" | ||
Initializes the FeatureServer object. | ||
:param host: The host address to bind the server to. | ||
:param port: The port number to bind the server to. | ||
""" | ||
self.host = host | ||
self.port = port | ||
self.clients: List[FeatureClientTwin] = [] | ||
self.serversocket = None | ||
|
||
@abstractmethod | ||
def start(self) -> None: | ||
""" | ||
Starts the server and listens for incoming client connections. | ||
""" | ||
signal.signal(signal.SIGINT, self.signal_handler) | ||
|
||
self.serversocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) | ||
self.serversocket.bind((self.host, self.port)) | ||
self.serversocket.listen(5) | ||
logging.info("server started and listening") | ||
|
||
while True: | ||
c_socket, c_address = self.serversocket.accept() | ||
client = FeatureClientTwin(c_socket, c_address, self.clients, self.host) | ||
print('New connection', client) | ||
self.clients.append(client) | ||
|
||
@abstractmethod | ||
def signal_handler(self, sig: signal.Signals, frame) -> None: | ||
""" | ||
Handles a signal interrupt (SIGINT) and stops the server gracefully. | ||
:param sig: The signal received by the handler. | ||
:param frame: The current execution frame. | ||
""" | ||
print("Attempting to close threads.") | ||
for client in self.clients: | ||
print("joining", client.address) | ||
client.stop() | ||
|
||
print("threads successfully closed") | ||
sys.exit(0) | ||
|
||
@abstractmethod | ||
def send_data_clients(self, message: Dict[str, Any]) -> None: | ||
""" | ||
Sends a message to all connected clients except for the sender. | ||
:param message: The message to broadcast. | ||
""" | ||
print("Broadcasting channel switch info to all nodes...") | ||
for client in self.clients: | ||
try: | ||
client.socket.sendall(json.dumps(message).encode()) | ||
except BrokenPipeError: | ||
print("Broken pipe error, client disconnected:", client.address) | ||
self.clients.remove(client) | ||
|
||
@abstractmethod | ||
def run_server_fsm(self) -> None: | ||
""" | ||
Runs the server finite state machine. | ||
""" | ||
pass | ||
|
||
|
||
class FeatureClientTwin(ABC, threading.Thread): | ||
def __init__(self, socket: socket.socket, address: Tuple[str, int], clients: List["FeatureClientTwin"], host: str): | ||
""" | ||
Initializes the FeatureClientTwin object. | ||
:param socket: The connected socket for the client. | ||
:param address: The address of the client. | ||
:param clients: A list of all connected clients. | ||
""" | ||
threading.Thread.__init__(self) | ||
self.socket = socket | ||
self.address = address | ||
self.running = True | ||
self.clients = clients | ||
self.host = host | ||
self.start() | ||
|
||
@abstractmethod | ||
def run(self): | ||
# Create threads for server operations | ||
listen_thread = threading.Thread(target=self.receive_messages) | ||
listen_thread.start() | ||
|
||
@abstractmethod | ||
def receive_messages(self) -> None: | ||
""" | ||
Handles incoming messages from the client. | ||
""" | ||
while self.running: | ||
try: | ||
message = self.socket.recv(1024).decode() | ||
if not message: | ||
break | ||
|
||
# Split the received message into individual JSON objects | ||
json_objects = message.strip().split('\n') | ||
for json_object_str in json_objects: | ||
try: | ||
json_object = json.loads(json_object_str) | ||
action = json_object.get("action") | ||
client_ip = json_object.get("node_id") | ||
|
||
# Policy | ||
if action == "broadcast": | ||
print(f"Broadcast message received...") | ||
|
||
except json.JSONDecodeError as e: | ||
print(f"Failed to decode JSON: {e}") | ||
|
||
except ConnectionResetError: | ||
logging.warning("Connection forcibly closed by the remote host") | ||
break | ||
|
||
@abstractmethod | ||
def stop(self) -> None: | ||
""" | ||
Stops the Client thread and closes the socket and database connections. | ||
""" | ||
self.running = False | ||
self.socket.close() |
Oops, something went wrong.