from torch.utils.tensorboard import SummaryWriter
import torch
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
 
X, y = make_classification(n_samples=100, n_features=4, random_state=31)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
 
# Define a simple neural network
class SimpleNN(torch.nn.Module):
	def __init__(self):
		super(SimpleNN, self).__init__()
		self.fc1 = torch.nn.Linear(4, 8)
		self.fc2 = torch.nn.Linear(8, 2)
		self.counter = 0
	def forward(self, x, iteration: int = 0, **kwargs):
		x = torch.relu(self.fc1(x))
		x = self.fc2(x)
		return x
 
model = SimpleNN()
 
criterion = torch.nn.CrossEntropyLoss()
 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
def train(iteration_number: int=0):
	optimizer.zero_grad()
	output = model(torch.tensor(X_train).float(), iteration=iteration_number)
	loss = criterion(output, torch.tensor(y_train))
	loss.backward()
	optimizer.step()
 
summary_writer = SummaryWriter(log_dir="tb_logs")
 
# add hooks to model to get activations and gradients
def get_activation(name, iteration):
	def hook(mdl, input, output):
		summary_writer.add_histogram(f"activations_{name}", output.detach().numpy(), iteration)
	return hook
 
def get_gradient(name, iteration):
	def hook(mdl, grad_input, grad_output):
		if name == "fc1":
			model.counter += 1
		summary_writer.add_histogram(f"gradients_{name}", grad_output[0].detach().numpy(), iteration)
	return hook
 
for i in range(10):
	# Plot weights of each layer
	for layer, weights in model.named_parameters():
		summary_writer.add_histogram(f"weights_{layer}", weights.detach().numpy(), i)
	handles = []
	for layername, module in model.named_children():
		handles.append(dict(model.named_children())[layername].register_forward_hook(get_activation(layername, i), with_kwargs=False))
		handles.append(dict(model.named_children())[layername].register_forward_hook(get_gradient(layername, i), with_kwargs=False))
 
	train(i)
 
	for handle in handles:
		handle.remove()
	summary_writer.flush()
 
# Check for NaNs in weights
for layer, weights in model.named_parameters():
	if torch.isnan(weights).any():
		print(f"NaNs detected in {layer}")