めんまのつぶやき

ニューラルネットワークの図を描くためのPythonコード

下記のように全結合型ニューラルネットワークの図を描くためのPythonコードを書いた。
入力層~出力層にかけてのノード数を引数に入れて実行するだけ。

neural_network_img(6,4,8,5,6,4,8,6)

f:id:mennmabacon:20211130220042p:plain


コードは次の通り。

from itertools import product
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.dpi'] = 200

def make_nodes(num):
  return list(np.arange(-(num-1)/2, (num-1)/2+1,1))

def newral_netwark_img(*nodes_nums):
  n = len(nodes_nums)
  nodes_list = [make_nodes(num) for num in nodes_nums]
  edge_list = [list(product(nodes_list[i],nodes_list[i+1])) for i in range(n-1)]

  plt.figure()
  for i, edges in enumerate(edge_list):
    for edge in edges:
      plt.plot([i,i+1], edge, color='gray', linewidth=0.1)
  for i, nodes in enumerate(nodes_list):
    plt.scatter(np.full_like(nodes,i),nodes,c='white',edgecolors='gray')
  plt.xticks(ticks=list(range(n)),labels=['input']+[f'layer{i}' for i in range(n-2)]+['output'])
  plt.tick_params(bottom=False,
               left=False,
               right=False,
               top=False,
               labelbottom=True,
               labelleft=False,
               labelright=False,
               labeltop=False)
  ax = plt.gca()
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_visible(False)
  ax.spines['bottom'].set_visible(False)
  ax.set_title('Fully Coupled Neural Network')
  plt.show()

neural_network_img(2,4,8,5)

f:id:mennmabacon:20211130221909p:plain