#! /usr/bin/python -Es
# Copyright (C) 2005 Red Hat
# see file 'COPYING' for use and warranty information
#
#    chcat is a script that allows you modify the Security label on a file
#
#`   Author: Daniel Walsh <dwalsh@redhat.com>
#
#    This program is free software; you can redistribute it and/or
#    modify it under the terms of the GNU General Public License as
#    published by the Free Software Foundation; either version 2 of
#    the License, or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
#                                        02111-1307  USA
#
#
import subprocess
import sys
import os
import pwd
import string
import getopt
import selinux
import seobject
import gettext

try:
    gettext.install('policycoreutils')
except IOError:
    try:
        import builtins
        builtins.__dict__['_'] = str
    except ImportError:
        import __builtin__
        __builtin__.__dict__['_'] = unicode


def errorExit(error):
    sys.stderr.write("%s: " % sys.argv[0])
    sys.stderr.write("%s\n" % error)
    sys.stderr.flush()
    sys.exit(1)


def verify_users(users):
    for u in users:
        try:
            pwd.getpwnam(u)
        except KeyError:
            error("User %s does not exist" % u)


def chcat_user_add(newcat, users):
    errors = 0
    logins = seobject.loginRecords()
    seusers = logins.get_all()
    add_ind = 0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user = seusers[u]
        else:
            add_ind = 1
            user = seusers["__default__"]
        serange = user[1].split("-")
        cats = []
        top = ["s0"]
        if len(serange) > 1:
            top = serange[1].split(":")
            if len(top) > 1:
                cats.append(top[1])
                cats = expandCats(cats)

        for i in newcat[1:]:
            if i not in cats:
                cats.append(i)

        if len(cats) > 0:
            new_serange = "%s-%s:%s" % (serange[0], top[0], ",".join(cats))
        else:
            new_serange = "%s-%s" % (serange[0], top[0])

        if add_ind:
            cmd = "semanage login -a -r %s -s %s %s" % (new_serange, user[0], u)
        else:
            cmd = "semanage login -m -r %s -s %s %s" % (new_serange, user[0], u)
        rc = subprocess.getstatusoutput(cmd)
        if rc[0] != 0:
            print(rc[1])
            errors += 1

    return errors


def chcat_add(orig, newcat, objects, login_ind):
    if len(newcat) == 1:
        raise ValueError(_("Requires at least one category"))

    if login_ind == 1:
        return chcat_user_add(newcat, objects)

    errors = 0
    sensitivity = newcat[0]
    cat = newcat[1]
    cmd = 'chcon -l %s' % sensitivity
    for f in objects:
        (rc, c) = selinux.getfilecon(f)
        con = c.split(":")[3:]
        clist = translate(con)
        if sensitivity != clist[0]:
            print(_("Can not modify sensitivity levels using '+' on %s") % f)

        if len(clist) > 1:
            if cat in clist[1:]:
                print(_("%s is already in %s") % (f, orig))
                continue
            clist.append(cat)
            cats = clist[1:]
            cats.sort()
            cat_string = cats[0]
            for c in cats[1:]:
                cat_string = "%s,%s" % (cat_string, c)
        else:
            cat_string = cat
        cmd = 'chcon -l %s:%s %s' % (sensitivity, cat_string, f)
        rc = subprocess.getstatusoutput(cmd)
        if rc[0] != 0:
            print(rc[1])
            errors += 1
    return errors


def chcat_user_remove(newcat, users):
    errors = 0
    logins = seobject.loginRecords()
    seusers = logins.get_all()
    add_ind = 0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user = seusers[u]
        else:
            add_ind = 1
            user = seusers["__default__"]
        serange = user[1].split("-")
        cats = []
        top = ["s0"]
        if len(serange) > 1:
            top = serange[1].split(":")
            if len(top) > 1:
                cats.append(top[1])
                cats = expandCats(cats)

        for i in newcat[1:]:
            if i in cats:
                cats.remove(i)

        if len(cats) > 0:
            new_serange = "%s-%s:%s" % (serange[0], top[0], ",".join(cats))
        else:
            new_serange = "%s-%s" % (serange[0], top[0])

        if add_ind:
            cmd = "semanage login -a -r %s -s %s %s" % (new_serange, user[0], u)
        else:
            cmd = "semanage login -m -r %s -s %s %s" % (new_serange, user[0], u)
        rc = subprocess.getstatusoutput(cmd)
        if rc[0] != 0:
            print(rc[1])
            errors += 1
    return errors


def chcat_remove(orig, newcat, objects, login_ind):
    if len(newcat) == 1:
        raise ValueError(_("Requires at least one category"))

    if login_ind == 1:
        return chcat_user_remove(newcat, objects)

    errors = 0
    sensitivity = newcat[0]
    cat = newcat[1]

    for f in objects:
        (rc, c) = selinux.getfilecon(f)
        con = c.split(":")[3:]
        clist = translate(con)
        if sensitivity != clist[0]:
            print(_("Can not modify sensitivity levels using '+' on %s") % f)
            continue

        if len(clist) > 1:
            if cat not in clist[1:]:
                print(_("%s is not in %s") % (f, orig))
                continue
            clist.remove(cat)
            if len(clist) > 1:
                cat = clist[1]
                for c in clist[2:]:
                    cat = "%s,%s" % (cat, c)
            else:
                cat = ""
        else:
            print(_("%s is not in %s") % (f, orig))
            continue

        if len(cat) == 0:
            cmd = 'chcon -l %s %s' % (sensitivity, f)
        else:
            cmd = 'chcon -l %s:%s %s' % (sensitivity, cat, f)
        rc = subprocess.getstatusoutput(cmd)
        if rc[0] != 0:
            print(rc[1])
            errors += 1
    return errors


def chcat_user_replace(newcat, users):
    errors = 0
    logins = seobject.loginRecords()
    seusers = logins.get_all()
    add_ind = 0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user = seusers[u]
        else:
            add_ind = 1
            user = seusers["__default__"]
        serange = user[1].split("-")
        new_serange = "%s-%s:%s" % (serange[0], newcat[0], string.join(newcat[1:], ","))
        if new_serange[-1:] == ":":
            new_serange = new_serange[:-1]

        if add_ind:
            cmd = "semanage login -a -r %s -s %s %s" % (new_serange, user[0], u)
        else:
            cmd = "semanage login -m -r %s -s %s %s" % (new_serange, user[0], u)
        rc = subprocess.getstatusoutput(cmd)
        if rc[0] != 0:
            print(rc[1])
            errors += 1
    return errors


def chcat_replace(newcat, objects, login_ind):
    if login_ind == 1:
        return chcat_user_replace(newcat, objects)
    errors = 0
    if len(newcat) == 1:
        sensitivity = newcat[0]
        cmd = 'chcon -l %s ' % newcat[0]
    else:
        sensitivity = newcat[0]
        cmd = 'chcon -l %s:%s' % (sensitivity, newcat[1])
        for cat in newcat[2:]:
            cmd = '%s,%s' % (cmd, cat)

    for f in objects:
        cmd = "%s %s" % (cmd, f)

    rc = subprocess.getstatusoutput(cmd)
    if rc[0] != 0:
        print(rc[1])
        errors += 1

    return errors


def check_replace(cats):
    plus_ind = 0
    replace_ind = 0
    for c in cats:
        if len(c) > 0 and (c[0] == "+" or c[0] == "-"):
            if replace_ind:
                raise ValueError(_("Can not combine +/- with other types of categories"))
            plus_ind = 1
        else:
            replace_ind = 1
            if plus_ind:
                raise ValueError(_("Can not combine +/- with other types of categories"))
    return replace_ind


def isSensitivity(sensitivity):
    if sensitivity[0] == "s" and sensitivity[1:].isdigit() and int(sensitivity[1:]) in range(0, 16):
        return 1
    else:
        return 0


def expandCats(cats):
    newcats = []
    for c in cats:
        for i in c.split(","):
            if i.find(".") != -1:
                j = i.split(".")
                for k in range(int(j[0][1:]), int(j[1][1:]) + 1):
                    x = ("c%d" % k)
                    if x not in newcats:
                        newcats.append(x)
            else:
                if i not in newcats:
                    newcats.append(i)
    if len(newcats) > 25:
        return cats
    return newcats


def translate(cats):
    newcat = []
    if len(cats) == 0:
        newcat.append("s0")
        return newcat
    for c in cats:
        (rc, raw) = selinux.selinux_trans_to_raw_context("a:b:c:%s" % c)
        rlist = raw.split(":")[3:]
        tlist = []
        if isSensitivity(rlist[0]) == 0:
            tlist.append("s0")
            for i in expandCats(rlist):
                tlist.append(i)
        else:
            tlist.append(rlist[0])
            for i in expandCats(rlist[1:]):
                tlist.append(i)
        if len(newcat) == 0:
            newcat.append(tlist[0])
        else:
            if newcat[0] != tlist[0]:
                raise ValueError(_("Can not have multiple sensitivities"))
        for i in tlist[1:]:
            newcat.append(i)
    return newcat


def usage():
    print(_("Usage %s CATEGORY File ...") % sys.argv[0])
    print(_("Usage %s -l CATEGORY user ...") % sys.argv[0])
    print(_("Usage %s [[+|-]CATEGORY],...]q File ...") % sys.argv[0])
    print(_("Usage %s -l [[+|-]CATEGORY],...]q user ...") % sys.argv[0])
    print(_("Usage %s -d File ...") % sys.argv[0])
    print(_("Usage %s -l -d user ...") % sys.argv[0])
    print(_("Usage %s -L") % sys.argv[0])
    print(_("Usage %s -L -l user") % sys.argv[0])
    print(_("Use -- to end option list.  For example"))
    print(_("chcat -- -CompanyConfidential /docs/businessplan.odt"))
    print(_("chcat -l +CompanyConfidential juser"))
    sys.exit(1)


def listcats():
    fd = open(selinux.selinux_translations_path())
    for l in fd.read().split("\n"):
        if l.startswith("#"):
            continue
        if l.find("=") != -1:
            rec = l.split("=")
            print("%-30s %s" % tuple(rec))
    fd.close()
    return 0


def listusercats(users):
    if len(users) == 0:
        try:
            users.append(os.getlogin())
        except:
            users.append(pwd.getpwuid(os.getuid()).pw_name)

    verify_users(users)
    for u in users:
        cats = seobject.translate(selinux.getseuserbyname(u)[2])
        cats = cats.split("-")
        if len(cats) > 1 and cats[1] != "s0":
            print("%s: %s" % (u, cats[1]))
        else:
            print("%s: %s" % (u, cats[0]))


def error(msg):
    print("%s: %s" % (sys.argv[0], msg))
    sys.exit(1)

if __name__ == '__main__':
    if selinux.is_selinux_mls_enabled() != 1:
        error("Requires a mls enabled system")

    if selinux.is_selinux_enabled() != 1:
        error("Requires an SELinux enabled system")

    delete_ind = 0
    list_ind = 0
    login_ind = 0
    try:
        gopts, cmds = getopt.getopt(sys.argv[1:],
                                    'dhlL',
                                    ['list',
                                     'login',
                                     'help',
                                     'delete'])

        for o, a in gopts:
            if o == "-h" or o == "--help":
                usage()
            if o == "-d" or o == "--delete":
                delete_ind = 1
            if o == "-L" or o == "--list":
                list_ind = 1
            if o == "-l" or o == "--login":
                login_ind = 1

        if list_ind == 0 and len(cmds) < 1:
            usage()

    except getopt.error as error:
        errorExit(_("Options Error %s ") % error.msg)

    except ValueError as e:
        usage()

    if delete_ind:
        sys.exit(chcat_replace(["s0"], cmds, login_ind))

    if list_ind:
        if login_ind:
            sys.exit(listusercats(cmds))
        else:
            if len(cmds) > 0:
                usage()
            sys.exit(listcats())

    if len(cmds) < 2:
        usage()

    set_ind = 0
    cats = cmds[0].split(",")
    mod_ind = 0
    errors = 0
    objects = cmds[1:]
    try:
        if check_replace(cats):
            errors = chcat_replace(translate(cats), objects, login_ind)
        else:
            for c in cats:
                l = []
                l.append(c[1:])
                if len(c) > 0 and c[0] == "+":
                    errors += chcat_add(c[1:], translate(l), objects, login_ind)
                    continue
                if len(c) > 0 and c[0] == "-":
                    errors += chcat_remove(c[1:], translate(l), objects, login_ind)
                    continue
    except ValueError as e:
        error(e)
    except OSError as e:
        error(e)

    sys.exit(errors)