from fastcore.test import *auto_return
This section utilizes the standard Python ast (Abstract Syntax Tree) package to implement the auto_return function. The purpose of this function is to determine whether or not the last line of a given block of code should be returned and to appropriately add the return statement if necessary.
The following are the main steps involved:
- Extract the last statement in the cell code.
- Determine if it can (or should) be returned.
- Modify the code and add the return statement.
Please note that auto_return is the core component of testcell. By utilizing ast, we can handle arbitrary code in a robust and consistent manner.
NOTE ON REFACTORING FROM auto_display TO auto_return
Previously, this code was based on the auto_display function, which wrapped the last line of the code with a display statement if necessary. While this approach appeared to be correct in most situations, it didn’t accurately match the behavior of Jupyter cells.
What a jupyter cell does is to optionally return the value of the last statement; this value is in general an object that can implement __repr__ and __str__ and these dunders will be used by jupyter/IPython to decide how to display that value.
The current implementation instead, attempts to decide whether to add a return statement and defers the display of that value to the actual notebook infrastructure.
last_node
last_node (code)
node_source
node_source (node, code)
sample_code = '''
a = 1
b = 2
c = a+b;
# test
'''
test_eq(node_source(last_node(sample_code),sample_code), 'c = a+b')sample_code = '''
def my_function(x):
    print(aaa)
    return x
    
my_function(123)
'''
test_eq(node_source(last_node(sample_code),sample_code), 'my_function(123)')sample_code = ''
test_eq(node_source(last_node(sample_code),sample_code), None) # No code should display nothingsample_code = '''
for i in [1,2,3]:i
'''
test_eq(node_source(last_node(sample_code),sample_code), None) # should not display anyhtingsample_code = '''
t=0 # sample assignment in the same cell
with open('test.txt') as f: 
    f.readlines()
'''
test_eq(node_source(last_node(sample_code),sample_code), None) # with block should catch implicit outputis_assignment
is_assignment (node)
test_eq( is_assignment( last_node('a = 1\nb = 2\nc = a+b')), True )
test_eq( is_assignment( last_node('a = 1\nb = 2\nc = a+b\nc')), False )
test_eq( is_assignment( last_node('c')), False )
test_eq( is_assignment( last_node('a=1')),True)
test_eq( is_assignment( last_node('a = function_execution()')),True)
test_eq( is_assignment( last_node('a;')),False)
test_eq( is_assignment( last_node('a')),False)
test_eq( is_assignment( last_node('a - function_execution()')),False)extract_call
extract_call (node)
# OK
test_eq(extract_call(last_node('fn()')), 'fn')
test_eq(extract_call(last_node('x.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.z.fn()')), 'fn')
# KO
test_eq(extract_call(last_node('(fn)')), None)
test_eq(extract_call(last_node('fn')), None)
test_eq(extract_call(last_node('(fn(),fn)')), None)
test_eq(extract_call(last_node('(x.y.fn(),fn())')), None)is_import_statement
is_import_statement (node)
test_eq( is_import_statement(last_node('123')) , False )
test_eq( is_import_statement(last_node('func(123)')) , False )
test_eq( is_import_statement(last_node('# test')) , False )
test_eq( is_import_statement(last_node('# import numpy')) , False )
test_eq( is_import_statement(last_node('import numpy')) , True )
test_eq( is_import_statement(last_node('from PIL import Image')) , True )is_ast_node
is_ast_node (x, ref)
test_eq(is_ast_node(last_node('del a'),[ast.Delete]), True)
test_eq(is_ast_node(last_node('del a'),[ast.Assert]), False)
test_eq(is_ast_node(last_node('a==1'),[ast.Assert]), False)NOTE: I can’t came around with any common propety to mark statements like del a and assert b==1. The only way I’ve found is to hardcode a comparison against these language statements.
need_return
need_return (node)
# The following are a bunch of real use cases we want to test.
# NOTE: not considering ";"
# Let's define a test utility function
def test_need_return(code): return need_return(last_node(code))
# SHOULD BE TRUE
test_eq( test_need_return('a') , True )
#test_eq( test_need_return('a;') , False ) # This is not supported with ast: we should do it differently
test_eq( test_need_return('func(a)') , True )
test_eq( test_need_return('{1:1,2:2}') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('1 if True else None'), True)
# SHOULD BE FALSE
test_eq( test_need_return('display(a)') , True )
test_eq( test_need_return('# xxx') , False )
test_eq( test_need_return('print(a)') , True )
test_eq( test_need_return('import xxx') , False )
test_eq( test_need_return('from xxx import yyy') , False )
test_eq( test_need_return('a=1') , False )
test_eq( test_need_return('for a in [1,2,3]: a') , False )
test_eq( test_need_return('del a') , False )
test_eq( test_need_return('a=1; del a') , False )
test_eq( test_need_return('assert a(b)==1') , False )
test_eq( test_need_return('try: a=0\nexcept: a=1') , False )
test_eq( test_need_return('from numpy import array') , False )
test_eq( test_need_return('global a') , False )
test_eq( test_need_return('nonlocal a') , False )end_of_last_line_of_code
end_of_last_line_of_code (code:str, node)
def do_test_end_of_last_line_of_code(sample_code):
    return end_of_last_line_of_code(sample_code, last_node(sample_code))
test_eq( do_test_end_of_last_line_of_code('a=1') , '' ) # one line
test_eq( do_test_end_of_last_line_of_code('a=1\na') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1\na\n') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1 # comment') , ' # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1; # comment') , '; # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1\na; # comment') , '; # comment' ) # two lines and commentNOTE: we need to make the check on ; semicolon using string because ast ignores it.
last_statement_has_semicolon
last_statement_has_semicolon (code)
test_eq( last_statement_has_semicolon('a=1\nb=2') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;') , True )
test_eq( last_statement_has_semicolon('a=1\nb=2\n# test') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;\n# test') , True )
# with comment in the end
test_eq( last_statement_has_semicolon('a=1\na # comment'), False)
test_eq( last_statement_has_semicolon('a="# fake comment"\na; # real comment'), True)Finally we need a way to grab the code till a given ast node in order to properly inject the return statement.
code_till_node
code_till_node (code:str, node)
def do_test_code_till_node(sample_code):
    return code_till_node(sample_code, last_node(sample_code))
    
test_eq( do_test_code_till_node('a=1\na') , 'a=1' ) # two lines
test_eq( do_test_code_till_node('a=1;a') , 'a=1;' ) # inlined
test_eq( do_test_code_till_node('a=1;a\n#') , 'a=1;' ) # with post-comment
test_eq( do_test_code_till_node('a=1;print(\na)') , 'a=1;' ) # multiline instruction
test_eq( do_test_code_till_node('print(1,\n2);print(\na)') , 'print(1,\n2);' ) # all togetherauto_return is the main function of this module, and it determines whether a given line of code should be returned or not. It returns the modified code with that modification applied if necessary.
NOTE: the comment # %%testcell is added to signal the row that has been modified.
auto_return
auto_return (code)
test_eq( auto_return('a=1\na'), 'a=1\nreturn a # %%testcell' )
test_eq( auto_return('a=1\na;'), 'a=1\na;' )print(auto_return('a=1\na'))a=1
return a # %%testcellprint(auto_return('a=3\na;'))a=3
a;