我似乎一直在为matplotlib的滑块动画在嵌入Toplevel窗口的图形画布中居中而苦恼。我的函数将一个C乘N的矩阵M作为参数,其中N相当大,一个C通道的列表,和一个正整数win_size,与N相比很小。动画的每一帧都将显示一个C-by-win_size矩阵的热图,其中精确的列的集合取决于滑块的值。例如,如果M是4乘300,win_size=10,第0帧将是M[:,0:10]的热图,X轴标签为0到10,第1帧将是M[:,10:20]的热图,X轴标签为10到20,以此类推。每一帧中的Y轴标签是通道标签。
Here's the function:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import random
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from tkinter import Tk,TOP,BOTH,Toplevel
import matplotlib
matplotlib.use('TkAgg',force=True)
from mpl_toolkits.axes_grid1 import make_axes_locatable
import mpl_toolkits
import matplotlib.ticker as ticker
def scrolling_matrix_viewer(M,channels,win_size):
plot_window = Toplevel(bg="lightgray") #Toplevel where animation is embedded
plot_window.geometry('1000x900')
plot_window.wm_title('Top Level Window')
plot_window.attributes('-topmost', 'true')
fig, ax = plt.subplots()
fig.subplots_adjust(left=0.05, bottom=0.07, right=0.95, top=0.95, wspace=0, hspace=0)
canvas = FigureCanvasTkAgg(fig, master=plot_window)
canvas.draw()
canvas.get_tk_widget().pack(side=TOP,fill=BOTH,expand=1)
ax.xaxis.set_major_locator(ticker.MaxNLocator(10))
ax.invert_yaxis
plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03]) # axis for slider
# add colorbar to right of animation
ax_divider = make_axes_locatable(ax)
cax = ax_divider.append_axes('right', size='7%', pad='2%')
spos = Slider(ax_time, '',valinit=0,valmin=0,valmax=M.shape[1]-win_size,valstep=win_size)
def update_graph(val):
start=spos.val # starting column index of M used for this frame
stop=spos.val+win_size # ending column index of M used for this frame
if stop<=M.shape[1]:
ax.cla
heatmap=ax.imshow(M[:,start:stop],vmin=0, vmax=1, cmap='coolwarm',
aspect='auto',extent=[0,win_size,M.shape[0],0])
cb = fig.colorbar(heatmap, cax=cax, orientation='vertical')
ax_time.set_xlabel('time (sec)',fontsize=12)
ax.set_ylabel('channel',fontsize=12)
ticks=ax.get_xticks()
x_labels=[str(start+ticks[i]) for i in range(len(ticks))]
ax.set_xticklabels(x_labels)
ax.set_yticks(range(len(channels)))
ax.set_yticklabels(channels,fontsize=12)
spos.on_changed(update_graph)
spos.set_val(0)
现在,举个例子,M是4乘300,win_size=10。
root=Tk()
M=np.random.rand(4,300)
channels=['A','B','C','D']
win_size=10
scrolling_matrix_viewer(M,channels,win_size)