root/branches/mk/cheesecake/model.py

Revision 11, 15.1 kB (checked in by grig, 7 years ago)

Added ez_setup.py so that users are forced to upgrade to new version of setuptools.

Changed installation to alternate directory to use --root instead of --home.

Line 
1 """
2 -Code borrowed from Michael Hudson's docextractor package with the author's permission.
3 -The original code is available at <http://codespeak.net/svn/user/mwh/docextractor/>
4 """
5
6 from compiler import ast
7 import sys
8 import os
9 import cPickle as pickle
10 import __builtin__
11 import sets
12
13 import compiler
14 from compiler.transformer import parse, parseFile
15 from compiler.visitor import walk
16
17 class Documentable(object):
18     def __init__(self, system, prefix, name, docstring, parent=None):
19         self.system = system
20         self.prefix = prefix
21         self.name = name
22         self.docstring = docstring
23         self.parent = parent
24         self.setup()
25     def setup(self):
26         self.contents = {}
27         self.orderedcontents = []
28         self._name2fullname = {}
29     def fullName(self):
30         return self.prefix + self.name
31     def shortdocstring(self):
32         docstring = self.docstring
33         if docstring:
34             docstring = docstring.rstrip()
35             if len(docstring) > 20:
36                 docstring = docstring[:8] + '...' + docstring[-8:]
37         return docstring
38     def __repr__(self):
39         return "%s %r"%(self.__class__.__name__, self.fullName())
40     def name2fullname(self, name):
41         if name in self._name2fullname:
42             return self._name2fullname[name]
43         else:
44             return self.parent.name2fullname(name)
45
46
47 class Package(Documentable):
48     def name2fullname(self, name):
49         raise NameError
50
51 class Module(Documentable):
52     def name2fullname(self, name):
53         if name in self._name2fullname:
54             return self._name2fullname[name]
55         elif name in __builtin__.__dict__:
56             return name
57         else:
58             self.system.warning("optimistic name resolution", name)
59             return name
60
61 class Class(Documentable):
62     def setup(self):
63         super(Class, self).setup()
64         self.bases = []
65     def __repr__(self):
66         return "%s(%r, %r) # %r"%(self.__class__.__name__,
67                                   self.name, self.shortdocstring(),
68                                   self.bases)
69
70 class Function(Documentable):
71     pass
72
73
74 class System(object):
75     Class = Class
76     Module = Module
77     Package = Package
78     Function = Function
79
80     def __init__(self):
81         self.current = None
82         self._stack = []
83         self.allobjects = {}
84         self.orderedallobjects = []
85         self.rootobjects = []
86         self.warnings = {}
87         # importstargraph contains edges {importedby:[imports]} but only
88         # for import * statements
89         self.importstargraph = {}
90         self.func_called = {}
91
92     def _push(self, cls, name, docstring):
93         if self.current:
94             prefix = self.current.fullName() + '.'
95             parent = self.current
96         else:
97             prefix = ''
98             parent = None
99         obj = cls(self, prefix, name, docstring, parent)
100         if parent:
101             parent.orderedcontents.append(obj)
102             parent.contents[name] = obj
103             parent._name2fullname[name] = obj.fullName()
104         else:
105             self.rootobjects.append(obj)
106         self.current = obj
107         self.orderedallobjects.append(obj)
108         fullName = obj.fullName()
109         #print 'push', cls.__name__, fullName
110         if fullName in self.allobjects:
111             obj = self.handleDuplicate(obj)
112         else:
113             self.allobjects[obj.fullName()] = obj
114         return obj
115
116     def handleDuplicate(self, obj):
117         '''This is called when we see two objects with the same
118         .fullName(), for example:
119
120         class C:
121             if something:
122                 def meth(self):
123                     implementation 1
124             else:
125                 def meth(self):
126                     implementation 2
127
128         The default is that the second definition "wins".
129         '''
130         self.warning("duplicate", self.allobjects[obj.fullName()])
131         self.allobjects[obj.fullName()] = obj
132         return obj
133
134
135     def _pop(self, cls):
136         assert isinstance(self.current, cls)
137 ##         if self.current.parent:
138 ##             print 'pop', self.current.fullName(), '->', self.current.parent.fullName()
139 ##         else:
140 ##             print 'pop', self.current.fullName(), '->', self.current.parent
141         self.current = self.current.parent
142
143     def push(self, obj, node=None):
144         self._stack.append(self.current)
145         self.current = obj
146
147     def pop(self, obj):
148         assert self.current is obj, "%r is not %r"%(self.current, obj)
149         self.current = self._stack.pop()
150
151     def pushClass(self, name, docstring):
152         return self._push(self.Class, name, docstring)
153     def popClass(self):
154         self._pop(self.Class)
155
156     def pushModule(self, name, docstring):
157         return self._push(self.Module, name, docstring)
158     def popModule(self):
159         self._pop(self.Module)
160
161     def pushFunction(self, name, docstring, func_called):
162         self.func_called.update(func_called)
163         return self._push(self.Function, name, docstring)
164     def popFunction(self):
165         self._pop(self.Function)
166
167     def pushPackage(self, name, docstring):
168         return self._push(self.Package, name, docstring)
169     def popPackage(self):
170         self._pop(self.Package)
171
172     def report(self):
173         for o in self.rootobjects:
174             self._report(o, '')
175
176     def _report(self, o, indent):
177         print indent, o
178         for o2 in o.orderedcontents:
179             self._report(o2, indent+'  ')
180
181     def resolveAlias(self, n):
182         if '.' not in n:
183             return n
184         mod, clsname = n.split('.')
185         if not mod or mod not in self.allobjects:
186             return n
187         m = self.allobjects[mod]
188         if not isinstance(m, Module):
189             return n
190         if clsname in m._name2fullname:
191             newname = m.name2fullname(clsname)
192             if newname not in self.allobjects:
193                 return self.resolveAlias(newname)
194             else:
195                 return newname
196
197     def resolveAliases(self):
198         for ob in self.orderedallobjects:
199             if not isinstance(ob, Class):
200                 continue
201             for i, b in enumerate(ob.bases):
202                 if b not in self.allobjects:
203                     ob.bases[i] = self.resolveAlias(b)
204
205     def warning(self, type, detail):
206         fn = self.current.fullName()
207         #print fn, type, detail
208         self.warnings.setdefault(type, []).append((fn, detail))
209
210     def objectsOfType(self, cls):
211         for o in self.orderedallobjects:
212             if isinstance(o, cls):
213                 yield o
214
215 def expandModname(system, modname, givewarning=True):
216     c = system.current
217     if '.' in modname:
218         prefix, suffix = modname.split('.', 1)
219         suffix = '.' + suffix
220     else:
221         prefix, suffix = modname, ''
222     while c is not None and not isinstance(c, Package):
223         c = c.parent
224     while c is not None:
225         if prefix in c.contents:
226             break
227         c = c.parent
228     if c is not None:
229         if givewarning:
230             system.warning("local import", modname)
231         return c.contents[prefix].fullName() + suffix
232     else:
233         return prefix + suffix
234
235 class ImportStarFinder(object):
236     def __init__(self, system, modfullname):
237         self.system = system
238         self.modfullname = modfullname
239
240     def visitFrom(self, node):
241         if node.names[0][0] == '*':
242             modname = expandModname(self.system, node.modname, False)
243             self.system.importstargraph.setdefault(
244                 modname, []).append(self.modfullname)
245
246 class ModuleVistor(object):
247     def __init__(self, system, modname):
248         self.system = system
249         self.modname = modname
250         self.morenodes = []
251
252     def default(self, node):
253         for child in node.getChildNodes():
254             self.visit(child)
255
256     def postpone(self, docable, node):
257         self.morenodes.append((docable, node))
258
259     def visitModule(self, node):
260         if self.system.current and self.modname in self.system.current.contents:
261             m = self.system.current.contents[self.modname]
262             assert m.docstring is None
263             m.docstring = node.doc
264             self.system.push(m, node)
265             self.default(node)
266             self.system.pop(m)
267         else:
268             self.system.pushModule(self.modname, node.doc)
269             self.default(node)
270             self.system.popModule()
271
272     def visitClass(self, node):
273         cls = self.system.pushClass(node.name, node.doc)
274         for n in node.bases:
275             if isinstance(n, ast.Name):
276                 cls.bases.append(cls.parent.name2fullname(n.name))
277             elif isinstance(n, ast.Getattr):
278                 p = []
279                 while isinstance(n, ast.Getattr):
280                     p.append(n.attrname)
281                     n = n.expr
282                 assert isinstance(n, ast.Name)
283                 p.append(cls.parent.name2fullname(n.name))
284                 p.reverse()
285                 assert None not in p, n
286                 cls.bases.append('.'.join(p))
287             else:
288                 assert not n
289         self.default(node)
290         self.system.popClass()
291
292     def visitFrom(self, node):
293         modname = expandModname(self.system, node.modname)
294         name2fullname = self.system.current._name2fullname
295         for fromname, asname in node.names:
296             if fromname == '*':
297                 self.system.warning("import *", modname)
298                 if modname not in self.system.allobjects:
299                     return
300                 mod = self.system.allobjects[modname]
301                 #snarl (see below)
302                 #assert mod.processed
303                 self.system.warning("mwh is an idiot", "")
304                 for n in mod.contents:
305                     name2fullname[n] = modname + '.' + n
306                 return
307             if asname is None:
308                 asname = fromname
309             name2fullname[asname] = modname + '.' + fromname
310
311     def visitImport(self, node):
312         name2fullname = self.system.current._name2fullname
313         for fromname, asname in node.names:
314             fullname = expandModname(self.system, fromname)
315             if asname is None:
316                 asname = fromname.split('.', 1)[0]
317                 # aaaaargh! python sucks.
318                 parts = fullname.split('.')
319                 for i, part in enumerate(fullname.split('.')[::-1]):
320                     if part == asname:
321                         fullname = '.'.join(parts[:len(parts)-i])
322                         name2fullname[asname] = fullname
323                         break
324                 else:
325                     name2fullname[asname] = '.'.join(parts)
326             else:
327                 name2fullname[asname] = fullname
328
329
330     def visitFunction(self, node):
331         fc = {}
332         get_function_calls(node, fc)
333         #print fc.keys()
334         func = self.system.pushFunction(node.name, node.doc, fc)
335         # ast.Function has a pretty lame representation of
336         # arguments. Let's convert it to a nice concise
337         # getargspec-like format and include it in the Function
338         # object.
339         argnames = node.argnames[:]
340         kwname = starargname = None
341         if node.kwargs:
342             kwname = argnames.pop(-1)
343         if node.varargs:
344             starargname = argnames.pop(-1)
345         defaults = []
346         for default in node.defaults:
347             if isinstance(default, ast.Const):
348                 defaults.append(default.value)
349             elif isinstance(default, ast.Name):
350                 defaults.append(default.name)
351             else:
352                 self.system.warning("unparseable default", repr(default))
353                 defaults.append('???')
354                 #assert False, "don't know how to handle default %r"%(default,)
355         # argh, convert unpacked-arguments from tuples to lists,
356         # because that's what getargspec uses and the unit test
357         # compares it
358         argnames2 = []
359         for argname in argnames:
360             if isinstance(argname, tuple):
361                 argname = list(argname)
362             argnames2.append(argname)
363         func.argspec = (argnames2, starargname, kwname, tuple(defaults))
364         #for child in node.getChildren():
365         #    if isinstance(child, compiler.ast.Stmt):
366         #        for c in child.getChildren():
367         #            print c.__class__
368         #            print c
369         self.postpone(func, node.code)
370         self.system.popFunction()
371
372 def get_function_calls(node, fc):
373     if not isinstance(node, compiler.ast.Node):
374         return
375     for child in node.getChildren():
376         #print "child:", child
377         if isinstance(child, compiler.ast.CallFunc):
378                 funcname = ""
379                 attrname = ""
380                 n = child.node
381                 #print "n:", n
382                 #print n.__class__
383                 if isinstance(n, compiler.ast.Getattr):
384                     expr = n.expr
385                     if isinstance(expr, compiler.ast.Name):
386                         funcname = expr.name
387                     attrname = n.attrname
388                 func_called = ""
389                 if funcname: func_called = funcname + "."
390                 func_called += attrname
391                 if func_called:
392                     fc[func_called] = 1
393         get_function_calls(child, fc)
394    
395 def processModuleAst(ast, name, system):
396     mv = ModuleVistor(system, name)
397     walk(ast, mv)
398     while mv.morenodes:
399         obj, node = mv.morenodes.pop(0)
400         system.push(obj, node)
401         mv.visit(node)
402         system.pop(obj)
403
404
405 def fromText(src, modname='<test>', system=None):
406     if system is None:
407         system = System()
408     processModuleAst(parse(src), modname, system)
409     return system.rootobjects[0]
410
411
412 def preprocessDirectory(system, dirpath):
413     package = system.pushPackage(os.path.basename(dirpath), None)
414     for fname in os.listdir(dirpath):
415         fullname = os.path.join(dirpath, fname)
416         if os.path.isdir(fullname) and os.path.exists(os.path.join(fullname, '__init__.py')):
417             preprocessDirectory(system, fullname)
418         elif fname.endswith('.py'):
419             modname = os.path.splitext(fname)[0]
420             mod = system.pushModule(modname, None)
421             mod.filepath = fullname
422             mod.processed = False
423             system.popModule()
424     system.popPackage()
425
426 def processDirectory(system, dirpath):
427     preprocessDirectory(system, dirpath)
428     modlist = list(system.objectsOfType(Module))
429     for mod in modlist:
430         system.push(mod.parent)
431         isf = ImportStarFinder(system, mod.fullName())
432         walk(parseFile(mod.filepath), isf)
433         system.pop(mod.parent)
434
435     # snarl; a toposort is meant to go here.
436     newlist = modlist
437
438     for mod in newlist:
439         system.push(mod.parent)
440         processModuleAst(parseFile(mod.filepath), mod.name, system)
441         mod.processed = True
442         system.pop(mod.parent)
443
444 def main(argv):
445     if '-r' in argv:
446         argv.remove('-r')
447         assert len(argv) == 1
448         system = System()
449         processDirectory(system, argv[0])
450         pickle.dump(system, open('da.out', 'wb'), pickle.HIGHEST_PROTOCOL)
451         print
452         print 'warning summary:'
453         for k, v in system.warnings.iteritems():
454             print k, len(v)
455     else:
456         system = System()
457         for fname in argv:
458             modname = os.path.splitext(os.path.basename(fname))[0] # XXX!
459             processModuleAst(parseFile(fname), modname, system)
460         system.report()
461
462
463
464 if __name__ == '__main__':
465     main(sys.argv[1:])
Note: See TracBrowser for help on using the browser.