#!/usr/bin/env python
#
# Copyright (C) 2011 Apple Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1.  Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
# 2.  Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the distribution.
# 3.  Neither the name of Apple Computer, Inc. ("Apple") nor the names of
#     its contributors may be used to endorse or promote products derived
#     from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY APPLE AND ITS CONTRIBUTORS "AS IS" AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL APPLE OR ITS CONTRIBUTORS BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys
import getopt
from optparse import OptionParser

oneK = 1024
oneM = 1024 * 1024
oneG = 1024 * 1024 * 1024

hotspot = False
scaleSize = True
showBars = True

def byteString(bytes):
    if scaleSize:
        format = '  %4d   '
        val = bytes

        if bytes >= oneG:
            format = '%8.1fG'
            val = float(bytes) / oneG
        elif bytes >= oneM:
            format = '%8.1fM'
            val = float(bytes) / oneM
        elif bytes >= oneK:
            format = '%8.1fK'
            val = float(bytes) / oneK

        return format % val
    if hotspot:
        return '%d' % bytes
    return '%12d' % bytes

class Node:
    def __init__(self, name, level = 0, bytes = 0):
        self.name = name
        self.level = level
        self.children = {}
        self.totalBytes = bytes

    def hasChildren(self):
        return len(self.children) > 0

    def getChild(self, name):
        if not name in self.children:
            newChild = Node(name, self.level + 1)
            self.children[name] = newChild

        return self.children[name]

    def getBytes(self):
        return self.totalBytes

    def addBytes(self, bytes):
        self.totalBytes = self.totalBytes + bytes

    def processLine(self, bytes, line):
        sep = line.find('|')
        if sep < 0:
            childName = line.strip()
            line = ''
        else:
            childName = line[:sep].strip()
            line = line[sep+1:]

        child = self.getChild(childName)
        child.addBytes(bytes)

        if len(line) > 0:
            child.processLine(bytes, line)

    def printNode(self, prefix = ' '):
        global hotspot
        global scaleSize
        global showBars

        if self.hasChildren():
            byteStr = byteString(self.totalBytes)

            if hotspot:
                print('    %s%s %s' % (self.level * ' ', byteString(self.totalBytes), self.name))
            else:
                print('%s %s%s' % (byteString(self.totalBytes), prefix[:-1], self.name))

            sortedChildren = sorted(self.children.values(), key=sortKeyByBytes, reverse=True)

            if showBars and len(self.children) > 1:
                newPrefix = prefix + '|'
            else:
                newPrefix = prefix + ' '

            childrenLeft = len(sortedChildren)
            for child in sortedChildren:
                if childrenLeft <= 1:
                    newPrefix = prefix + ' '
                else:
                    childrenLeft = childrenLeft - 1
                child.printNode(newPrefix)
        else:
            byteStr = byteString(self.totalBytes)

            if hotspot:
                print('    %s%s %s' % (self.level * ' ', byteString(self.totalBytes), self.name))
            else:
                print('%s %s%s' % (byteString(self.totalBytes), prefix[:-1], self.name))

def sortKeyByBytes(node):
    return node.getBytes();

def main():
    global hotspot
    global scaleSize
    global showBars

    # parse command line options
    parser = OptionParser(usage='malloc-tree [options] [malloc_history-file]',
                          description='Format malloc_history output as a nested tree',
                          epilog='stdin used if malloc_history-file is missing')

    parser.add_option('-n', '--nobars', action='store_false', dest='showBars',
                      default=True, help='don\'t show bars lining up siblings in tree');
    parser.add_option('-b', '--size-in-bytes', action='store_false', dest='scaleSize',
                      default=None, help='show sizes in bytes');
    parser.add_option('-s', '--size-scale', action='store_true', dest='scaleSize',
                      default=None, help='show sizes with appropriate scale suffix [K,M,G]');
    parser.add_option('-t', '--hotspot', action='store_true', dest='hotspot',
                      default=False, help='output in HotSpotFinder format, implies -b');

    (options, args) = parser.parse_args()

    hotspot = options.hotspot
    if options.scaleSize is None:
        if hotspot:
            scaleSize = False
        else:
            scaleSize = True
    else:
        scaleSize = options.scaleSize
    showBars = options.showBars

    if len(args) < 1:
        inputFile = sys.stdin
    else:
        inputFile = open(args[0], "r")

    line = inputFile.readline()

    rootNodes = {}

    while line:
        firstSep = line.find('|')
        if firstSep > 0:
            firstPart = line[:firstSep].strip()
            lineRemain = line[firstSep+1:]
            bytesSep = firstPart.find('bytes:')
            if bytesSep >= 0:
                name = firstPart[bytesSep+7:]
                stats = firstPart.split(' ')
                bytes = int(stats[3].replace(',', ''))

                if not name in rootNodes:
                    node = Node(name, 0, bytes);
                    rootNodes[name] = node
                else:
                    node = rootNodes[name]
                    node.addBytes(bytes)

                node.processLine(bytes, lineRemain)

        line = inputFile.readline()

    sortedRootNodes = sorted(rootNodes.values(), key=sortKeyByBytes, reverse=True)

    print 'Call graph:'
    try:
        for node in sortedRootNodes:
            node.printNode()
            print
    except:
        pass

if __name__ == "__main__":
    main()