#!/usr/bin/python
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This script will take any number of trace files generated by strace(1)
# and output a system call filtering policy suitable for use with Minijail.

from collections import namedtuple
import sys

NOTICE = """# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

ALLOW = "%s: 1"

SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname",
               "getsockopt", "listen", "recv", "recvfrom", "recvmsg", "send",
               "sendmsg", "sendto", "setsockopt", "shutdown", "socket",
               "socketpair"]

# /* Protocol families.  */
# #define PF_UNSPEC     0       /* Unspecified.  */
# #define PF_LOCAL      1       /* Local to host (pipes and file-domain).  */
# #define PF_UNIX       PF_LOCAL /* POSIX name for PF_LOCAL.  */
# #define PF_FILE       PF_LOCAL /* Another non-standard name for PF_LOCAL.  */
# #define PF_INET       2       /* IP protocol family.  */
# #define PF_AX25       3       /* Amateur Radio AX.25.  */
# #define PF_IPX        4       /* Novell Internet Protocol.  */
# #define PF_APPLETALK  5       /* Appletalk DDP.  */
# #define PF_NETROM     6       /* Amateur radio NetROM.  */
# #define PF_BRIDGE     7       /* Multiprotocol bridge.  */
# #define PF_ATMPVC     8       /* ATM PVCs.  */
# #define PF_X25        9       /* Reserved for X.25 project.  */
# #define PF_INET6     10      /* IP version 6.  */
# #define PF_ROSE      11      /* Amateur Radio X.25 PLP.  */
# #define PF_DECnet    12      /* Reserved for DECnet project.  */
# #define PF_NETBEUI   13      /* Reserved for 802.2LLC project.  */
# #define PF_SECURITY  14      /* Security callback pseudo AF.  */
# #define PF_KEY       15      /* PF_KEY key management API.  */
# #define PF_NETLINK   16

ArgInspectionEntry = namedtuple("ArgInspectionEntry", "arg_index value_set")


def usage(argv):
    print "%s <trace file> [trace files...]" % argv[0]


def main(traces):
    syscalls = {}

    uses_socketcall = False

    basic_set = ["restart_syscall", "exit", "exit_group",
                 "rt_sigreturn"]
    frequent_set = []

    syscall_sets = {}
    syscall_set_list = [["sigreturn", "rt_sigreturn"],
                        ["sigaction", "rt_sigaction"],
                        ["sigprocmask", "rt_sigprocmask"],
                        ["open", "openat"],
                        ["mmap", "mremap"],
                        ["mmap2", "mremap"]]

    arg_inspection = {
        "socket": ArgInspectionEntry(0, set([])),   # int domain
        "ioctl": ArgInspectionEntry(1, set([])),    # int request
        "prctl": ArgInspectionEntry(0, set([]))     # int option
    }

    for syscall_list in syscall_set_list:
        for syscall in syscall_list:
            other_syscalls = syscall_list[:]
            other_syscalls.remove(syscall)
            syscall_sets[syscall] = other_syscalls

    for trace_filename in traces:
        if "i386" in trace_filename or ("x86" in trace_filename and
                                        "64" not in trace_filename):
            uses_socketcall = True

        trace_file = open(trace_filename)
        for line in trace_file:
            if "---" in line or '(' not in line:
                continue

            syscall, args = line.strip().split('(', 1)
            if uses_socketcall and syscall in SOCKETCALLS:
                syscall = "socketcall"

            if syscall in syscalls:
                syscalls[syscall] += 1
            else:
                syscalls[syscall] = 1

            args = [arg.strip() for arg in args.split(')', 1)[0].split(',')]

            if syscall in arg_inspection:
                arg_value = args[arg_inspection[syscall].arg_index]
                arg_inspection[syscall].value_set.add(arg_value)

    sorted_syscalls = list(zip(*sorted(syscalls.iteritems(),
                                       key=lambda pair: pair[1],
                                       reverse=True))[0])

    print NOTICE

    # Add frequent syscalls first.
    for frequent_syscall in frequent_set:
        sorted_syscalls.remove(frequent_syscall)

    all_syscalls = frequent_set + sorted_syscalls

    # Add the basic set once the frequency drops below 2.
    below_ten_index = -1
    for sorted_syscall in sorted_syscalls:
        if syscalls[sorted_syscall] < 2:
            below_ten_index = all_syscalls.index(sorted_syscall)
            break

    first_half = all_syscalls[:below_ten_index]
    for basic_syscall in basic_set:
        if basic_syscall not in all_syscalls:
            first_half.append(basic_syscall)

    all_syscalls = first_half + all_syscalls[below_ten_index:]

    for syscall in all_syscalls:
        if syscall in arg_inspection:
            arg_index = arg_inspection[syscall].arg_index
            arg_values = arg_inspection[syscall].value_set
            arg_filter = " || ".join(["arg%d == %s" % (arg_index, arg_value)
                                      for arg_value in arg_values])
            print syscall + ": " + arg_filter
        else:
            print ALLOW % syscall


if __name__ == "__main__":
    if len(sys.argv) < 2:
        usage(sys.argv)
        sys.exit(1)

    main(sys.argv[1:])