from fastcore.test import *
auto_return
This section utilizes the standard Python ast
(Abstract Syntax Tree) package to implement the auto_return
function. The purpose of this function is to determine whether or not the last line of a given block of code should be returned and to appropriately add the return
statement if necessary.
The following are the main steps involved:
- Extract the last statement in the cell code.
- Determine if it can (or should) be returned.
- Modify the code and add the return statement.
Please note that auto_return
is the core component of testcell
. By utilizing ast, we can handle arbitrary code in a robust and consistent manner.
NOTE ON REFACTORING FROM auto_display
TO auto_return
Previously, this code was based on the auto_display
function, which wrapped the last line of the code with a display
statement if necessary. While this approach appeared to be correct in most situations, it didn’t accurately match the behavior of Jupyter cells.
What a jupyter cell does is to optionally return the value of the last statement; this value is in general an object that can implement __repr__
and __str__
and these dunders will be used by jupyter/IPython to decide how to display that value.
The current implementation instead, attempts to decide whether to add a return
statement and defers the display of that value to the actual notebook infrastructure.
last_node
last_node (code)
node_source
node_source (node, code)
= '''
sample_code a = 1
b = 2
c = a+b;
# test
'''
'c = a+b') test_eq(node_source(last_node(sample_code),sample_code),
= '''
sample_code def my_function(x):
print(aaa)
return x
my_function(123)
'''
'my_function(123)') test_eq(node_source(last_node(sample_code),sample_code),
= ''
sample_code None) # No code should display nothing test_eq(node_source(last_node(sample_code),sample_code),
= '''
sample_code for i in [1,2,3]:i
'''
None) # should not display anyhting test_eq(node_source(last_node(sample_code),sample_code),
= '''
sample_code t=0 # sample assignment in the same cell
with open('test.txt') as f:
f.readlines()
'''
None) # with block should catch implicit output test_eq(node_source(last_node(sample_code),sample_code),
is_assignment
is_assignment (node)
'a = 1\nb = 2\nc = a+b')), True )
test_eq( is_assignment( last_node('a = 1\nb = 2\nc = a+b\nc')), False )
test_eq( is_assignment( last_node('c')), False )
test_eq( is_assignment( last_node('a=1')),True)
test_eq( is_assignment( last_node('a = function_execution()')),True)
test_eq( is_assignment( last_node('a;')),False)
test_eq( is_assignment( last_node('a')),False)
test_eq( is_assignment( last_node('a - function_execution()')),False) test_eq( is_assignment( last_node(
extract_call
extract_call (node)
# OK
'fn()')), 'fn')
test_eq(extract_call(last_node('x.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.z.fn()')), 'fn')
test_eq(extract_call(last_node(
# KO
'(fn)')), None)
test_eq(extract_call(last_node('fn')), None)
test_eq(extract_call(last_node('(fn(),fn)')), None)
test_eq(extract_call(last_node('(x.y.fn(),fn())')), None) test_eq(extract_call(last_node(
is_import_statement
is_import_statement (node)
'123')) , False )
test_eq( is_import_statement(last_node('func(123)')) , False )
test_eq( is_import_statement(last_node('# test')) , False )
test_eq( is_import_statement(last_node('# import numpy')) , False )
test_eq( is_import_statement(last_node('import numpy')) , True )
test_eq( is_import_statement(last_node('from PIL import Image')) , True ) test_eq( is_import_statement(last_node(
is_ast_node
is_ast_node (x, ref)
'del a'),[ast.Delete]), True)
test_eq(is_ast_node(last_node('del a'),[ast.Assert]), False)
test_eq(is_ast_node(last_node('a==1'),[ast.Assert]), False) test_eq(is_ast_node(last_node(
NOTE: I can’t came around with any common propety to mark statements like del a
and assert b==1
. The only way I’ve found is to hardcode a comparison against these language statements.
need_return
need_return (node)
# The following are a bunch of real use cases we want to test.
# NOTE: not considering ";"
# Let's define a test utility function
def test_need_return(code): return need_return(last_node(code))
# SHOULD BE TRUE
'a') , True )
test_eq( test_need_return(#test_eq( test_need_return('a;') , False ) # This is not supported with ast: we should do it differently
'func(a)') , True )
test_eq( test_need_return('{1:1,2:2}') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('1 if True else None'), True)
test_eq( test_need_return(
# SHOULD BE FALSE
'display(a)') , True )
test_eq( test_need_return('# xxx') , False )
test_eq( test_need_return('print(a)') , True )
test_eq( test_need_return('import xxx') , False )
test_eq( test_need_return('from xxx import yyy') , False )
test_eq( test_need_return('a=1') , False )
test_eq( test_need_return('for a in [1,2,3]: a') , False )
test_eq( test_need_return('del a') , False )
test_eq( test_need_return('a=1; del a') , False )
test_eq( test_need_return('assert a(b)==1') , False )
test_eq( test_need_return('try: a=0\nexcept: a=1') , False )
test_eq( test_need_return('from numpy import array') , False )
test_eq( test_need_return('global a') , False )
test_eq( test_need_return('nonlocal a') , False ) test_eq( test_need_return(
end_of_last_line_of_code
end_of_last_line_of_code (code:str, node)
def do_test_end_of_last_line_of_code(sample_code):
return end_of_last_line_of_code(sample_code, last_node(sample_code))
'a=1') , '' ) # one line
test_eq( do_test_end_of_last_line_of_code('a=1\na') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1\na\n') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1 # comment') , ' # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1; # comment') , '; # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1\na; # comment') , '; # comment' ) # two lines and comment test_eq( do_test_end_of_last_line_of_code(
NOTE: we need to make the check on ;
semicolon using string because ast
ignores it.
last_statement_has_semicolon
last_statement_has_semicolon (code)
'a=1\nb=2') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;') , True )
test_eq( last_statement_has_semicolon('a=1\nb=2\n# test') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;\n# test') , True )
test_eq( last_statement_has_semicolon(
# with comment in the end
'a=1\na # comment'), False)
test_eq( last_statement_has_semicolon('a="# fake comment"\na; # real comment'), True) test_eq( last_statement_has_semicolon(
Finally we need a way to grab the code till a given ast
node in order to properly inject the return
statement.
code_till_node
code_till_node (code:str, node)
def do_test_code_till_node(sample_code):
return code_till_node(sample_code, last_node(sample_code))
'a=1\na') , 'a=1' ) # two lines
test_eq( do_test_code_till_node('a=1;a') , 'a=1;' ) # inlined
test_eq( do_test_code_till_node('a=1;a\n#') , 'a=1;' ) # with post-comment
test_eq( do_test_code_till_node('a=1;print(\na)') , 'a=1;' ) # multiline instruction
test_eq( do_test_code_till_node('print(1,\n2);print(\na)') , 'print(1,\n2);' ) # all together test_eq( do_test_code_till_node(
auto_return
is the main function of this module, and it determines whether a given line of code should be returned or not. It returns the modified code with that modification applied if necessary.
NOTE: the comment # %%testcell
is added to signal the row that has been modified.
auto_return
auto_return (code)
'a=1\na'), 'a=1\nreturn a # %%testcell' )
test_eq( auto_return('a=1\na;'), 'a=1\na;' ) test_eq( auto_return(
print(auto_return('a=1\na'))
a=1
return a # %%testcell
print(auto_return('a=3\na;'))
a=3
a;