Skip to content

Commit

Permalink
Added jamming avoidance files
Browse files Browse the repository at this point in the history
  • Loading branch information
dania-tii committed Oct 12, 2023
1 parent 9705de3 commit b91a26d
Show file tree
Hide file tree
Showing 43 changed files with 95,085 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F

from options import Options, VALID_CHANNELS
from util import map_freq_to_channel


class ChannelQualityEstimator:
"""
A class for estimating the quality of communication channels using a pre-trained ResCNN model.
This class utilizes a pre-trained ResCNN model to estimate the channel quality based on input features.
The channel quality is computed as a score that reflects the difference between good and jamming states.
Positive channel quality values indicate a better channel quality, while negative values suggest
potential jamming or lower channel quality.
Attributes:
device (str): A PyTorch device ('cuda' or 'cpu') to run the model on.
model (ResCNN): The pre-trained ResCNN model for channel quality estimation.
"""

def __init__(self) -> None:
"""
Initializes the ChannelQualityEstimator object.
Loads the pre-trained traced ResCNN model and sets the device to 'cuda' if available, 'cpu' otherwise.
"""
self.args = Options()
# Model related attributes
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
self.model = torch.jit.load(self.args.traced_model_path)
self.model = self.model.to(self.device)
self.model.eval()
except FileNotFoundError:
raise FileNotFoundError("Model file not found")
except Exception as e:
print(f"Exception error in loading pretrained model: {e}") if self.args.debug else None

def _forward(self, feat_array: np.ndarray) -> np.ndarray:
"""
Computes the class probabilities for a given feature array.
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features.
:return: A 2D NumPy array of shape (n_samples, n_classes) containing the model's class probabilities.
"""
with torch.no_grad():
# Create tensors
inputs = torch.from_numpy(feat_array).float()
inputs = inputs.to(self.device, non_blocking=True)
# Feed inputs and compute jamming probability for each frequency
out = self.model(inputs)
# Compute the probabilities
probs = F.softmax(out, dim=1)
return probs.cpu().numpy()

def _compute_channel_quality(self, probs: np.ndarray) -> np.ndarray:
"""
Computes the channel quality scores based on the model's predictions.
The channel quality is calculated as the difference between the weighted sum of good state probabilities
and the average probability of jamming states. Positive channel quality values indicate a better channel
quality, while negative values suggest potential jamming or lower channel quality.
:param probs: A 2D NumPy array of shape (n_channels, n_classes) containing the model's class probabilities.
:return: A 1D NumPy array of shape (n_channels,) representing the channel quality scores for each channel.
"""
# Weights for good states (communication, floor, inter_mid, inter_high)
good_weights = np.array([1.0, 0.6, 0.4])
# Compute good state probabilities
good_probs = probs[:, :3] # Get the probabilities for the first 3 states
# Compute a score based on the good state probabilities and their weights
good_scores = np.dot(good_probs, good_weights)
# Compute jamming state probabilities
jamming_probs = probs[:, 3:] # Get the probabilities for the remaining jamming states
# Compute a jamming score based on the sum of the jamming probabilities
jamming_scores = np.sum(jamming_probs, axis=1)
# Compute channel quality as the difference between good scores and jamming scores
channel_quality = good_scores - jamming_scores

return channel_quality

def check_arrays(self, feat_array: np.ndarray, frequencies: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Check if numpy arrays contain NaN values or contain inf values.
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features.
:param frequencies: A 1D NumPy array of shape (n_channels,) containing the frequencies of the channels.
"""
# Check if frequencies contains NaN values
if np.isnan(frequencies).any():
frequencies_nan_mask = np.isnan(frequencies)
frequencies = frequencies[~frequencies_nan_mask]
# Reshape feat_array to match the new size of frequencies
feat_array = feat_array[~frequencies_nan_mask]

# Check if feat_array contains NaN values
elif np.isnan(feat_array).any():
feat_array_nan_mask = np.isnan(feat_array)
feat_array = feat_array[~feat_array_nan_mask]

# Check if frequencies contains infinite values
if np.isinf(frequencies).any():
frequencies_inf_mask = np.isinf(frequencies)
frequencies = frequencies[~frequencies_inf_mask]
# Reshape feat_array to match the new size of frequencies
feat_array = feat_array[~frequencies_inf_mask]

# Check if feat_array contains infinite values
elif np.isinf(feat_array).any():
feat_array_inf_mask = np.isinf(feat_array)
feat_array = feat_array[~feat_array_inf_mask]

return feat_array, frequencies

def estimate(self, feat_array: np.ndarray, frequencies: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Estimates the channel quality for a given feature array.
:param feat_array: A 2D NumPy array of shape (n_samples, n_features) containing the input features.
:param frequencies: A 1D NumPy array of shape (n_channels,) containing the frequencies of the channels.
:return: A tuple containing the estimated channel quality scores (1D NumPy array) and the class probabilities
(2D NumPy array).
"""
feat_array, frequencies = self.check_arrays(feat_array, frequencies)

# Check if feat_array and frequencies arrays are empty
if feat_array.size == 0 or frequencies.size == 0:
return np.array([]), np.empty((0, 0)), np.array([])

try:
# Compute class probabilities for the given features
probs = self._forward(feat_array)
# Compute the channel quality
channel_quality = self._compute_channel_quality(probs)
try:
freq_list = [int(freq) for freq in frequencies.tolist()]
channel_list = [map_freq_to_channel(freq) for freq in freq_list]
mask = np.isin(channel_list, VALID_CHANNELS)
# Normalize the quality values
quality_normalized = (channel_quality[mask] - (-1)) / (1 - (-1))
except Exception:
quality_normalized = np.array([])

return channel_quality, probs, quality_normalized
except Exception:
return np.array([]), np.empty((0, 0)), np.array([])
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import socket
import threading
import time
from abc import ABC, abstractmethod


class FeatureClient(ABC, threading.Thread):
def __init__(self, node_id: str, host: str, port: int) -> None:
"""
Initialize the FeatureClient object.
:param node_id: An integer representing the node ID.
:param host: A string representing the host address.
:param port: An integer representing the port number.
"""
super().__init__()
self.node_id = node_id
self.host = host
self.port = port
self.running = threading.Event()
self.switching = threading.Event()
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)

@abstractmethod
def run(self) -> None:
self.socket.connect((self.host, self.port))
self.running.set()
receive_thread = threading.Thread(target=self.receive_messages)
receive_thread.start()

@abstractmethod
def run_client_fsm(self) -> None:
pass

@abstractmethod
def receive_messages(self) -> None:
"""
Receive messages from the socket server.
"""
while True:
try:
message = self.socket.recv(1024).decode()
if not message:
print("No message... break")
break

# Split the received message into individual JSON objects
json_objects = message.strip().split('\n')
for json_object_str in json_objects:
try:
print(f"Received message: {message}")
json_object = json.loads(json_object_str)
action = json_object.get("action")

# Jamming Policy
if action == "broadcast":
print(f"Broadcast message received...")

except json.JSONDecodeError as e:
print(f"Failed to decode JSON: {e}")

except ConnectionResetError:
print("Connection forcibly closed by the remote host")
break


@abstractmethod
def send_messages(self, action) -> None:
"""
Send message to the server.
:param action: Action to be taken by client.
"""
data = [{'action': 'broadcast', 'node_id': self.node_id},
{'action': 'broadcast', 'node_id': self.node_id}]

for message in data:
json_str = json.dumps(message)
self.socket.send(json_str.encode())
print("Sent message to server")
time.sleep(5)

@abstractmethod
def stop(self) -> None:
self.running.clear()
self.socket.close()
self.join()
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import logging
import signal
import socket
import sys
import threading
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any

logging.basicConfig(level=logging.INFO)


class FeatureServer(ABC):
def __init__(self, host: str, port: int):
"""
Initializes the FeatureServer object.
:param host: The host address to bind the server to.
:param port: The port number to bind the server to.
"""
self.host = host
self.port = port
self.clients: List[FeatureClientTwin] = []
self.serversocket = None

@abstractmethod
def start(self) -> None:
"""
Starts the server and listens for incoming client connections.
"""
signal.signal(signal.SIGINT, self.signal_handler)

self.serversocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.serversocket.bind((self.host, self.port))
self.serversocket.listen(5)
logging.info("server started and listening")

while True:
c_socket, c_address = self.serversocket.accept()
client = FeatureClientTwin(c_socket, c_address, self.clients, self.host)
print('New connection', client)
self.clients.append(client)

@abstractmethod
def signal_handler(self, sig: signal.Signals, frame) -> None:
"""
Handles a signal interrupt (SIGINT) and stops the server gracefully.
:param sig: The signal received by the handler.
:param frame: The current execution frame.
"""
print("Attempting to close threads.")
for client in self.clients:
print("joining", client.address)
client.stop()

print("threads successfully closed")
sys.exit(0)

@abstractmethod
def send_data_clients(self, message: Dict[str, Any]) -> None:
"""
Sends a message to all connected clients except for the sender.
:param message: The message to broadcast.
"""
print("Broadcasting channel switch info to all nodes...")
for client in self.clients:
try:
client.socket.sendall(json.dumps(message).encode())
except BrokenPipeError:
print("Broken pipe error, client disconnected:", client.address)
self.clients.remove(client)

@abstractmethod
def run_server_fsm(self) -> None:
"""
Runs the server finite state machine.
"""
pass


class FeatureClientTwin(ABC, threading.Thread):
def __init__(self, socket: socket.socket, address: Tuple[str, int], clients: List["FeatureClientTwin"], host: str):
"""
Initializes the FeatureClientTwin object.
:param socket: The connected socket for the client.
:param address: The address of the client.
:param clients: A list of all connected clients.
"""
threading.Thread.__init__(self)
self.socket = socket
self.address = address
self.running = True
self.clients = clients
self.host = host
self.start()

@abstractmethod
def run(self):
# Create threads for server operations
listen_thread = threading.Thread(target=self.receive_messages)
listen_thread.start()

@abstractmethod
def receive_messages(self) -> None:
"""
Handles incoming messages from the client.
"""
while self.running:
try:
message = self.socket.recv(1024).decode()
if not message:
break

# Split the received message into individual JSON objects
json_objects = message.strip().split('\n')
for json_object_str in json_objects:
try:
json_object = json.loads(json_object_str)
action = json_object.get("action")
client_ip = json_object.get("node_id")

# Policy
if action == "broadcast":
print(f"Broadcast message received...")

except json.JSONDecodeError as e:
print(f"Failed to decode JSON: {e}")

except ConnectionResetError:
logging.warning("Connection forcibly closed by the remote host")
break

@abstractmethod
def stop(self) -> None:
"""
Stops the Client thread and closes the socket and database connections.
"""
self.running = False
self.socket.close()
Loading

0 comments on commit b91a26d

Please sign in to comment.