What do we use it for?

Logistic regression is one of those tools that cannot be missing in a data scientist's toolbox. Whenever a binary classification task comes up on the plate, logistic regression can be your first choice for it's effectiveness and semplicity.

!git clone https://github.com/lorebucs/zenzo.git
%cd zenzo/_notebooks
import sys

import zenzo as z

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

Create dataset

f = lambda x: x
X, y = z.binary_classification_dataset(f, 1000, scale=2)
fig = plt.figure(figsize = (6, 6))
ax1 = fig.add_subplot(111)
ax1.set_xlim(left = -1.05, right=1.05)
ax1.set_ylim(bottom = -1.05, top=1.05)


z.plot_2Dpoints(X.numpy(), y.numpy(), ax1)
model = z.LogisticClassifier()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Train model

n_epochs = 20000
losses = []

fig = plt.figure(figsize=(6,6))
ax = fig.add_axes([0,0,1,1])
ax.set_xlim(left = -1.05, right=1.05)
ax.set_ylim(bottom = -1.05, top=1.05)

for epoch in range(n_epochs):  
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if epoch % (n_epochs // 100) == 0:
        weights = list(model.parameters())[0][0].detach().numpy()
        bias = list(model.parameters())[1][0].detach().numpy()
        z.plot_line(-weights[0]/weights[1], bias/weights[1], ax, color='g--')


weights = list(model.parameters())[0][0].detach().numpy()
bias = list(model.parameters())[1][0].detach().numpy()
z.plot_line(-weights[0]/weights[1], bias/weights[1], ax, color='black')

z.plot_2Dpoints(X.detach().numpy(), y.detach().numpy(), ax)
plt.show()


# Plotting the loss
plt.title("Training Loss")
plt.xlabel('Number of epochs')
plt.ylabel('Loss')
plt.plot(list(range(len(losses))), losses)
plt.show()


print("Final training loss:", losses[-1])
Final training loss: 0.05441097542643547
print(float(list(model.named_parameters())[1][1]))
-0.05139697715640068
fig = plt.figure(figsize = (16, 16))
ax = fig.add_subplot(111, projection='3d')


for color, class_ in [('blue', 0), ('red', 1)]:
    xys = X.numpy()[np.argwhere(y.squeeze().numpy() == class_).squeeze()]
    zs = np.ones(xys.shape[0])
    ax.scatter(xys[:,0], xys[:,1], zs, color=color, edgecolor = 'k', s=10)


# Model parameters
w0 = float(list(model.named_parameters())[1][1])
w1 = float(list(model.named_parameters())[0][1][0,0])
w2 = float(list(model.named_parameters())[0][1][0,1])


X_grid, Y_grid = np.meshgrid(np.linspace(-1.,1.,1000), np.linspace(-1.,1.,1000))
Z_surf = w0 + w1 * X_grid + w2 * Y_grid
#Z_surf = np.clip(Z_surf, -2., 2.)
Z_surf[Z_surf > 2.] = np.nan
Z_surf[Z_surf < -2.] = np.nan


ax.plot_surface(X_grid, Y_grid, Z_surf, alpha=0.2)
ax.view_init(10, 30) # View angle

ax.set_axis_on()

# Plot xyz axis
ax.plot([0,0], [0,1], [0,0], c='r')
ax.plot([0,1], [0,0], [0,0], c='b')
ax.plot([0,0], [0,0], [0,2], c='g')

ax.set_xlim(left = -1.05, right = 1.05)
ax.set_ylim(bottom = -1.05, top = 1.05)
ax.set_zlim(bottom = -2, top = 2)




ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()
<ipython-input-111-2d5fd32a75dc>:24: UserWarning: Z contains NaN values. This may result in rendering artifacts.
  ax.plot_surface(X_grid, Y_grid, Z_surf, alpha=0.2)