| # |
| # Copyright (C) 2005 Fredrik Kuivinen |
| # |
| |
| import sys, re, os, traceback |
| from sets import Set |
| |
| def die(*args): |
| printList(args, sys.stderr) |
| sys.exit(2) |
| |
| def printList(list, file=sys.stdout): |
| for x in list: |
| file.write(str(x)) |
| file.write(' ') |
| file.write('\n') |
| |
| import subprocess |
| |
| # Debugging machinery |
| # ------------------- |
| |
| DEBUG = 0 |
| functionsToDebug = Set() |
| |
| def addDebug(func): |
| if type(func) == str: |
| functionsToDebug.add(func) |
| else: |
| functionsToDebug.add(func.func_name) |
| |
| def debug(*args): |
| if DEBUG: |
| funcName = traceback.extract_stack()[-2][2] |
| if funcName in functionsToDebug: |
| printList(args) |
| |
| # Program execution |
| # ----------------- |
| |
| class ProgramError(Exception): |
| def __init__(self, progStr, error): |
| self.progStr = progStr |
| self.error = error |
| |
| def __str__(self): |
| return self.progStr + ': ' + self.error |
| |
| addDebug('runProgram') |
| def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True): |
| debug('runProgram prog:', str(prog), 'input:', str(input)) |
| if type(prog) is str: |
| progStr = prog |
| else: |
| progStr = ' '.join(prog) |
| |
| try: |
| if pipeOutput: |
| stderr = subprocess.STDOUT |
| stdout = subprocess.PIPE |
| else: |
| stderr = None |
| stdout = None |
| pop = subprocess.Popen(prog, |
| shell = type(prog) is str, |
| stderr=stderr, |
| stdout=stdout, |
| stdin=subprocess.PIPE, |
| env=env) |
| except OSError, e: |
| debug('strerror:', e.strerror) |
| raise ProgramError(progStr, e.strerror) |
| |
| if input != None: |
| pop.stdin.write(input) |
| pop.stdin.close() |
| |
| if pipeOutput: |
| out = pop.stdout.read() |
| else: |
| out = '' |
| |
| code = pop.wait() |
| if returnCode: |
| ret = [out, code] |
| else: |
| ret = out |
| if code != 0 and not returnCode: |
| debug('error output:', out) |
| debug('prog:', prog) |
| raise ProgramError(progStr, out) |
| # debug('output:', out.replace('\0', '\n')) |
| return ret |
| |
| # Code for computing common ancestors |
| # ----------------------------------- |
| |
| currentId = 0 |
| def getUniqueId(): |
| global currentId |
| currentId += 1 |
| return currentId |
| |
| # The 'virtual' commit objects have SHAs which are integers |
| shaRE = re.compile('^[0-9a-f]{40}$') |
| def isSha(obj): |
| return (type(obj) is str and bool(shaRE.match(obj))) or \ |
| (type(obj) is int and obj >= 1) |
| |
| class Commit(object): |
| __slots__ = ['parents', 'firstLineMsg', 'children', '_tree', 'sha', |
| 'virtual'] |
| |
| def __init__(self, sha, parents, tree=None): |
| self.parents = parents |
| self.firstLineMsg = None |
| self.children = [] |
| |
| if tree: |
| tree = tree.rstrip() |
| assert(isSha(tree)) |
| self._tree = tree |
| |
| if not sha: |
| self.sha = getUniqueId() |
| self.virtual = True |
| self.firstLineMsg = 'virtual commit' |
| assert(isSha(tree)) |
| else: |
| self.virtual = False |
| self.sha = sha.rstrip() |
| assert(isSha(self.sha)) |
| |
| def tree(self): |
| self.getInfo() |
| assert(self._tree != None) |
| return self._tree |
| |
| def shortInfo(self): |
| self.getInfo() |
| return str(self.sha) + ' ' + self.firstLineMsg |
| |
| def __str__(self): |
| return self.shortInfo() |
| |
| def getInfo(self): |
| if self.virtual or self.firstLineMsg != None: |
| return |
| else: |
| info = runProgram(['git-cat-file', 'commit', self.sha]) |
| info = info.split('\n') |
| msg = False |
| for l in info: |
| if msg: |
| self.firstLineMsg = l |
| break |
| else: |
| if l.startswith('tree'): |
| self._tree = l[5:].rstrip() |
| elif l == '': |
| msg = True |
| |
| class Graph: |
| def __init__(self): |
| self.commits = [] |
| self.shaMap = {} |
| |
| def addNode(self, node): |
| assert(isinstance(node, Commit)) |
| self.shaMap[node.sha] = node |
| self.commits.append(node) |
| for p in node.parents: |
| p.children.append(node) |
| return node |
| |
| def reachableNodes(self, n1, n2): |
| res = {} |
| def traverse(n): |
| res[n] = True |
| for p in n.parents: |
| traverse(p) |
| |
| traverse(n1) |
| traverse(n2) |
| return res |
| |
| def fixParents(self, node): |
| for x in range(0, len(node.parents)): |
| node.parents[x] = self.shaMap[node.parents[x]] |
| |
| # addDebug('buildGraph') |
| def buildGraph(heads): |
| debug('buildGraph heads:', heads) |
| for h in heads: |
| assert(isSha(h)) |
| |
| g = Graph() |
| |
| out = runProgram(['git-rev-list', '--parents'] + heads) |
| for l in out.split('\n'): |
| if l == '': |
| continue |
| shas = l.split(' ') |
| |
| # This is a hack, we temporarily use the 'parents' attribute |
| # to contain a list of SHA1:s. They are later replaced by proper |
| # Commit objects. |
| c = Commit(shas[0], shas[1:]) |
| |
| g.commits.append(c) |
| g.shaMap[c.sha] = c |
| |
| for c in g.commits: |
| g.fixParents(c) |
| |
| for c in g.commits: |
| for p in c.parents: |
| p.children.append(c) |
| return g |
| |
| # Write the empty tree to the object database and return its SHA1 |
| def writeEmptyTree(): |
| tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index' |
| def delTmpIndex(): |
| try: |
| os.unlink(tmpIndex) |
| except OSError: |
| pass |
| delTmpIndex() |
| newEnv = os.environ.copy() |
| newEnv['GIT_INDEX_FILE'] = tmpIndex |
| res = runProgram(['git-write-tree'], env=newEnv).rstrip() |
| delTmpIndex() |
| return res |
| |
| def addCommonRoot(graph): |
| roots = [] |
| for c in graph.commits: |
| if len(c.parents) == 0: |
| roots.append(c) |
| |
| superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree()) |
| graph.addNode(superRoot) |
| for r in roots: |
| r.parents = [superRoot] |
| superRoot.children = roots |
| return superRoot |
| |
| def getCommonAncestors(graph, commit1, commit2): |
| '''Find the common ancestors for commit1 and commit2''' |
| assert(isinstance(commit1, Commit) and isinstance(commit2, Commit)) |
| |
| def traverse(start, set): |
| stack = [start] |
| while len(stack) > 0: |
| el = stack.pop() |
| set.add(el) |
| for p in el.parents: |
| if p not in set: |
| stack.append(p) |
| h1Set = Set() |
| h2Set = Set() |
| traverse(commit1, h1Set) |
| traverse(commit2, h2Set) |
| shared = h1Set.intersection(h2Set) |
| |
| if len(shared) == 0: |
| shared = [addCommonRoot(graph)] |
| |
| res = Set() |
| |
| for s in shared: |
| if len([c for c in s.children if c in shared]) == 0: |
| res.add(s) |
| return list(res) |