auto_return

This section utilizes the standard Python ast (Abstract Syntax Tree) package to implement the auto_return function. The purpose of this function is to determine whether or not the last line of a given block of code should be returned and to appropriately add the return statement if necessary.

The following are the main steps involved:

Please note that auto_return is the core component of testcell. By utilizing ast, we can handle arbitrary code in a robust and consistent manner.

NOTE ON REFACTORING FROM auto_display TO auto_return

Previously, this code was based on the auto_display function, which wrapped the last line of the code with a display statement if necessary. While this approach appeared to be correct in most situations, it didn’t accurately match the behavior of Jupyter cells.

What a jupyter cell does is to optionally return the value of the last statement; this value is in general an object that can implement __repr__ and __str__ and these dunders will be used by jupyter/IPython to decide how to display that value.

The current implementation instead, attempts to decide whether to add a return statement and defers the display of that value to the actual notebook infrastructure.


source

last_node

 last_node (code)
from fastcore.test import *

source

node_source

 node_source (node, code)
sample_code = '''
a = 1
b = 2
c = a+b;
# test
'''
test_eq(node_source(last_node(sample_code),sample_code), 'c = a+b')
sample_code = '''
def my_function(x):
    print(aaa)
    return x
    
my_function(123)
'''
test_eq(node_source(last_node(sample_code),sample_code), 'my_function(123)')
sample_code = ''
test_eq(node_source(last_node(sample_code),sample_code), None) # No code should display nothing
sample_code = '''
for i in [1,2,3]:i
'''
test_eq(node_source(last_node(sample_code),sample_code), None) # should not display anyhting
sample_code = '''
t=0 # sample assignment in the same cell
with open('test.txt') as f: 
    f.readlines()
'''
test_eq(node_source(last_node(sample_code),sample_code), None) # with block should catch implicit output

source

is_assignment

 is_assignment (node)
test_eq( is_assignment( last_node('a = 1\nb = 2\nc = a+b')), True )
test_eq( is_assignment( last_node('a = 1\nb = 2\nc = a+b\nc')), False )
test_eq( is_assignment( last_node('c')), False )
test_eq( is_assignment( last_node('a=1')),True)
test_eq( is_assignment( last_node('a = function_execution()')),True)
test_eq( is_assignment( last_node('a;')),False)
test_eq( is_assignment( last_node('a')),False)
test_eq( is_assignment( last_node('a - function_execution()')),False)

source

extract_call

 extract_call (node)
# OK
test_eq(extract_call(last_node('fn()')), 'fn')
test_eq(extract_call(last_node('x.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.fn()')), 'fn')
test_eq(extract_call(last_node('x.y.z.fn()')), 'fn')

# KO
test_eq(extract_call(last_node('(fn)')), None)
test_eq(extract_call(last_node('fn')), None)
test_eq(extract_call(last_node('(fn(),fn)')), None)
test_eq(extract_call(last_node('(x.y.fn(),fn())')), None)

source

is_import_statement

 is_import_statement (node)
test_eq( is_import_statement(last_node('123')) , False )
test_eq( is_import_statement(last_node('func(123)')) , False )
test_eq( is_import_statement(last_node('# test')) , False )
test_eq( is_import_statement(last_node('# import numpy')) , False )
test_eq( is_import_statement(last_node('import numpy')) , True )
test_eq( is_import_statement(last_node('from PIL import Image')) , True )

source

is_ast_node

 is_ast_node (x, ref)
test_eq(is_ast_node(last_node('del a'),[ast.Delete]), True)
test_eq(is_ast_node(last_node('del a'),[ast.Assert]), False)
test_eq(is_ast_node(last_node('a==1'),[ast.Assert]), False)

NOTE: I can’t came around with any common propety to mark statements like del a and assert b==1. The only way I’ve found is to hardcode a comparison against these language statements.


source

need_return

 need_return (node)
# The following are a bunch of real use cases we want to test.
# NOTE: not considering ";"

# Let's define a test utility function
def test_need_return(code): return need_return(last_node(code))

# SHOULD BE TRUE
test_eq( test_need_return('a') , True )
#test_eq( test_need_return('a;') , False ) # This is not supported with ast: we should do it differently
test_eq( test_need_return('func(a)') , True )
test_eq( test_need_return('{1:1,2:2}') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('a in b') , True )
test_eq( test_need_return('1 if True else None'), True)

# SHOULD BE FALSE
test_eq( test_need_return('display(a)') , True )
test_eq( test_need_return('# xxx') , False )
test_eq( test_need_return('print(a)') , True )
test_eq( test_need_return('import xxx') , False )
test_eq( test_need_return('from xxx import yyy') , False )
test_eq( test_need_return('a=1') , False )
test_eq( test_need_return('for a in [1,2,3]: a') , False )
test_eq( test_need_return('del a') , False )
test_eq( test_need_return('a=1; del a') , False )
test_eq( test_need_return('assert a(b)==1') , False )
test_eq( test_need_return('try: a=0\nexcept: a=1') , False )
test_eq( test_need_return('from numpy import array') , False )
test_eq( test_need_return('global a') , False )
test_eq( test_need_return('nonlocal a') , False )

source

end_of_last_line_of_code

 end_of_last_line_of_code (code:str, node)
def do_test_end_of_last_line_of_code(sample_code):
    return end_of_last_line_of_code(sample_code, last_node(sample_code))

test_eq( do_test_end_of_last_line_of_code('a=1') , '' ) # one line
test_eq( do_test_end_of_last_line_of_code('a=1\na') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1\na\n') , '' ) # two lines
test_eq( do_test_end_of_last_line_of_code('a=1 # comment') , ' # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1; # comment') , '; # comment' ) # one line and comment
test_eq( do_test_end_of_last_line_of_code('a=1\na; # comment') , '; # comment' ) # two lines and comment

NOTE: we need to make the check on ; semicolon using string because ast ignores it.


source

last_statement_has_semicolon

 last_statement_has_semicolon (code)
test_eq( last_statement_has_semicolon('a=1\nb=2') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;') , True )
test_eq( last_statement_has_semicolon('a=1\nb=2\n# test') , False )
test_eq( last_statement_has_semicolon('a=1\nb=2;\n# test') , True )

# with comment in the end
test_eq( last_statement_has_semicolon('a=1\na # comment'), False)
test_eq( last_statement_has_semicolon('a="# fake comment"\na; # real comment'), True)

Finally we need a way to grab the code till a given ast node in order to properly inject the return statement.


source

code_till_node

 code_till_node (code:str, node)
def do_test_code_till_node(sample_code):
    return code_till_node(sample_code, last_node(sample_code))
    
test_eq( do_test_code_till_node('a=1\na') , 'a=1' ) # two lines
test_eq( do_test_code_till_node('a=1;a') , 'a=1;' ) # inlined
test_eq( do_test_code_till_node('a=1;a\n#') , 'a=1;' ) # with post-comment
test_eq( do_test_code_till_node('a=1;print(\na)') , 'a=1;' ) # multiline instruction
test_eq( do_test_code_till_node('print(1,\n2);print(\na)') , 'print(1,\n2);' ) # all together

auto_return is the main function of this module, and it determines whether a given line of code should be returned or not. It returns the modified code with that modification applied if necessary.

NOTE: the comment # %%testcell is added to signal the row that has been modified.


source

auto_return

 auto_return (code)
test_eq( auto_return('a=1\na'), 'a=1\nreturn a # %%testcell' )
test_eq( auto_return('a=1\na;'), 'a=1\na;' )
print(auto_return('a=1\na'))
a=1
return a # %%testcell
print(auto_return('a=3\na;'))
a=3
a;