
import argparse
import glob
import numpy
#import mlpy
#import time
#import scipy
import os
from pyAudioAnalysis import audioFeatureExtraction as aF
from pyAudioAnalysis import audioTrainTest as aT
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import audioSegmentation as aS
import subprocess
import wave

def mtFileClassification(inputFile, modelName, modelType, plotResults=False, gtFile=""):
    '''
    This function performs mid-term classification of an audio stream.
    Towards this end, supervised knowledge is used, i.e. a pre-trained classifier.
    ARGUMENTS:
        - inputFile:        path of the input WAV file
        - modelName:        name of the classification model
        - modelType:        svm or knn depending on the classifier type
        - plotResults:      True if results are to be plotted using matplotlib along with a set of statistics

    RETURNS:
          - segs:           a sequence of segment's endpoints: segs[i] is the endpoint of the i-th segment (in seconds)
          - classes:        a sequence of class flags: class[i] is the class ID of the i-th segment
    '''

    if not os.path.isfile(modelName):
        print "mtFileClassificationError: input modelType not found!"
        return (-1, -1, -1)
    # Load classifier:
    if modelType == 'svm':
        [Classifier, MEAN, STD, classNames, mtWin, mtStep, stWin, stStep, computeBEAT] = aT.loadSVModel(modelName)
    elif modelType == 'knn':
        [Classifier, MEAN, STD, classNames, mtWin, mtStep, stWin, stStep, computeBEAT] = aT.loadKNNModel(modelName)
    if computeBEAT:
        print "Model " + modelName + " contains long-term music features (beat etc) and cannot be used in segmentation"
        return (-1, -1, -1)
    [Fs, x] = audioBasicIO.readAudioFile(inputFile)        # load input file
    if Fs == -1:                                           # could not read file
        return (-1, -1, -1)
    x = audioBasicIO.stereo2mono(x)                        # convert stereo (if) to mono
    ##Duration = len(x) / Fs
    # mid-term feature extraction:
    [MidTermFeatures, _] = aF.mtFeatureExtraction(x, Fs, mtWin * Fs, mtStep * Fs, round(Fs * stWin), round(Fs * stStep))
    flags = []
    Ps = []
    flagsInd = []
    for i in range(MidTermFeatures.shape[1]):              # for each feature vector (i.e. for each fix-sized segment):
        curFV = (MidTermFeatures[:, i] - MEAN) / STD       # normalize current feature vector
        [Result, P] = aT.classifierWrapper(Classifier, modelType, curFV)    # classify vector
        flagsInd.append(Result)
        flags.append(classNames[int(Result)])              # update class label matrix
        Ps.append(numpy.max(P))                            # update probability matrix
    flagsInd = numpy.array(flagsInd)

    # 1-window smoothing
    for i in range(1, len(flagsInd) - 1):
        if flagsInd[i-1] == flagsInd[i + 1]:
            flagsInd[i] = flagsInd[i + 1]
    (segs, classes) = aS.flags2segs(flags, mtStep)            # convert fix-sized flags to segments and classes
    segs[-1] = len(x) / float(Fs)

    # Load grount-truth:
    if os.path.isfile(gtFile):
        [segStartGT, segEndGT, segLabelsGT] = aS.readSegmentGT(gtFile)
        flagsGT, classNamesGT = aS.segs2flags(segStartGT, segEndGT, segLabelsGT, mtStep)
        flagsIndGT = []
        for j, fl in enumerate(flagsGT):                    # "align" labels with GT
            if classNamesGT[flagsGT[j]] in classNames:
                flagsIndGT.append(classNames.index(classNamesGT[flagsGT[j]]))
            else:
                flagsIndGT.append(-1)
        flagsIndGT = numpy.array(flagsIndGT)
    else:
        flagsIndGT = numpy.array([])
    return flagsInd, flagsIndGT, classNames, mtStep

def compute_speech_score(inputFile):
    '''
    Processes a file and returns a speech score representing the percentage of speech in the file
    '''
    global modelFilePath
    [flagsInd, flagsIndGT, classNames, mtStep] = mtFileClassification(inputFile, modelFilePath, "svm", False)

    flags = [classNames[int(f)] for f in flagsInd]
    (segs, classes) = aS.flags2segs(flags, mtStep)

    '''minLength = min(flagsInd.shape[0], flagsIndGT.shape[0])
    if minLength > 0:
        accuracy = numpy.count_nonzero(flagsInd[0:minLength] == flagsIndGT[0:minLength]) / float(minLength)
    else:
        accuracy = -1'''

    Duration = segs[-1, 1]
    SPercentages = numpy.zeros((len(classNames), 1))
    Percentages = numpy.zeros((len(classNames), 1))
    AvDurations = numpy.zeros((len(classNames), 1))

    for iSeg in range(segs.shape[0]):
        SPercentages[classNames.index(classes[iSeg])] += (segs[iSeg, 1]-segs[iSeg, 0])

    for i in range(SPercentages.shape[0]):
        Percentages[i] = 100.0 * SPercentages[i] / Duration
        S = sum(1 for c in classes if c == classNames[i])
        if S > 0:
            AvDurations[i] = SPercentages[i] / S
        else:
		AvDurations[i] = 0.0

    for i in range(Percentages.shape[0]):
        print classNames[i], Percentages[i], AvDurations[i]
        
    return classNames, Percentages
    
def check_file_size_and_duration(file_path):
    file_size = os.path.getsize(file_path)
    #Process only files greater than 50 bytes
    if not file_size > 50:
        print "File "+file_path+" has a size of "+str(file_size)+". Will be deleted"
        os.remove(file_path)
        return False
    return True;
    
def parse_arguments():
    parser = argparse.ArgumentParser(description="A demonstration script for pyAudioAnalysis library. Extended to extract only speechiness percentage score")
    tasks = parser.add_subparsers(
        title="subcommands", description="available tasks", dest="task", metavar="")
    speechScoreFile = tasks.add_parser("GetFilespeechScore", help="SpeechScore - classification of a WAV file given a trained SVM or kNN to determine percentage of speech in file")
    speechScoreFile.add_argument("-i", "--folderToProcess", required=True, help="Input folder to process")
    speechScoreFile.add_argument("-s", "--successfulFilesFolder", required=True, help="Folder to move successfully scored files to")
    speechScoreFile.add_argument("-f", "--failedFilesFolder", required=True, help="Folder to move files that have failed the score test to")
    speechScoreFile.add_argument("-m", "--modelFilePath", required=True, help="Path to speech model for classifying specified file")
    speechScoreFile.add_argument("-p", "--speechnessPassMark", required=True, help="Speech score percentage pass mark on a scale of 0-100")
    speechScoreFile.add_argument("-d", "--deleteRawFiles", type=int, choices=[0, 1], required=True, help="Upon downsampling of with sox, True [1]: Deletes raw wav files, and False [0]: leaves raw wav files within file system")
    speechScoreFile.add_argument("-t", "--fileInUseShScriptPath", required=True, help="Path to shell script for testing if file is in use")
    return parser.parse_args()

if __name__=='__main__':  
    '''
    Read commandline arguments and init variables
    '''
    args = parse_arguments()
    folderToProcess=args.folderToProcess#"/Users/slabedo/Downloads/detect_speech/data_files/processed/sox/"
    successfulFilesFolder=args.successfulFilesFolder#"/Users/slabedo/Downloads/detect_speech/data_files/success/"
    failedFilesFolder=args.failedFilesFolder#"/Users/slabedo/Downloads/detect_speech/data_files/failed/"
    #inputFile="/Users/slabedo/Downloads/detect_speech/data_files/processed/score.04_mp3_sox.wav"
    modelFilePath=args.modelFilePath#"/Users/slabedo/GitHub/PLK/radio_mining/pyAudioAnalysis/data/svmSM"
    fileInUseShScriptPath=args.fileInUseShScriptPath
    requiredScore=float(args.speechnessPassMark)#30.0;
    
    filesList=glob.glob(folderToProcess+"/*.wav")
    print "Processing: "+str(len(filesList))+" total files in "+folderToProcess
    for inputFile in filesList:
        if not os.path.isfile(inputFile) or "_tmp_" in inputFile:
            continue
        print "\n*** ***\nProcessing: "+inputFile

        #Check if file is in use
        #Shell script to wait until file is released     
        file_in_use_cmd=fileInUseShScriptPath+'file_in_use_check.sh '+inputFile+' '+str(30) #/Users/slabedo/GitHub/PLK/radio_mining/recording/detect_speech/  
        r = subprocess.call(file_in_use_cmd, shell=True)
        if r>0:
            raise Exception('Error running command: %s' % (file_in_use_cmd))
        
        '''
        if not os.path.isfile(inputFile):
            print 'File '+inputFile+' deleted'
            continue
        '''
        
        #Check if file has size and duration > 0
        if not check_file_size_and_duration(inputFile):
            continue
                
        
        #np.random.seed()
        #randomid = '%d' % (np.random.randint(1e8))
        #tmpfname = inputFile+'%d.wav' % (np.random.randint(1e8)) 
        #tmpfname = inputFile+'%s.wav' % ("_tmp") 
        tmpfname = inputFile.replace(".wav", "_tmp_.wav") 
        try:
            fname_nospace = inputFile.replace(' ','\ ') 
            #sox_convert_command = 'sox %s -r 16000 %s' % (inputFile, tmpfname)
            sox_convert_command = 'sox -r 16k -e unsigned-integer -b 16 -c 1 -t raw %s %s' % (inputFile, tmpfname)
            r = subprocess.call(sox_convert_command, shell=True)
            if r>0:
                raise Exception('Error running command: %s' % (sox_convert_command))
               
            #Only process files with duration greater than 10 seconds
            f = wave.open(tmpfname, 'r')
            file_duration = f.getnframes() / float(f.getframerate())
            #with contextlib.closing(wave.open(file_path,'r')) as f:
                #file_duration = f.getnframes() / float(f.getframerate())
            if file_duration < 10:
                print "File " + inputFile + " has a duration of " + str(file_duration) + ". Will be deleted"
                os.remove(inputFile)
                os.remove(tmpfname)
                continue
                
            classNames, Percentages=compute_speech_score(tmpfname)
            for i in range(len(classNames)):
                if classNames[i]=='speech':
                    file_name=os.path.basename(inputFile)#os.path.split(inputFile)
                    if Percentages[i][0]>=requiredScore:
                        os.rename(tmpfname, successfulFilesFolder+file_name)
                        print "File "+inputFile+" past score test with score of "+str(Percentages[i][0])
                    else:
                        print "File "+inputFile+" failed score test with score of "+str(Percentages[i][0])
                        #os.rename(tmpfname, failedFilesFolder+file_name)
                        print "Deleting tmp file "+tmpfname+" from the file system"
                        rm_command = 'rm %s' % (tmpfname)
                        r = subprocess.call(rm_command, shell=True)
                        if r>0:
                            raise Exception('Error running command: %s' % (rm_command))
        finally:
            # Clean up temporary files
            ###rm_command = 'rm %s' % (tmpfname)
            ###r = subprocess.call(rm_command, shell=True)
            ###if r>0:
            ###    raise Exception('Error running command: %s' % (rm_command))
            
            #Delete raw files if specified at commandline
            if args.deleteRawFiles==1 and os.path.exists(inputFile):
                print "Deleting file "+inputFile+" from the file system"
                rm_command = 'rm %s' % (inputFile)
                r = subprocess.call(rm_command, shell=True)
                if r>0:
                    raise Exception('Error running command: %s' % (rm_command))
    print '\n'
		
#print [flagsInd,flagsIndGT, classNames, mtStep]
#print "total_speech: "+str(total_speech)+"\ntotal_nonspeech: "+str(total_nonspeech)+"\ntotal all: "+str(total_speech+total_nonspeech)+"\ntotal i: "+str(icount)
