module Main where
import System.Random
import qualified Data.ByteString as B
import Data.Char
import Data.Bits
import System.Environment
-------------------------------------------------------------------------------
-- Project 1, a Haskell version of the PR-1 program of Gottfried Michael Koenig
-- interpretation and programming: W.G. Vree, 2007
-------------------------------------------------------------------------------
-- For each "musical parameter" this program calculates a sequence of Parts.
-- The musical parameters are:
-- chords, duration (called entry-delay) dynamics and tempo.
-- A Part can be either a Row a Group or a Balance. 
-- Rows and Groups are sequences that obey certain rules of serial music.
-- A Balance is a (balanced) mixture of Rows and Groups.
-------------------------------------------------------------------------------

data Part a = Row [a] | Grp [a] | Bal [Part a] [Part a] | Par a

major :: [[Int]] -- explicit typing to force 32-bit integers (Int)
minor :: [[Int]] -- Haskell defaults to huge integers (Integer)
mapping :: [(String, Int, Int)]

-----------------------------------------------------------------------------
-- LEVEL 1 (parameter definitions, to be choosen by the composer)------------
-----------------------------------------------------------------------------

notes = ["c", "c#", "d", "d#", "e", "f", "f#", "g", "g#", "a", "a#", "b"]

major = [[0, 4, 7], [3, 7, 10], [6, 10, 1], [9,  2, 5]] -- interval-rows for computing chords
minor = [[0, 3, 7], [2, 5,  9], [5,  8, 0], [7, 10, 2]]

-- the possible entry-delays (chord duration) with corresponding minimum and maximum chord size
mapping = [("1/2",1,4),
           ("1/4",1,3), ("1/4",1,3), ("1/4",1,3), 
           ("1/8",1,2), ("1/8",1,2), ("1/8",1,2), ("1/8",1,2),
           ("1/12",1,1), ("1/12",1,1), ("1/12",1,1), ("1/12",1,1),
           ("1/16",1,1), ("1/16",1,1), ("1/16",1,1), ("1/16",1,1)]

entry_list = map f mapping where f (x,y,z) = x                -- extracted list of possible chord durations

dyna_list = ["ppp", "pp", "p", "mp", "mf", "f", "ff", "fff"]  -- possible dynamics

tempo_list = ["120", "104", "90", "79", "69", "60"] -- possible tempo values (t60 == 60 1/4-beats per minute)

rR = 1::Int            -- 1..2, repetition rate for chord rows and groups

-- the average number of notes in a chord, caculated from the mapping above

average_chord_size = fromIntegral total / fromIntegral (2 * (length mapping))
                     where
                     total = sum [min + max | (delay, min, max) <- mapping]

-----------------------------------------------------------------------------
-- LEVEL 1 Section definition -----------------------------------------------
-----------------------------------------------------------------------------
-- computes 'nlines' of a section with the given process numbers and random seed.
-- the process-numbers specify the type of serial generation process (see: Level 4)
-- 1..3 == Rows are generated (with decreasing irregularity from 1 to 3)
-- 4 == Balance structures are generated (balances between irregular and regular)
-- 5..7 == Groups are generated (with increasing regularity from 5 to 7)
-----------------------------------------------------------------------------

section nlines dyna_process entry_process chord_process seed = output dyna_str entry_str tempo_str chord_str nlines
    where
    dyna_str = d_section dynamics
                         (mk_row_str dyna_list rs1, mk_serial_str dyna_list rs2)
                         dyna_process
                         (length dyna_list)
                         rs3
               where
               rs1 = mkRndStr 11 seed
               rs2 = mkRndStr 12 seed
               rs3 = mkRndStr 13 seed
    
    entry_str = d_section entry_delay
                          (mk_row_str entry_list rs1, mk_serial_str entry_list rs2)
                          entry_process
                          (length entry_list)
                          rs3
                where
                rs1 = mkRndStr 14 seed
                rs2 = mkRndStr 15 seed
                rs3 = mkRndStr 16 seed

    ch_size_str = d_section chord_size
                            (flatten entry_str, [])
                            0
                            mapping
                            rs1
                  where
                  rs1 = mkRndStr 17 seed

    tempo_str = d_section tempo_grp
                          (flat2 entry_str, mk_serial_str tempo_list rs1)
                          0
                          (between 1 4 r1)
                          []
                where
                (r1:rs1) = mkRndStr 18 seed

    chord_str = d_section chord
                          (ch_size_str, mk_serial_str notes rs1)
                          chord_process 
                          average_chord_size
                          rs2
                where
                rs1 = mkRndStr 19 seed
                rs2 = mkRndStr 20 seed

-----------------------------------------------------------------------------
-- LEVEL 2 --General stream creation patterns--------------------------------
-----------------------------------------------------------------------------

d_section par_function streams process arguments rs =
    part : d_section par_function rest_streams process arguments rs1
    where
    ((part, rest_streams), rs1) = par_function streams process arguments rs

mk_row_str par_list rs = res1 : mk_row_str par_list rs1
                         where
                         (res1, rs1) = perm par_list rs

mk_serial_str par_list rs = foldr1 (++) (mk_row_str par_list rs)

-----------------------------------------------------------------------------
-- LEVEL 3 --Definition of the parameter functions---------------------------
-----------------------------------------------------------------------------

dynamics str_tup process par_len rs
    | process <= 3 = (row     str_tup (round (max 1 rowMax)), rs)
    | process == 4 = balance str_tup 2 (min 8 par_len) 1 rs
    | process >  4 = (group   str_tup (between grpMin grpMax r1), rs1)
    where
    rowMax = fromIntegral (par_len * (5 - process)) / (fromIntegral 4)
    minvec = [[4, 6,  8], [4,  6,   8]]
    maxvec = [[6, 9, 12], [10, 15, 20]]
    grpMin = minvec !! (rR - 1) !! (process - 5)
    grpMax = maxvec !! (rR - 1) !! (process - 5)
    (r1 : rs1) = rs
    
entry_delay str_tup process par_len rs
    | process <= 3 = (row str_tup (round (max 1 rowMax)), rs)
    | process == 4 = balance str_tup 1 (quot par_len 2) (between 1 4 r1) rs1
    | process >  4 = (group str_tup (between grpMin grpMax r1), rs1)
    where
    rowMax = fromIntegral (par_len * (5 - process)) / (fromIntegral 4)
    minvec = [[2, 5,  8], [3,  7, 11]]
    maxvec = [[4, 8, 12], [6, 11, 16]]
    grpMin = minvec !! (rR - 1) !! (process - 5)
    grpMax = maxvec !! (rR - 1) !! (process - 5)
    (r1 : rs1) = rs

chord_size (delay : rest_delays, ys) process mapping rs =
    ((Par (between min max r1), (rest_delays, ys)), rs1)
    where
    ranges = [(min, max) | (del, min, max) <- mapping, del == delay]
    min = minimum (map fst ranges)
    max = maximum (map snd ranges)
    (r1 : rs1) = rs

tempo_grp str_tup process rtc rs = (tempo_grp' str_tup process rtc, rs)
tempo_grp' (e:entries, t:tempos) process rtc = 
    if rtc == 0 then (Grp [], (e:entries, tempos))
    else             (Grp (t_list ++ rest_t_list), rest_streams)
    where
    t_list = replicate (length e) t
    (Grp rest_t_list, rest_streams) =  tempo_grp' (entries, t:tempos) process (rtc - 1)

chord str_tup process average_ch_size rs
    | process <= 3 = row_chord str_tup (12 - (3 * (process - 1))) rs
    | process == 4 = balance_chord str_tup average_ch_size rs
    | process >  4 = one_of_3 (grp_tones str_tup (between min_tone   max_tone   r1)   rs1)
                              (grp_chord str_tup (between min_group  max_group  r1) 3 rs1)
                              (grp_chord str_tup (between min_dgroup max_dgroup r1) 6 rs1)
                              r2

    where min_tone   = process - 3
          max_tone   = min_tone * (rR + 1)
          min_group  = round ((2.0 * average_ch_size * fromIntegral (min_tone)) / 2.25)
          max_group  = min_group * (rR + 1)
          min_dgroup = round ((      average_ch_size * fromIntegral (min_tone)) / 2.25)
          max_dgroup = min_dgroup * (rR + 1)
          (r1:r2:rs1) = rs

-----------------------------------------------------------------------------
-- LEVEL 4: Row, Group and Balance of non-chord parameters-------------------
-----------------------------------------------------------------------------

row ((xs : row_str), serial_str) m = (Row (take m xs), (row_str, serial_str))

group (row_str, (x : serial_str)) m = (Grp (replicate m x), (row_str, serial_str))

balance str_tup min max repeat rs = ((Bal set_parts res1, rest_streams2), rs3)
    where
        (len_list, rs1)                = mk_len_list min max repeat rs
        (proc_list, rs2)               = mk_proc_list        repeat rs1
        set_proc_list                  = map fst proc_list
        bal_proc_list                  = map snd proc_list
        (set_parts,     rest_streams)  = struc set_proc_list str_tup      len_list
        (balance_parts, rest_streams2) = struc bal_proc_list rest_streams len_list
        (res1, rs3)                    = perm balance_parts rs2

        struc    fs  str_tup []         = ([]            , str_tup)
        struc (f:fs) str_tup (len:lens) = ((part : parts), rest_streams2)
            where
            (part, rest_streams)   = f str_tup len
            (parts, rest_streams2) = struc fs rest_streams lens

        mk_len_list min max 0 rs = ([], rs)
        mk_len_list min max n rs = (between min max r1 : res1, rs2)
                                   where
                                   (r1 : rs1) = rs
                                   (res1, rs2) = mk_len_list min max (n - 1) rs1

        mk_proc_list 0 rs = ([], rs)
        mk_proc_list n rs = (one_of (row, group) (group, row) r1 : res1, rs2)
                            where
                            (r1 : rs1) = rs
                            (res1, rs2) = mk_proc_list (n - 1) rs1

-----------------------------------------------------------------------------
-- LEVEL 4: row, group, double-group, tone and balance of the chord parameter
-----------------------------------------------------------------------------

row_chord (ch_len_str, note_str) row_size rs = ((Row chords,  (rest_ch_len_str, rest_note_str)), rs1)
    where
    ((trio_notes, rest_note_str), rs1) = trio_row note_str rs
    row = take row_size trio_notes
    (chords, rest_ch_len_str) = fill_chord ch_len_str row

grp_chord (ch_len_str, note_str) ngroups group_size rs = ((Grp chords, (rest_ch_len_str, rest_note_str)), rs1)
    where
    ((trio_notes, rest_note_str), rs1) = trio_row note_str rs
    group = take group_size trio_notes
    groups = foldr1 (++) (replicate ngroups group)
    (chords, rest_ch_len_str) = fill_chord ch_len_str groups

grp_tones (ch_len_str, note : rest_note_str) ntones rs = ((Grp chords, (rest_ch_len_str, rest_note_str)), rs1)
    where
    first_notes = note : first_notes
    ((chords, rest_ch_len_str), rs1) = mk_tones ch_len_str first_notes ntones rs

    mk_tones               ch_len_str  first_notes 0      rs = (([]                  ,      ch_len_str), rs)
    mk_tones (Par ch_len : ch_len_str) first_notes ntones rs = ((tone    : rest_tones, rest_ch_len_str), rs2)
        where
        ((trio_notes, _), rs1) = trio_row first_notes rs
        chord = take ch_len trio_notes
        tone = head first_notes : tail chord
        ((rest_tones, rest_ch_len_str), rs2) = mk_tones ch_len_str first_notes (ntones - 1) rs1

balance_chord str_tup avrage_ch_size rs = one_of exp1 exp2 r1
    where
    exp1 = ((Bal [Row chords] [chord_grp], rest_str_tup2), rs3) where
        nrows                             = fromIntegral (between 1 3 r2)
        ntones                            = round ((nrows * 5.0 * 2.25) / avrage_ch_size)
        ngroups                           = round ((nrows * 4.0 * 2.25) / avrage_ch_size)
        ndgroups                          = round ((nrows * 2.0 * 2.25) / avrage_ch_size)
        ((chords, rest_str_tup1)   , rs2) = rep_chord_row str_tup nrows rs1
        ((chord_grp, rest_str_tup2), rs3) = one_of_3 (grp_tones rest_str_tup1 ntones     rs2)
                                                     (grp_chord rest_str_tup1 ngroups 3  rs2)
                                                     (grp_chord rest_str_tup1 ndgroups 6 rs2)
                                                     r2
    exp2 = ((Bal [chords_grp] [Row chords], rest_str_tup2), rs3) where
        minElem = one_of_3 6 4 2 r2
        maxElem = one_of_3 (18 * rR)
                           (round ((12.0 * fromIntegral (rR) * avrage_ch_size) / 2.25))
                           (round ((6.0  * fromIntegral (rR) * avrage_ch_size) / 2.25))
                           r2
        nelems = between minElem maxElem r2
        nrows = round (fromIntegral (nelems) / (one_of_3 5.0 4.0 2.0 r2))
        ((chords_grp, rest_str_tup1), rs2) = one_of_3 (grp_tones str_tup nelems   rs1)
                                                      (grp_chord str_tup nelems 3 rs1)
                                                      (grp_chord str_tup nelems 6 rs1)
                                                      r2
        ((chords, rest_str_tup2)    , rs3) = rep_chord_row rest_str_tup1 nrows rs2
    (r1:r2:rs1) = rs

-----------------------------------------------------------------------------
-- support-functions for the chord parameter functions of level 4 -----------
-----------------------------------------------------------------------------

trio_row (start_note : rest_note_str) rs = ((map (add_note start_note) trio_list, rest_note_str), rs2)
    where
    interval_row = one_of major minor r1
    mixed_row    = foldr1 (++) res1
    trio_list    = one_of mixed_row (reverse mixed_row) r2
    (res1, rs2)   = perm interval_row rs1
    (r1:r2:rs1)  = rs

fill_chord ch_len_str row =
    if length row >= chlen then (chord : rest_chords, rest_ch_len_str)
    else                        ([], ch_len_str)
    where (Par chlen : rest_lens)        = ch_len_str
          (chord, rest_row)              = splitAt chlen row
          (rest_chords, rest_ch_len_str) = fill_chord rest_lens rest_row

rep_chord_row str_tup 0     rs = (([], str_tup)                         , rs)
rep_chord_row str_tup nrows rs = ((chords ++ rest_chords, rest_str_tup2), rs2)
    where
    ((Row chords,  rest_str_tup1), rs1) = row_chord str_tup 12 rs
    ((rest_chords, rest_str_tup2), rs2) = rep_chord_row rest_str_tup1 (nrows - 1) rs1
  
-----------------------------------------------------------------------------
-- SMALL UTILITY FUNCTIONS --------------------------------------------------
-----------------------------------------------------------------------------

select r xs = xs !! (mod r (length xs))

one_of   x y   r = select r [x,y]
one_of_3 x y z r = select r [x,y,z]

between left right r = left + mod r (1 + right - left)

perm [x]   rs  = ([x], rs)
perm xs (r:rs) = ((v:ys), rs1)
    where
    (us, (v:vs)) = splitAt (mod r (length xs)) xs
    (ys, rs1) = perm (us ++ vs) rs

mkRndStr v seed = rndStr (v + seed)
    where
    rndStr v = v1 : rndStr v1
               where
               v1 = mod (v * 25173 + 13849) 65536
{-
mkRndStr n = n : mkRndStr n
mkRndStr n = rndStr (mkStdGen n)
    where
    rndStr rg = r1 : rndStr rg1
                where (r1,rg1) = next rg
-}
index note []     = -1
index note (x:xs) | note == x = 0
                  | otherwise = 1 + index note xs

add_note note1 interval = notes !! iy
    where
    ix = index note1 notes
    iy = rem (ix + interval) 12

flatten []               = []
flatten (Row x : xs)     = x ++ flatten xs
flatten (Grp x : xs)     = x ++ flatten xs
flatten (Bal xs ys : zs) = flatten xs ++ flatten ys ++ flatten zs

flat2 []               = []
flat2 (Row x : xs)     = x : flat2 xs                     -- row structure is kept
flat2 (Grp x : xs)     = x : flat2 xs                     -- group structure is kept
flat2 (Bal xs ys : zs) = flat2 xs ++ flat2 ys ++ flat2 zs -- row/group structure is kept

-----------------------------------------------------------------------------
-- OUTPUT FUNCTIONS ---------------------------------------------------------
-- one section is ouput as a midi file, named "zmid.mid". This file contains 
-- one track. Chords are split into two channels (0 and 1). The first note
-- of a chord goes into channel 0, the rest into channel 1.
-----------------------------------------------------------------------------

tagPar []               = []              -- strip the constructors from the stream
tagPar (Row x     : xs) = x  ++ tagPar xs -- and flatten the stream
tagPar (Grp x     : xs) = x  ++ tagPar xs
tagPar (Bal xs ys : zs) = tagPar xs ++ tagPar ys ++ tagPar zs

output dyna_str entry_str tempo_str chord_str nlines = foldr1 (++) (take nlines lines)
    where
    lines = mkEvents "0" (tagPar tempo_str) (tagPar dyna_str) (tagPar chord_str) (tagPar entry_str)

    mkEvents pT (t:ts) (d:ds) (c:cs) (e:es)            -- pT is tempo in previous chord event
        | t /= pT   = (setTempo t ++ chordOut d c e) : mkEvents t ts ds cs es
        | otherwise =               (chordOut d c e) : mkEvents t ts ds cs es

composition (n:dp:ep:cp:sd:[]) =                         -- This defines the complete composition
    setInstr 0 trombone   ++ setInstr 1 tenorSax    ++ -- the instrumentation,
    setContr 0 volume 100 ++ setContr 1 volume  127 ++
    setContr 0 panning  0 ++ setContr 1 panning 127 ++
    setContr 0 reverb 127 ++ setContr 1 reverb  127 ++
    section n dp ep cp sd                              -- and the score!!
    where
    (volume, panning, reverb, trombone, tenorSax) = (7,10,91,57,66) -- midi constants

main = do xs <- getArgs                                     -- get arguments from the command line
          let args = if xs /= [] then readInts xs else [160, 1, 7, 6, 1]
              track = mkChunk "MTrk" (composition args)    -- The composition is converted to a track-chunk
              bytes = map fromIntegral (mkHead 1 ++ track) -- make a one track midifile
              in do B.writeFile "zmid.mid" (B.pack bytes) 
                    putStr ("section: " ++ show args ++ " to zmid.mid\n")

-----------------------------------------------------------------------------
-- chordOut is the central function that outputs Midi events
-- a 'chord' sounds during 'delay' beats with a volume 'dyn'
-----------------------------------------------------------------------------

chordOut dyn chord delay = chordOn 0 (convDyn dyn) (convNotes chord) ++
                           chordOff (convDel delay) 0 (convNotes chord)
    where
    convNotes = map (\n -> index n notes + 60) -- all notes are within the sixth octave

    chordOn ch dyn [] = [] -- a chord starts playing
    chordOn ch dyn (note:notes) = noteOn ch dyn note ++ chordOn 1 dyn notes

    chordOff del ch [] = [] -- the chord stops playing
    chordOff del ch (note:notes) = noteOff del ch note ++ chordOff 0 1 notes

    convDyn :: String -> Int
    convDyn dyn | dyn == "ppp" = 60
                | dyn == "pp"  = 70
                | dyn == "p"   = 80
                | dyn == "mp"  = 90
                | dyn == "mf"  = 100
                | dyn == "f"   = 110
                | dyn == "ff"  = 120
                | dyn == "fff" = 127

    convDel :: String -> Int -- all delay values below 127 (no variable length stuff needed)
    convDel del | del == "1/2"  = quot (4 * ticksPerQuarterNote) 2
                | del == "1/4"  = quot (4 * ticksPerQuarterNote) 4
                | del == "1/8"  = quot (4 * ticksPerQuarterNote) 8
                | del == "1/12" = quot (4 * ticksPerQuarterNote) 12
                | del == "1/16" = quot (4 * ticksPerQuarterNote) 16

readInts:: [String] -> [Int]
readInts = map read

-----------------------------------------------------------------------------
-- Low level Midi functions, bits and bytes
-----------------------------------------------------------------------------

midiFileType = 0::Int -- type of a midi file with one track
ticksPerQuarterNote = 12::Int -- this sets the midi clock rate

mkHead nTrk = -- the header of a midi file with nTrk tracks
    mkChunk "MThd" (swrd midiFileType ++ swrd nTrk ++ swrd ticksPerQuarterNote)

mkChunk typ xs = map ord typ ++ dwrd (length xs) ++ xs -- a midi file is build of chunks:

noteOn chan vol  pitch = [0x00, 0x90 + chan, pitch, vol] -- midi event for a key pressed
noteOff del chan pitch = [del, 0x80 + chan, pitch, 0] -- midi event for a key released

setTempo t = [0x00, 0xFF, 0x51, 0x03, t2, t3, t4] -- sets tempo in micro seconds per 1/4 note
             where [t1,t2,t3,t4] = dwrd (quot 60000000 (read t))

setInstr ch inst        = [0x00, 0xC0 + ch, inst] -- sets the instument on a channel
setContr ch contr value = [0x00, 0xB0 + ch, contr, value] -- sets a controller on a channel

dwrd i = reverse (map (0xff .&.) (take 4 (iterate (`shiftR` 8) i))) -- convert a 32bit word to 4 bytes
swrd i = reverse (map (0xff .&.) (take 2 (iterate (`shiftR` 8) i))) -- convert a 16bit word to 2 bytes

