# Copyright (c) 2007 Andrew Choi.  All rights reserved.

# This is emphatically NOT free/GPL software.

# Permission for the use of this code is granted only for research,
# educational, and non-commercial purposes.

# Redistribution of this code or its parts in any form without
# permission, with or without modification, is prohibited.
# Modifications include, but are not limited to, translation to other
# programming languages and reuse of tables, constant definitions, and
# API's defined in it.

# There is no restriction on the use of the output generated by this
# software.

# THIS SOFTWARE IS PROVIDED BY ANDREW CHOI "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED.  IN NO EVENT SHALL ANDREW CHOI BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.


from itertools import izip
from random import random
from math import fmod

from toe.core.objects import Chord, Interval
from toe.midi.objects import MIDINote, Sequence
from toe.midi.utils import diatonic_above, diatonic_below, scale_notes_between

from mjb.utils import WeightedChoice, FlatteningIterator, LookaheadIterator
from mjb.utils import are_approximately_equal, fractional_part
from mjb.generators import MNG
from mjb.algorithms import analyze_tonality

MNG_NORMAL = MNG(MIDINote('G1'), MIDINote('C3'))
MNG_EXTENDED = MNG(MIDINote('E1'), MIDINote('F3'))
MNG_STANDARD = MNG_NORMAL * 10 + MNG_EXTENDED * 4

MAX_SEMI_TONE_LEAP = 15
MAX_TWO_BEAT_SEMI_TONE_LEAP = 11

F_ROOT = 70
F_THIRD = 15
F_FIFTH = 15

REPEAT_FIRST_NOTE_REL_PROB = 0.25

F_CHROMATIC_ABOVE = 40
F_CHROMATIC_BELOW = 20
F_DIATONIC_ABOVE = 20
F_DIATONIC_BELOW = 20

WALKING_LINE_2_SCALE_NOTE_PROB = 0.8
WALKING_LINE_3_SCALE_NOTE_PROB = 0.8

WALKING_LINE_INSERT_1_PROB = 0.3
WALKING_LINE_INSERT_2_PROB = 0.3
WALKING_LINE_INSERT_3_PROB = 0.3

CLOSE_TARGET_MAX_SEMI_TONE = 5
CLOSE_TARGET_SEMI_TONE_RANGE = 10
CLOSE_1_3_MAX_SEMI_TONE = 2
CLOSE_1_3_SEMI_TONE_RANGE = 5

Z = 2.0

EMBELLISH_2_PROB = 0.15
EMBELLISH_4_PROB = 0.35
embellish_4_position = WeightedChoice([(0, 4), (1, 2), (2, 3), (3, 0)])

STRONG_VELOCITY = 100
WEAK_VELOCITY = 85
EMBELLISHMENT_VELOCITY = 70


def gen_first_notes(chords, durations):
    result = []
    for i in LookaheadIterator(izip(chords, durations)):
        chord = i.curr[0]
        
        if not i.prev or not i.lookahead:
            result.append((MNG(chord.bass) * MNG_NORMAL).gen())
        else:
            if chord.bass != chord.root:
                mng = MNG(chord.bass)
            else:
                mng = MNG(chord.root) * F_ROOT + MNG(chord.nth(3)) * F_THIRD + MNG(chord.nth(5)) * F_FIFTH

            prev_note = result[-1]
            mng = mng * ~MNG(prev_note) * int(100 * (1.0 - REPEAT_FIRST_NOTE_REL_PROB)) + mng * int(100 * REPEAT_FIRST_NOTE_REL_PROB)

            if mng.is_null():
                mng = MNG(chord.bass)

            if i.prev[1] <= 2:
                max_leap = MAX_TWO_BEAT_SEMI_TONE_LEAP
            else:
                max_leap = MAX_SEMI_TONE_LEAP

            mng = mng * MNG(prev_note - max_leap, prev_note + max_leap)
            result.append((mng * MNG_STANDARD).gen())

    return result

def approach_note_MNG(avoid_note, target_note, target_scale):
    return MNG(target_note + 1) * F_CHROMATIC_ABOVE + MNG(target_note - 1) * F_CHROMATIC_BELOW + MNG(diatonic_above(target_note, target_scale)) * F_DIATONIC_ABOVE + MNG(diatonic_below(target_note, target_scale)) * F_DIATONIC_BELOW * ~MNG(avoid_note) * MNG_STANDARD

def gen2(note, next_note, chord, next_scale):
    root = chord.root
    if note.note != root and next_note.note != root:
        # Choose root in the same octave of both current and target notes.
        mng = MNG(root) * MNG(note - 12, note + 12, 2.0) * MNG(next_note - 12, next_note + 12, 2.0) * MNG_STANDARD
        return mng.gen()
    else:
        mng = approach_note_MNG(note, next_note, next_scale)
        if mng.is_null():
            mng = MNG(chord) * MNG(next_note - 7, next_note + 7) * MNG_STANDARD
        return mng.gen()

def gen_walking_line_2(note, next_note, next_scale, notes_between):
    nb0 = notes_between[0]
    nb1 = notes_between[1]

    if abs(nb0 - nb1) == 2 and random() < WALKING_LINE_INSERT_1_PROB:
        if nb0 < nb1:
            between = nb0 + 1
        else:
            between = nb0 - 1
        return [nb0, between, nb1]
    elif abs(nb1 - next_note) == 2 and random() < WALKING_LINE_INSERT_2_PROB:
        if nb1 < next_note:
            between = nb1 + 1
        else:
            between = nb1 - 1
        return [nb0, nb1, between]
    else:
        if note < next_note:
            return [nb0, nb1, diatonic_above(next_note, next_scale)]
        else:
            return [nb0, nb1, diatonic_below(next_note, next_scale)]

def gen_walking_line_3(note, next_note, notes_between):
    nb1 = notes_between[1]
    nb2 = notes_between[2]

    if abs(nb2 - next_note) == 2 and random() < WALKING_LINE_INSERT_3_PROB:
        if nb2 < next_note:
            between = next_note - 1
        else:
            between = next_note + 1
        return [nb1, nb2, between]
    else:
        return notes_between

def gen_second_note(scale, chord, note, next_note):
    if abs(note - next_note) <= CLOSE_1_3_MAX_SEMI_TONE:
        mng = MNG(scale) + MNG(chord)
        if note < next_note:
            mng *= MNG(next_note + 1, next_note + CLOSE_1_3_SEMI_TONE_RANGE)
        else:
            mng *= MNG(next_note - CLOSE_1_3_SEMI_TONE_RANGE, next_note - 1)
    else:
        mng = MNG(scale)
        if note < next_note:
            mng = mng * MNG(note + 1, next_note - 1, Z)
        else:
            mng = mng * MNG(next_note + 1, note - 1, Z)

    mng *= MNG_STANDARD
    if mng.is_null():
        mng = MNG(scale) * (MNG(next_note + 1, next_note + CLOSE_1_3_SEMI_TONE_RANGE) + MNG(next_note - CLOSE_1_3_SEMI_TONE_RANGE, next_note - 1)) * MNG_STANDARD
    return mng.gen()

def gen_third_note(chord, note, next_note):
    if abs(note - next_note) <= CLOSE_TARGET_MAX_SEMI_TONE:
        mng = MNG(chord)
        if note < next_note:
            mng *= MNG(next_note + 1, next_note + CLOSE_TARGET_SEMI_TONE_RANGE)
        else:
            mng *= MNG(next_note - CLOSE_TARGET_SEMI_TONE_RANGE, next_note - 2)
    else:
        mng = MNG(chord)
        if note < next_note:
            mng = mng * MNG(note + 1, next_note - 1, Z)
        else:
            mng = mng * MNG(next_note + 1, note - 1, Z)

    mng *= MNG_STANDARD
    if mng.is_null():
        mng = MNG(chord) * (MNG(next_note + 1, next_note + CLOSE_TARGET_SEMI_TONE_RANGE) + MNG(next_note - CLOSE_TARGET_SEMI_TONE_RANGE, next_note - 2)) * MNG_STANDARD
    return mng.gen()

def gen3(note, next_note, chord, scale, next_scale):
    third_note = gen_third_note(chord, note, next_note)
    second_note = gen_second_note(scale, chord, note, third_note)
    return [second_note, third_note]

def gen4(note, next_note, chord, scale, next_scale):
    notes_between = scale_notes_between(note, next_note, scale)
    l = len(notes_between)

    if l == 2 and random() <= WALKING_LINE_2_SCALE_NOTE_PROB:
        return gen_walking_line_2(note, next_note, next_scale, notes_between)
    elif l == 3 and random() <= WALKING_LINE_3_SCALE_NOTE_PROB:
        return gen_walking_line_3(note, next_note, notes_between)
    else:
        third_note = gen_third_note(chord, note, next_note)
        second_note = gen_second_note(scale, chord, note, third_note)
        mng = approach_note_MNG(third_note, next_note, next_scale)
        if mng.is_null():
            mng = MNG(chord) * MNG(next_note - 7, next_note + 7) * MNG_STANDARD
        fourth_note = mng.gen()
        return [second_note, third_note, fourth_note]
    
def gen_other_notes(chords, durations, first_notes, scales):
    result = []
    for i in LookaheadIterator(izip(chords, durations, first_notes, scales)):
        note = i.curr[2]
        duration = i.curr[1]

        result.append(note)
        if not i.lookahead or duration == 1:
            continue

        chord = i.curr[0]
        scale = i.curr[3]
        next_note = i.lookahead[2]
        next_scale = i.lookahead[3]
        
        if duration == 2:
            result.append(gen2(note, next_note, chord, next_scale))
        elif duration == 3:
            result += gen3(note, next_note, chord, scale, next_scale)
        else:  # duration == 4
            result += gen4(note, next_note, chord, scale, next_scale)
        
    return result

def embellish(bassline, durations):
    bass_seq = Sequence()

    note_iter = iter(bassline)
    beat = 0.0
    for d in LookaheadIterator(durations):
        if not d.lookahead:
            bass_seq.add_event(beat, (note_iter.next(), d.curr))
        else:
            if d.curr == 2 and random() < EMBELLISH_2_PROB:
                n1 = note_iter.next()
                n2 = note_iter.next()
                bass_seq.add_event(beat, (n1, 0.5))
                bass_seq.add_event(beat + 0.5, (n1, 0.5))
                bass_seq.add_event(beat + 1.0, (n2, 1.0))
                beat += 2.0
            elif d.curr == 4 and random() < EMBELLISH_4_PROB:
                i = embellish_4_position.choose()
                for j in xrange(4):
                    n = note_iter.next()
                    if j == i:
                        bass_seq.add_event(beat, (n, 0.5))
                        bass_seq.add_event(beat + 0.5, (n, 0.5))
                    else:
                        bass_seq.add_event(beat, (n, 1.0))
                    beat += 1.0
            else:
                for j in xrange(d.curr):
                    bass_seq.add_event(beat, (note_iter.next(), 1.0))
                    beat += 1.0

    return bass_seq

def perf_seq(seq):
    result = Sequence()

    for time, event in seq:
        mn, dur = event

        if are_approximately_equal(fractional_part(time), 0.0):
            if not are_approximately_equal(fmod(time, 2.0), 1.0):
                result.add_event(time, (mn, dur, STRONG_VELOCITY))
            else:
                result.add_event(time, (mn, dur, WEAK_VELOCITY))
        else:
            result.add_event(time, (mn, dur, EMBELLISHMENT_VELOCITY))

    return result

def gen_bass(chart):
    chord_duration_pairs = [i for i in FlatteningIterator(chart['chords'])]
    
    chords = [Chord(c) for c, d in chord_duration_pairs]
    durations = [d for c, d in chord_duration_pairs]

    first_notes = gen_first_notes(chords, durations)
    
    scales = analyze_tonality(chords)
    
    bassline = gen_other_notes(chords, durations, first_notes, scales)

    bass_seq = embellish(bassline, durations)
    
    return perf_seq(bass_seq)
