#!/usr/bin/env python3

import dataset
from sqlalchemy.types import *
from glob import glob
from datetime import datetime
from os import system
from string import ascii_lowercase, digits
from autocorrect import spell
from itertools import product
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score, KFold
from numpy import asarray

connect_string = 'postgresql://jsalsman@/sctest?host=/var/opt/gitlab/postgresql'

def create_database(dbname='sctest', clobber=False):

    db = dataset.connect(connect_string)

    t = db['users'] # makes primary key autoincrementing integer 'id'
    if clobber: t.drop()
    t.create_column('email', String(length=100))
    t.create_index('email')
    t.create_column('name', String(length=100))
    t.create_index('name')
    t.create_column('registered', DateTime)
    t.create_column('telephone', String(length=30))
    t.create_column('passhash', String(length=144)) # 512 bits as hex plus salt
    t.create_column('utterances', ARRAY(Integer)) # foreign keys: utterances/id
    t.create_column('native', String(length=20)) # native language
    t.create_column('learning', String(length=20)) # goal language
    t.create_column('currentlevel', Float(precision=1))
    t.create_column('goallevel', Float(precision=1))
    t.create_column('goalmonths', Float(precision=1))
    # TODO: billing/payments/methods
    t.create_column('categories', ARRAY(String(length=15)))
    t.create_column('disabled', Boolean)
    t.create_column('students', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('teachers', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('parents', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('schools', ARRAY(Integer)) # foreign keys: schools/id
    t.create_column('about', Text)

    t = db['words']
    if clobber: t.drop()
    t.create_column('spelling', String(length=30))
    t.create_index('spelling')
    t.create_column('phonemes', ARRAY(String(length=3))) # CMUBET, no dipthongs
    t.create_column('homographs', ARRAY(Integer)) # f.keys: words/id TEMP_UNUSED
    t.create_column('homophones', ARRAY(Integer)) # f.keys: words/id TEMP_UNUSED
    t.create_column('speechrank', Integer) # nth most common spoken word TEMP_UNUSED
    t.create_column('speechpart', String(length=1)) # see below TEMP_UNUSED TODO
    #
    # TEMPORARY LITERAL HOMOPHONES:
    t.create_column('homops', ARRAY(Text))
#
# speechpart codes:
#
# q: quantifier
# n: noun
# v: verb
# x: negative
# w: adverb
# m: adjective
# o: pronoun
# s: possessive
# p: preposition
# c: conjunction
# a: article

    t = db['prompts']
    if clobber: t.drop()
    t.create_column('display', Text)
    t.create_index('display')
    t.create_column('language', String(length=20))
    t.create_column('words', ARRAY(Integer)) # foreign keys: words/id
    t.create_column('utterances', ARRAY(Integer)) # foreign keys: utterances/id
    t.create_column('mediatype', Text)
    t.create_column('media', LargeBinary)
    t.create_column('level', Float(precision=1))
    t.create_column('phonemes', ARRAY(String(length=3))) # CMUBET, no dipthongs
    t.create_column('freeform', Boolean)

    t = db['utterances']
    if clobber: t.drop()
    t.create_column('promptid', Integer) # foreign key: prompts/id
    t.create_index('promptid')
    t.create_column('userid', Integer) # foreign key: users/id
    t.create_index('userid')
    t.create_column('pcm', LargeBinary)
    t.create_column('mp3', LargeBinary)
    t.create_column('wav', LargeBinary)
    t.create_column('exemplar', Boolean)
    t.create_column('adult', Boolean)
    t.create_column('male', Boolean)
    t.create_column('phonemes', ARRAY(String(length=3))) # CMUBET, no dipthongs
    t.create_column('features', ARRAY(Float(precision=3))) # ten per phoneme +1
    t.create_column('transcriptions', ARRAY(Integer)) # fk.s: transcriptions/id
    t.create_column('alignment', ARRAY(Float(precision=3)))
    t.create_column('worstphonemeposn', Integer)
    t.create_column('worstdiphoneposn', Integer)
    t.create_column('worstphoneme', String(length=3))
    t.create_column('worstdiphone', String(length=6))
    t.create_column('score', Float(precision=2))
    t.create_column('at', DateTime)
    t.create_column('whose', String(length=100))

    t = db['transcriptions']
    if clobber: t.drop()
    t.create_column('utteranceid', Integer) # foreign key: utterances/id
    t.create_column('promptid', Integer) # foreign key: prompts/id
    t.create_index('promptid')
    t.create_column('transcription', Text)
    t.create_column('intelligible', Boolean)
    t.create_column('userid', Integer) # foreign key: users/id
    t.create_column('transcribed', DateTime)
    t.create_column('source', String(length=1)) # (m)turk (p)eer (s)urvey (a)d (e)xemplar
    t.create_column('mturkuid', String(length=30)) #was length=14 until 21 seen
    t.create_column('mturkhit', String(length=35)) #was length=30

    t = db['log']
    if clobber: t.drop()
    t.create_column('userid', Integer) # foreign key: users/id
    t.create_index('userid')
    t.create_column('at', DateTime)
    t.create_index('at')
    t.create_column('event', String(length=20))
    t.create_index('event')
    t.create_column('address', String(length=300)) # IPv4 or v6["="domain name]
    t.create_index('address')
    t.create_column('details', Text)
    t.create_column('mediatype', Text)
    t.create_column('media', LargeBinary)

    t = db['choices'] # no index except 'id'
    if clobber: t.drop()
    t.create_column('question', Text)
    t.create_column('responses', ARRAY(Integer)) # foreign keys: responses/id
    t.create_column('activity', Text) #e.g. CMS assignment URL (w/type prefix?)
    t.create_column('mediatype', Text)
    t.create_column('media', LargeBinary)
    t.create_column('proofread', Boolean)
    t.create_column('disabled', Boolean)
    t.create_column('level', Float(precision=1))

    t = db['responses']
    if clobber: t.drop()
    t.create_column('choiceid', Integer) # foreign key: choices/id
    t.create_column('promptid', Integer) # foreign key: prompts/id
    t.create_column('resultid', Integer) # foreign key: choices/id
    t.create_column('terminal', Boolean)
    t.create_column('activity', Text) # e.g. assignment URL (with type prefix?)
    t.create_column('action', Text) # javascript?

    t = db['lessons']
    if clobber: t.drop()
    t.create_column('name', Text)
    t.create_column('choices', ARRAY(Integer)) # foreign keys: choices/id
    t.create_column('level', Float(precision=1))

    t = db['topics']
    if clobber: t.drop()
    t.create_column('name', Text)
    t.create_index('name')
    t.create_column('lessons', ARRAY(Integer)) # foreign keys: lessons/id
    t.create_column('level', Float(precision=1))
    t.create_index('level')

    t = db['schools']
    if clobber: t.drop()
    t.create_column('name', String(length=100))
    t.create_index('name')
    t.create_column('students', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('teachers', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('admins', ARRAY(Integer)) # foreign keys: users/id
    t.create_column('telephone', String(length=30))
    t.create_column('about', Text)


def load_database(dbname='sctest', clobber=False):
    db = dataset.connect(connect_string)
    wt = db['words']
    pt = db['prompts']
    ut = db['utterances']
    tt = db['transcriptions']
    if clobber:
        print('deleting words, prompt, utterance, and transcription data...')
        wt.delete()
        pt.delete()
        ut.delete()
        tt.delete()
        print('...data deleted')
    pris = {}
    for pr in sorted(glob('../database/db/p??????.txt')):
        with open(pr, 'r') as f:
            ps = f.read().strip().replace(' .','.').replace(' ,',','
                    ).replace(' ?','?').replace(' !','!')
        pri = pr.replace('../database/db/p','').replace('.txt','')
        pris[pri] = ps
        ws = ps.replace(',','').replace('-',' ').replace('.',''
                ).replace('!','').replace('?','').lower().split()
        wids = []
        for wd in ws:
            w = wt.find_one(spelling=wd)
            if w is None:
                wid = wt.insert(dict(spelling=wd))
                print('word:', wd, 'id:', wid)
            else:
                wid = w['id']
            wids.append(wid)
        pid = pt.insert(dict(display=ps, words=wids, language='en',
                freeform=False)) # gets utterances= below
        print('prompt:', ps, ws, 'id:', pid)
        uids = []
        for us in sorted(glob('../database/*-mp3s/p'+pri+'s??????.mp3')):
            with open(us, 'rb') as f:
                mp3 = f.read()
            uid = ut.insert(dict(promptid=pid, exemplar=False, adult=False,
                    mp3=mp3, at=datetime(2017,5,1), whose=us.split('/')[-1]))
            # gets transcriptions=[id,...] below;
            print (us, 'len:', len(mp3), 'id:', uid)
            uids.append(uid)
            tids = []
            for ts in sorted(glob(us.replace('.mp3', 'n*.txt'
                ).replace('shorter-mp3s', 'db').replace('longer-mp3s', 'db'))):
                with open(ts, 'r') as f:
                    tf = f.readlines()
                    tsn = tf[0].strip()
                    mturkuid = tf[2].strip().split()[1]
                    mturkhit = tf[3].strip().split()[1]
                if tsn.lower() == 'english': # skip defective transcriptions
                    continue
                tid = tt.insert(dict(utteranceid=uid, promptid=pid,
                        transcription=tsn, transcribed=datetime(2017,9,1),
                        source='m', mturkuid=mturkuid, mturkhit=mturkhit))
                print('transcription id', tid, 'was:', tsn)
                tids.append(tid)
            ut.update(dict(id=uid, transcriptions=tids), ['id'])
            print(len(tids), 'transcriptions')
        pt.update(dict(id=pid, utterances=uids), ['id'])
        print(len(uids), 'utterances')
    print(len(pris), 'prompts:', pris)


def load_homophones_and_phonemes(spelling_schema=True):
    db = dataset.connect(connect_string)
    wt = db['words']
    if spelling_schema:
        wt.create_column('homops', ARRAY(Text))
    pt = db['prompts']
    ut = db['utterances']
    with open('words.txt', 'r') as f:
        for l in f:
            if ' also ' in l:
                g = l.strip().split('#')[0].split('also')[0].split()
                h = l.strip().split('#')[0].split('also')[1].strip().split(', ')
            else:
                g = l.strip().split('#')[0].split()
                h = []
            w = g[0]
            p = g[1:]
            # insert phoneme string in words
            wt.update(dict(spelling=w, phonemes=p, homops=h), ['spelling'])
            print('word:', w, p, 'homophones:', h)
    # loop over prompts constructing phonemes from words' phonemes
    for p in pt:
        ps = []
        for w in p['words']:
            for pn in wt.find_one(id=w)['phonemes']:
                ps.append(pn)
        pt.update(dict(id=p['id'], phonemes=ps), ['id'])
        print('prompt', p['id'], p['display'], ps)
    # loop over utterances copying phonemes from prompts
    for u in ut:
        p = pt.find_one(id=u['promptid'])
        ut.update(dict(id=u['id'], phonemes=p['phonemes']), ['id'])
        print('utterance', u['id'], p['display'], p['phonemes'])


def load_exemplars():
    db = dataset.connect(connect_string)
    pt = db['prompts']
    ut = db['utterances']
    wt = db['words']
    for e in sorted(glob('../database/exemplars/*/*.wav')):
        w = e.split('/')[-1].replace('.wav','').split('-')[0]
        n = e.split('/')[-1].replace('.wav','').split('-')[1]
        with open(e, 'rb') as f:
            wav = f.read()
        wd = wt.find_one(spelling=w) # l/c only
        p = pt.find_one(words='{' + str(wd['id']) + '}') # mixed case display
        pid = p['id']
        ps = p['phonemes']
        put = p['utterances']
        if ut.find_one(promptid=pid, whose=n) is None:
            uid = ut.insert(dict(promptid=pid, phonemes=ps, wav=wav,
                    exemplar=True, adult=True, at=datetime(2017,11,1), whose=n))
            if uid not in put:
                pt.update(dict(id=pid, utterances=put + [uid]), ['id'])
            print('exemplar', w, 'by', n)
        else:
            print('exemplar', w, 'by', n, 'already extant')

def load_exemplar_transcripts():
    db = dataset.connect(connect_string)
    pt = db['prompts']
    ut = db['utterances']
    tt = db['transcriptions']
    for u in ut.find(exemplar=True, order_by='id'):
        p = pt.find_one(id=u['promptid'])
        tid = tt.insert(dict(utteranceid=u['id'], promptid=u['promptid'],
                intelligible=True, transcription=p['display'],
                transcribed=datetime(2017,11,1), source='e'))
        ut.update(dict(id=u['id'], transcriptions=[tid]))
        put = p['utterances']
        pt.update(dict(id=u['promptid'], utterances=sorted(list(
                set(put + [u['id']])))), ['id'])


def load_fvta10():
    db = dataset.connect(connect_string)
    ut = db['utterances']
    for u in ut:
        if u['features'] is not None: # skip if done
            continue
        uid = u['id']
        ps = u['phonemes']
        mp3 = u['mp3']
        if mp3 is None:
            wav = u['wav']
            with open('featex.wav', 'wb') as f:
                f.write(wav)
        else:
            with open('featex.mp3', 'wb') as f:
                f.write(mp3)
                system('mpg123 -qw featex.wav featex.mp3 ; rm -f featex.mp3')
        system('sox -q featex.wav -r16k -ts16 -c1 featex.raw norm fade t 0.05 0')
        system('rm -f featex.wav ; ./ps/featex/featex-vta10 ' + ' '.join(ps)
            + ' > featex.out 2> /dev/null ; rm -f featex.raw')
        with open('featex.out', 'r') as f:
            o = f.read().split()
        system('rm -f featex.out')
        if len(o) > 4:
            v = [float(n) for n in o]
            ut.update(dict(id=uid, features=v), ['id'])
            print('utterance', uid, v)
        else:
            print('utterance', uid, 'failed')


def align_exemplars():
    db = dataset.connect(connect_string)
    ut = db['utterances']
    for u in ut.find(exemplar=True):
        if u['alignment'] is not None: # skip if done
            continue
        uid = u['id']
        ps = u['phonemes']
        wav = u['wav']
        a = []
        with open('featex.wav', 'wb') as f:
            f.write(wav)
        system('sox -q featex.wav -r16k -ts16 -c1 featex.raw norm fade t 0.05 0')
        system('rm -f featex.wav ; ./ps/featex/featex-vta10 ' + ' '.join(ps)
            + ' 2> featex.out > /dev/null ; rm -f featex.raw')
        with open('featex.out', 'r') as f:
            for l in f:
                if 'featex-vta10: phoneme ' in l:
                    o = l.strip().replace(':', '').replace('s', '').replace(
                        ',', '').split()
                    a.append(o[4]) # start
                    a.append(o[6]) # duration
        system('rm -f featex.out')
        if len(o) > 4:
            v = [float(n) for n in a]
            ut.update(dict(id=uid, alignment=v), ['id'])
            print('exemplar', uid, v)
        else:
            print('exemplar', uid, 'failed')


def load_intelligibility():
    db = dataset.connect(connect_string)
    tt = db['transcriptions']
    pt = db['prompts']
    wt = db['words']
    okchars = ascii_lowercase + digits + " '"
    for t in tt.find(exemplar=False, order_by='id'):
        p = pt.find_one(id=t['promptid'])
        wl = p['words']
        ws = [] # two layer list of spellings plus homographs
        match = False
        for w in wl:
            tw = wt.find_one(id=w)
            ws.append([tw['spelling']] + tw['homops']) # lowercase already
        ts = ''.join([c.lower() for c in t['transcription'] # keep only lower-
            if c.lower() in okchars]).strip() + ' ' # case, spaces & apostrophes
        tc = '' # autocorrect-ed alternative
        for w in ts.split():
            tc += spell(w) + ' ' # spelling correction
        # compare permutations.startswith(word(s)), appending a space to each
        for ps in [' '.join(s) + ' ' for s in product(*ws)]: # itertools.product
            if (tc.startswith(ps) or tc.startswith(ps.replace(' ', '') + ' ')
                    or (tc.replace(' ', '') + ' ').startswith(ps) or
                    ts.startswith(ps) or ts.startswith(ps.replace(' ', '') +' ')
                    or (ts.replace(' ', '') + ' ').startswith(ps)):
                match = True
                break
        tt.update(dict(id=t['id'], intelligible=match), ['id'])
        print('transcription ', ts, 'is', match, 'for:', p['display'])


def measure_models():
    db = dataset.connect(connect_string)
    pt = db['prompts']
    tt = db['transcriptions']
    ut = db['utterances']
    s = 0.0; se = 0.0; sv = 0.0; n = 0; ne = 0; nv = 0
    for p in pt.all(order_by='display'):
        if p['utterances'] in [None, []]:
            continue # duplicate audio utterances got merged leaving []s
        pd = p['display']
        np = len(p['phonemes'])
        nu = len(p['utterances'])
        X = []; y = []
        for u in ut.find(promptid=p['id'], order_by='whose'):
            if u['transcriptions'] in [None, []] or u['features'] in [None, []]:
                continue
            for t in tt.find(utteranceid=u['id'], order_by='transcription'):
                X.append(u['features'])
                if t['intelligible']:
                    y.append(1.0)
                else:
                    y.append(0.0)
        if len(X) < 4:
            continue
        # negative examples of other prompts features of the same length
        for fl in db.query('select distinct on (promptid) features from '
                + '(select * from (select distinct on (features) features, '
                + 'promptid from utterances where promptid <> :pid and '
                + 'cardinality(phonemes) = :pl) as o order by random()) as t '
                + 'limit 15', pid=p['id'], pl=np): # 30 bombs w/"NaN" TODO
            if fl['features'] is not None: # some were missing, not sure why
                X.append(fl['features'])
                y.append(0.0)
        m = SVC(gamma='auto') # was C=7, class_weight='balanced'
        cvs = cross_val_score(m, X, y, cv=KFold(len(X)), n_jobs=-1)
        s1 = cvs.mean()
        cvp = cross_val_predict(m, X, y, cv=KFold(len(X)), n_jobs=-1)
        print('Prompt:', pd, np, 'phonemes,', nu, 'utterances,',
                len(y), 'transcriptions, %.1f%% intelligible,'
                % (sum(y) * 100 / len(y)), '%.1f%% accurate'
                % (s1 * 100), end='')
        s += s1; n += 1
        if len(set(cvp)) == 1:
            print (' VOID', end='')
        else:
            sv += s1; nv += 1
        if ut.find_one(promptid=p['id'], exemplar=True) is not None:
            se += s1; ne += 1
            print(' with exemplar')
        else:
            print('')
    print('Mean leave-one-out cross-validated accuracy: %.1f%%'
            % (s * 100 / n), 'for', n, 'prompts')
    print('Mean leave-one-out cross-validated accuracy: %.1f%%'
            % (se * 100 / ne), 'for', ne, 'with exemplars.')
    print('Mean leave-one-out cross-validated accuracy: %.1f%%'
            % (sv * 100 / nv), 'for', nv, 'non-void.')

# vs. 73.8% before autocorrect up to 74.1% after, non-cross
#
### before exemplar transcriptions:
# Mean 5-fold cross-validated accuracy: 53.2 ± 25.8% for 690 prompts
# Mean 5-fold cross-validated accuracy: 62.3 ± 31.1% for 40 with exemplars.
### after exemplar transcriptions:
# Mean 5-fold cross-validated accuracy: 53.3 ± 25.6% for 690 prompts
# Mean 5-fold cross-validated accuracy: 63.7 ± 27.7% for 40 with exemplars.
#
### leave-one-out:
# Mean leave-one-out cross-validated accuracy: 65.8 ± 93.0% for 694 prompts
# Mean leave-one-out cross-validated accuracy: 73.3 ± 87.0% for 40 with exemplars.
#
### negative examples:
# Mean leave-one-out cross-validated accuracy: 69.5 ± 89.6% for 694 prompts
# Mean leave-one-out cross-validated accuracy: 80.2 ± 77.7% for 40 with exemplars.
#
### vs DNNs:
# SVC: 69.2, DNN: 74.3; with exemplars: 80.0 85.4
#
### SVC(C=7)
# Mean leave-one-out cross-validated accuracy: 72.2 ± 87.9% for 694 prompts
# Mean leave-one-out cross-validated accuracy: 79.1 ± 80.4% for 40 with exemplars.
#
### negative examples
# 10 71.0, 20 73.5, 30 75.2, 40 76.8, 50 77.8
# with exemplars: 10 77.6, 20 79.9, 30 82.4, 40 84.0, 50 84.8

## regular SVC() with void detection:
# Mean leave-one-out cross-validated accuracy: 71.7% for 694 prompts
# Mean leave-one-out cross-validated accuracy: 83.1% for 40 with exemplars.
# Mean leave-one-out cross-validated accuracy: 70.3% for 549 non-void.
# jsalsman@gitlab-db-vm:~/src$ grep -c VOID\ with measurelog4.txt
# 8
# jsalsman@gitlab-db-vm:~/src$ grep -c VOID measurelog4.txt
# 145



model = {}
def train_models():
    global model
    db = dataset.connect(connect_string)
    pt = db['prompts']
    tt = db['transcriptions']
    ut = db['utterances']
    for p in pt.all(order_by='display'):
        if p['utterances'] in [None, []]:
            continue # duplicate audio utterances got merged leaving []s
        pd = p['display']
        np = len(p['phonemes'])
        nu = len(p['utterances'])
        X = []; y = []
        for u in ut.find(promptid=p['id'], order_by='whose'):
            if u['transcriptions'] in [None, []] or u['features'] in [None, []]:
                continue
            for t in tt.find(utteranceid=u['id'], order_by='transcription'):
                X.append(u['features'])
                if t['intelligible']:
                    y.append(1.0)
                else:
                    y.append(0.0)
        if len(X) < 4:
            continue
        # negative examples
        for fl in db.query('select distinct on (promptid) features from '
                + '(select * from (select distinct on (features) features, '
                + 'promptid from utterances where promptid <> :pid and '
                + 'cardinality(phonemes) = :pl) as o order by random()) as t '
                + 'limit 15', pid=p['id'], pl=np): # TODO: 10? 20?
            if fl['features'] is not None: # some were missing, not sure why
                X.append(fl['features'])
                y.append(0.0)
        m = SVC(C=7, gamma='scale', class_weight='balanced', probability=True
                ).fit(X, y) # and class_weight; weight exemplars
        if ut.find_one(promptid=p['id'], exemplar=True) is not None:
            ex = True
        else:
            ex = False
        model[p['display']] = [m, X, y, len(p['phonemes']), ex]
        print(len(model), 'trained ', p['display'], ' with',
                len(p['phonemes']), 'phonemes and', len(y), 'transcriptions',
                'with' if ex else 'without', 'exemplars')
    #for k, v in model.items():


# def diphone_remix()

# user input

# login

# learner status

# logout?

# sequencing


# multiple choices


## print('replacing schema...')
## create_database(clobber=True)
## print('...schema replaced')
##
## print('loading database...')
## load_database(clobber=True)
## print('...database loaded')
##
## print('loading homophones and phonemes...')
## load_homophones_and_phonemes()
## print('...homophones and phonemes loaded')
##
## print('loading exemplars...')
## load_exemplars()
## load_exemplar_transcripts()
## print('...exemplars loaded')
##
## print('extracting features...')
## load_fvta10()
## print('...features extracted')
##
## print('aligning exemplars...')
## align_exemplars()
## print('...exemplars aligned')
##
## print('determining intelligibility...')
## load_intelligibility()
## print('...transcriptions processed')
##
print('measuring models...')
measure_models()
##
#print('training models...')
#train_models()
