ニューラルネットワークの図を描くためのPythonコード
下記のように全結合型ニューラルネットワークの図を描くためのPythonコードを書いた。
入力層~出力層にかけてのノード数を引数に入れて実行するだけ。
neural_network_img(6,4,8,5,6,4,8,6)
コードは次の通り。
from itertools import product import numpy as np import matplotlib.pyplot as plt %matplotlib inline plt.rcParams['figure.dpi'] = 200 def make_nodes(num): return list(np.arange(-(num-1)/2, (num-1)/2+1,1)) def newral_netwark_img(*nodes_nums): n = len(nodes_nums) nodes_list = [make_nodes(num) for num in nodes_nums] edge_list = [list(product(nodes_list[i],nodes_list[i+1])) for i in range(n-1)] plt.figure() for i, edges in enumerate(edge_list): for edge in edges: plt.plot([i,i+1], edge, color='gray', linewidth=0.1) for i, nodes in enumerate(nodes_list): plt.scatter(np.full_like(nodes,i),nodes,c='white',edgecolors='gray') plt.xticks(ticks=list(range(n)),labels=['input']+[f'layer{i}' for i in range(n-2)]+['output']) plt.tick_params(bottom=False, left=False, right=False, top=False, labelbottom=True, labelleft=False, labelright=False, labeltop=False) ax = plt.gca() ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.set_title('Fully Coupled Neural Network') plt.show() neural_network_img(2,4,8,5)