#!/usr/bin/python

import lib
from Tkinter import *
import tkMessageBox
import tkSimpleDialog
import tkFileDialog
import wave

INF = 1e300000
NaN = INF/INF

class App:
	def __init__(self, master):
		########################################
		## Data variables
		##

		self.SCROLLRATE = 4
		self.NOTEHEIGHT = 20

		self.framerate = 0
		self.num_bytes = 0
		self.bytes = []

		self.bpm = 0
		self.npb = 1
		self.offset = 0
		self.notes = [set(), set(), set(), set()]

		self.note_spacing = 0

		########################################
		## GUI stuff
		##

		self.master = master
		master.geometry("1000x500")

		frame = Frame(master, bd=2)
		frame.pack(expand=1, fill='both')
             
		frame.grid_rowconfigure(0, weight=1)
		frame.grid_rowconfigure(1, weight=1)
		frame.grid_rowconfigure(2, weight=1)
		frame.grid_columnconfigure(0, weight=1)

		########################################
		## Menus
		##

		menubar = Menu(frame)

		filemenu = Menu(menubar, tearoff=0)
		#filemenu.add_command(label="Import (unimplemented)", command=self.import)
		filemenu.add_command(label="Export", command=self.export)
		filemenu.add_separator()
		filemenu.add_command(label="Import WAV", command=self.open_wav)
		filemenu.add_separator()
		filemenu.add_command(label="Exit", command=frame.quit)
		menubar.add_cascade(label="File", menu=filemenu);

		editmenu = Menu(menubar, tearoff=0)
		editmenu.add_command(label="Set BPM", command=self.set_bpm)
		editmenu.add_command(label="Set Note Resolution", command=self.set_npb)
		editmenu.add_separator()
		editmenu.add_command(label="WAV Offset (unimplemented)", command=self.set_offset)
		menubar.add_cascade(label="Edit", menu=editmenu)

		menubar.add_command(label="+", command=self.zoom_in)
		menubar.add_command(label="-", command=self.zoom_out)

		master.config(menu=menubar)

		########################################
		## Scrollbar and canvases
		##
		
		self.DEFAULTWIDTH = 1000
		self.ZOOMFACTOR = 1.5
		self.MOUSEZOOM = 1.25

		self.canvaswidth = self.DEFAULTWIDTH
		self.scrollmin = 0
		self.scrollmax = 1

		## TODO: on-demand wav blitting
		self.blittedmin = 0
		self.blittedmax = 1
		self.deltat = 0

		self.xscrollbar = Scrollbar(frame, orient=HORIZONTAL)
		self.xscrollbar.grid(row=3, column=0, sticky=EW)

		self.scrollables = []

		self.timecanvas = Canvas(frame, width=self.canvaswidth, height=25, xscrollcommand=self.xscroll)
		self.timecanvas.grid(row=0, column=0, sticky=NSEW)
		self.timecanvas.config(scrollregion=(0,0,self.canvaswidth,10))
		#self.timecanvas.create_rectangle(0,0, 25, 25)
		self.timecanvas.bind('<Button-4>', self.rollWheel)
		self.timecanvas.bind('<Button-5>', self.rollWheel)
		self.scrollables.append(self.timecanvas)

		self.notescanvas = Canvas(frame, width=self.canvaswidth, height=50, xscrollcommand=self.xscroll)
		self.notescanvas.grid(row=1, column=0, sticky=NSEW)
		self.notescanvas.config(scrollregion=(0,0,self.canvaswidth,500))
		self.notescanvas.bind('<Button-1>', self.click)
		self.notescanvas.bind('<Button-4>', self.rollWheel)
		self.notescanvas.bind('<Button-5>', self.rollWheel)
		self.scrollables.append(self.notescanvas)

		self.wavcanvas = Canvas(frame, width=self.canvaswidth, xscrollcommand=self.xscroll)
		self.wavcanvas.grid(row=2, column=0, sticky=NSEW)
		self.wavcanvas.config(scrollregion=(0,0,self.canvaswidth,500))
		self.wavcanvas.bind('<Button-4>', self.rollWheel)
		self.wavcanvas.bind('<Button-5>', self.rollWheel)
		self.scrollables.append(self.wavcanvas)

		self.master.bind('<MouseWheel>', self.rollWheel)

		self.xscrollbar.config(command=self.xview)
	
	def xscroll(self, *args):
		#print args
		self.scrollmin = float(args[0])
		self.scrollmax = float(args[1])
		self.xscrollbar.set(*args)

		## Blit the shown area to the canvas
		#print "here1"
		if self.num_bytes != 0:
			#print "here2"
			min = self.scrollmin * self.canvaswidth
			max = self.scrollmax * self.canvaswidth
			if min < self.blittedmin:
				#print "here3"
				self.blit(min, self.blittedmin)
				self.blittedmin = min
			#print "max", max
			if max > self.blittedmax:
				#print "here4"
				self.blit(self.blittedmax, max)
				self.blittedmax = max
	
	def xview(self, *args):
		#print args
		for s in self.scrollables:
			s.xview(*args)
	
	def export(self):
		name = tkSimpleDialog.askstring("Export",
			"""Enter a name to export as. Exporting will create two files:
<NAME>.notes.hex and <NAME>.music.hex.""",
			parent=self.master)
		lib.set_notes(self.notes)
		bytes = []
		if self.offset < 0:
			bytes = [0x80] * int(self.offset * self.framerate)
			bytes.extend(self.bytes)
		else:
			bytes = self.bytes[int(self.offset * self.framerate):]
		lib.set_waveform(bytes, 11025)
		#lib.set_waveform(self.bytes, 11025)
		scrollrate = self.bpm / 60 * self.npb
		lib.export(name=name, bpm=self.bpm, scrollrate=scrollrate, outrate=11025, split=1)
	
	def open_wav(self):
		file = tkFileDialog.askopenfile(parent=self.master, mode='rb', title='Choose a file')
		if file != None:
			f = wave.open(file, 'rb')
			self.framerate = f.getframerate()
			if self.framerate != 11025:
				tkMessageBox.showerror("ERRORZ!", "Sound framerate must be 11025 Hz.\n(Yours was %d)" % self.framerate)
				return
			if f.getnchannels() != 1:
				tkMessageBox.showerror("ERRORZ!", "Sound must be one channel (mono).\n(Yours was %d)" % f.getnchannels())
				return
			if f.getsampwidth() != 1:
				tkMessageBox.showerror("ERRORZ!", "ERRORZ! Sound must be 8-bit audio.\n(Yous was %d)" % (f.getsampwidth() * 8))
				return
			self.num_bytes = f.getnframes()
			#print self.num_bytes
			self.bytes = list(f.readframes(self.num_bytes))
			#print len(self.bytes)
			self.canvaswidth = self.DEFAULTWIDTH
			f.close()
			file.close()
			self.update()
	
	def click(self, event):
		button = int(self.notescanvas.canvasy(event.y) / self.NOTEHEIGHT)
		t = int((self.notescanvas.canvasx(event.x) + (self.note_spacing / 2)) / self.note_spacing)
		if t not in self.notes[button]:
			self.notes[button].add(t)
		else:
			self.notes[button].remove(t)
		self.toggle_note(button, t)

	def rollWheel(self, event):
		targetloc = float(self.wavcanvas.canvasx(event.x)) / self.canvaswidth
		target = (targetloc - self.scrollmin) / (self.scrollmax - self.scrollmin)

		# ignore the scroll wheel when not over canvas
		if target < 0 or target > 1:
			return

		# win32 <MouseWheel>
		if event.type == '38':
			#print "here"
			if event.delta > 0:
				self.zoom_in(target, self.MOUSEZOOM)
			if event.delta < 0:
				self.zoom_out(target, self.MOUSEZOOM)

		# X-Windows <Button-4> and <Button-5>
		if event.num == 4:
			self.zoom_in(target, self.MOUSEZOOM)
		if event.num == 5:
			self.zoom_out(target, self.MOUSEZOOM)

	def zoom_in(self, target = .5, zoom = 0):
		if zoom == 0:
			zoom = self.ZOOMFACTOR
		self.zoom(target, zoom)
	
	def zoom_out(self, target=.5, zoom = 0):
		if zoom == 0:
			zoom = self.ZOOMFACTOR
		self.zoom(target, 1/zoom)
	
	def zoom(self, target, zoom):
		## figure out where we want to zoom in on
		targetloc = self.scrollmin + (self.scrollmax - self.scrollmin) * target
		#print 'target %:', targetloc

		## zoom
		self.canvaswidth *= zoom

		self.update_non_blit()
		
		## call this for the scrollbar size to get adjusted, so we can place it correctly
		self.master.update_idletasks() 

		## scroll to the correct spot
		newwidth = self.scrollmax - self.scrollmin
		#print 'new bar size:', newwidth
		newmin = targetloc - newwidth * target
		newmax = targetloc + newwidth * (1 - target)
		#print 'bar min, max:', newmin, newmax
		self.xscroll(str(newmin), str(newmax))
		#print self.scrollmin + (self.scrollmax - self.scrollmin) * target
		self.xview('moveto', str(newmin))
		
		self.update_blit()
	
	def set_bpm(self):
		newbpm = tkSimpleDialog.askfloat("Set scrollrate",
			"""Enter the song BPM:""",
			initialvalue=self.bpm, parent=self.master)
		if newbpm != None:
			self.bpm = newbpm
		self.update()
	
	def set_npb(self):
		newnpb = tkSimpleDialog.askfloat("Set scrollrate",
			"""Enter the number of notes you want per beat (note that high numbers
make the screen scroll faster, making the game harder). Must be greater than 0:""",
			initialvalue=self.npb, parent=self.master)
		if newnpb != 0:
			self.npb = newnpb
		self.update()
	
	def set_offset(self):
		newoffset = tkSimpleDialog.askfloat("Set WAV offset",
			"""Shift the WAV to the right by some number of seconds.
Use this to synchronize the beats with the button presses.""",
			initialvalue=-self.offset, parent=self.master)
		if newoffset != None:
			self.offset = -newoffset
		self.update()
	
	def display_note(self, button, time):
		self.notescanvas.create_oval((time * self.note_spacing - (self.note_spacing / 2), button * self.NOTEHEIGHT, time * self.note_spacing + (self.note_spacing / 2), (button + 1) * self.NOTEHEIGHT), tags="%d_%d" % (button, time), fill="green")
	
	def toggle_note(self, button, time):
		item = self.notescanvas.find_withtag("%d_%d" % (button, time))
		if item:
			self.notescanvas.delete("%d_%d" % (button, time))
		else:
			self.display_note(button, time)

	def blit(self, min, max):
		offsetbytes = int(self.offset * self.framerate)
		deltat = float(self.num_bytes) / self.canvaswidth
		if deltat > 2:
			for x in range(int(min), int(max)):
				min = INF
				max = -INF
				for bytenum in range(int(x*deltat),int((x+1)*deltat)):
					if bytenum+offsetbytes >= len(self.bytes) or bytenum+offsetbytes < 0:
						byte = 0x80
					else:
						byte = ord(self.bytes[bytenum+offsetbytes])
					if byte < min:
						min = byte
					if byte > max:
						max = byte
				self.wavcanvas.create_line(x, min, x, max)
		else:
			#print 'here'
			for x in range(int(min), int(max)):
				if int((x+1)*deltat)+offsetbytes >= len(self.bytes) or int(x*deltat)+offsetbytes < 0:
					byte1 = 0x80
					byte2 = 0x80
				else:
					byte1 = ord(self.bytes[int(x*deltat)+offsetbytes])
					byte2 = ord(self.bytes[int((x+1)*deltat)+offsetbytes])
				self.wavcanvas.create_line(x, byte1, x+1, byte2)
			#print 'here'

	def update(self):
		self.update_non_blit()
		self.update_blit()
	
	def update_non_blit(self):
		for s in self.scrollables:
			s.config(width=self.canvaswidth)
			s.delete(ALL)

		## update scrollregions
		for s in self.scrollables:
			s.config(scrollregion=(0,0,int(self.canvaswidth),s.winfo_reqheight()))

		## let the scrollbars adjust, and pretend we've blitted everything, so the
		## adjustments don't try to blit anything -- we'll take care of it later
		self.blittedmin = 0
		self.blittedmax = self.canvaswidth
		self.master.update_idletasks()

		## Do time
		secs = float(self.num_bytes) / self.framerate
		if self.framerate != 0:
			#print "secs", secs
			num_marks = self.canvaswidth / 10
			#print num_marks
			minspm = float(secs) / num_marks
			spm = 1e-6
			while spm < minspm:
				spm *= 10
			#print "secs per mark", spm
			num_marks = secs / spm
			mark_spacing = float(self.canvaswidth) / num_marks
			#print mark_spacing
			for i in range(int(num_marks)):
				x = i * mark_spacing
				t = i * spm
				if mark_spacing > 40 or (mark_spacing > 20 and not (i % 2)) or not (i % 4):
					self.timecanvas.create_line(x, 0, x, 10)
					self.timecanvas.create_text(x, 10, anchor=N, text="%.*fs" % (2, t))
				else:
					self.timecanvas.create_line(x, 0, x, 5)

		## Do ruler lines and notes/arrows
		if self.bpm and self.npb:
			num_lines = float(self.bpm) / 60 * secs * self.npb
	 		line_spacing = float(self.canvaswidth) / num_lines
			self.note_spacing = line_spacing
			for t in range(int(num_lines)):
				if t % self.npb == 0:
					fill='black'
				else:
					fill='grey'
				self.notescanvas.create_line(t * line_spacing, 0, t * line_spacing, 4 * self.NOTEHEIGHT, fill=fill)
			for button in range(len(self.notes)):
				for time in self.notes[button]:
					self.display_note(button, time)

		## Set up wav blitting and blit the initial window
		## Moved to its own method
		#self.blittedmin = self.scrollmin * self.canvaswidth
		#self.blittedmax = self.scrollmax * self.canvaswidth
		#self.blit(self.blittedmin, self.blittedmax)

		#if self.num_bytes > 0:
			#deltat = float(self.num_bytes)/self.canvaswidth
			#print "deltat: ", deltat
			#if deltat > 1:
			#	for x in range(int(self.canvaswidth)):
			#		min = INF
			#		max = -INF
			#		for bytenum in range(int(x*deltat),int((x+1)*deltat)):
			#			byte = ord(self.bytes[bytenum])
			#			if byte < min:
			#				min = byte
			#			if byte > max:
			#				max = byte
			#		self.wavcanvas.create_line(x, min, x, max)
			#else:
			#	for x in range(int(self.canvaswidth)):
			#		self.wavcanvas.create_line(x, ord(self.bytes[int(x*deltat)]), x+1, ord(self.bytes[int((x+1)*deltat)]))

		## update scrollregions
		for s in self.scrollables:
			#print s.winfo_reqheight()
			s.config(scrollregion=(0,0,int(self.canvaswidth),s.winfo_reqheight()))

		#print self.canvaswidth
	
	def update_blit(self):
		## Let things update, just in case
		self.master.update_idletasks()

		## Set up wav blitting and blit the initial window
		self.blittedmin = self.scrollmin * self.canvaswidth
		self.blittedmax = self.scrollmax * self.canvaswidth
		#print "blittedmax", self.blittedmax
		self.blit(self.blittedmin, self.blittedmax)

root = Tk()

app = App(root)

root.mainloop()

