import torch
 
class SimpleModel(torch.nn.Module):
def __init__(self):
	super(SimpleModel, self).__init__()
	self.linear = torch.nn.Linear(10, 1)
 
def forward(self, x):
	return self.linear(x)
 
def test_model_overfit_single_sample():
# Create a simple model
model = SimpleModel()
 
# Generate random input data and target
input_data = torch.randn(1, 10)
target = torch.randn(1, 1)
 
# Define loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 
# Train the model on the single sample
for _ in range(100):
	optimizer.zero_grad()
	output = model(input_data)
	loss = criterion(output, target)
	loss.backward()
	optimizer.step()
 
# Check if the model has overfitted (i.e., the output is close to the target)
final_output = model(input_data)
 
assert torch.allclose(final_output, target), "Model did not overfit on the single sample"