Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions scripts/microgenerator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CodeAnalyzer(ast.NodeVisitor):
"""

def __init__(self):
self.structure: List[Dict[str, Any]] = []
self.analyzed_classes: List[Dict[str, Any]] = []
self.imports: set[str] = set()
self.types: set[str] = set()
self._current_class_info: Dict[str, Any] | None = None
Expand Down Expand Up @@ -106,13 +106,19 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:
if type_str:
self.types.add(type_str)
elif isinstance(node, ast.Subscript):
self._collect_types_from_node(node.value)
# Add the base type of the subscript (e.g., "List", "Dict")
if isinstance(node.value, ast.Name):
self.types.add(node.value.id)
self._collect_types_from_node(node.value) # Recurse on value just in case
self._collect_types_from_node(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for elt in node.elts:
self._collect_types_from_node(elt)
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
self.types.add(node.value)
elif isinstance(node, ast.Constant):
if isinstance(node.value, str): # Forward references
self.types.add(node.value)
elif node.value is None: # None type
self.types.add("None")
elif isinstance(node, ast.BinOp) and isinstance(
node.op, ast.BitOr
): # For | union type
Expand Down Expand Up @@ -164,7 +170,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
type_str = self._get_type_str(item.annotation)
class_info["attributes"].append({"name": attr_name, "type": type_str})

self.structure.append(class_info)
self.analyzed_classes.append(class_info)
self._current_class_info = class_info
self._depth += 1
self.generic_visit(node)
Expand Down Expand Up @@ -260,6 +266,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
# directly within the class body, not inside a method.
elif isinstance(target, ast.Name) and not self._is_in_method:
self._add_attribute(target.id, self._get_type_str(node.annotation))
self._collect_types_from_node(node.annotation)
self.generic_visit(node)


Expand All @@ -280,7 +287,7 @@ def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
tree = ast.parse(code)
analyzer = CodeAnalyzer()
analyzer.visit(tree)
return analyzer.structure, analyzer.imports, analyzer.types
return analyzer.analyzed_classes, analyzer.imports, analyzer.types


def parse_file(file_path: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
Expand Down Expand Up @@ -332,10 +339,10 @@ def list_code_objects(
all_class_keys = []

def process_structure(
structure: List[Dict[str, Any]], file_name: str | None = None
analyzed_classes: List[Dict[str, Any]], file_name: str | None = None
):
"""Populates the results dictionary from the parsed AST structure."""
for class_info in structure:
for class_info in analyzed_classes:
key = class_info["class_name"]
if file_name:
key = f"{key} (in {file_name})"
Expand All @@ -361,13 +368,13 @@ def process_structure(

# Determine if the path is a file or directory and process accordingly
if os.path.isfile(path) and path.endswith(".py"):
structure, _, _ = parse_file(path)
process_structure(structure)
analyzed_classes, _, _ = parse_file(path)
process_structure(analyzed_classes)
elif os.path.isdir(path):
# This assumes `utils.walk_codebase` is defined elsewhere.
for file_path in utils.walk_codebase(path):
structure, _, _ = parse_file(file_path)
process_structure(structure, file_name=os.path.basename(file_path))
analyzed_classes, _, _ = parse_file(file_path)
process_structure(analyzed_classes, file_name=os.path.basename(file_path))

# Return the data in the desired format based on the flags
if not show_methods and not show_attributes:
Expand Down Expand Up @@ -419,11 +426,11 @@ def _build_request_arg_schema(
module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, ".")

try:
structure, _, _ = parse_file(file_path)
if not structure:
analyzed_classes, _, _ = parse_file(file_path)
if not analyzed_classes:
continue

for class_info in structure:
for class_info in analyzed_classes:
class_name = class_info.get("class_name", "Unknown")
if class_name.endswith("Request"):
full_class_name = f"{module_name}.{class_name}"
Expand Down Expand Up @@ -451,11 +458,11 @@ def _process_service_clients(
if "/services/" not in file_path:
continue

structure, imports, types = parse_file(file_path)
analyzed_classes, imports, types = parse_file(file_path)
all_imports.update(imports)
all_types.update(types)

for class_info in structure:
for class_info in analyzed_classes:
class_name = class_info["class_name"]
if not _should_include_class(class_name, class_filters):
continue
Expand Down
6 changes: 2 additions & 4 deletions scripts/microgenerator/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from functools import wraps
import pathlib
import os
import nox
import time

Expand All @@ -26,7 +25,7 @@
BLACK_VERSION = "black==23.7.0"
BLACK_PATHS = (".",)

DEFAULT_PYTHON_VERSION = "3.9"
DEFAULT_PYTHON_VERSION = "3.13"
UNIT_TEST_PYTHON_VERSIONS = ["3.9", "3.11", "3.12", "3.13"]
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()

Expand Down Expand Up @@ -190,9 +189,8 @@ def lint(session):
session.install("flake8", BLACK_VERSION)
session.install("-e", ".")
session.run("python", "-m", "pip", "freeze")
session.run("flake8", os.path.join("scripts"))
session.run("flake8", ".")
session.run("flake8", "tests")
session.run("flake8", "benchmark")
session.run("black", "--check", *BLACK_PATHS)


Expand Down
45 changes: 24 additions & 21 deletions scripts/microgenerator/tests/unit/test_generate_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_import_extraction(self, code_snippet, expected_imports):

class TestCodeAnalyzerAttributes:
@pytest.mark.parametrize(
"code_snippet, expected_structure",
"code_snippet, expected_analyzed_classes",
[
pytest.param(
"""
Expand Down Expand Up @@ -243,22 +243,24 @@ def __init__(self):
),
],
)
def test_attribute_extraction(self, code_snippet: str, expected_structure: list):
def test_attribute_extraction(
self, code_snippet: str, expected_analyzed_classes: list
):
"""Tests the extraction of class and instance attributes."""
analyzer = CodeAnalyzer()
tree = ast.parse(code_snippet)
analyzer.visit(tree)

extracted = analyzer.structure
extracted = analyzer.analyzed_classes
# Normalize attributes for order-independent comparison
for item in extracted:
if "attributes" in item:
item["attributes"].sort(key=lambda x: x["name"])
for item in expected_structure:
for item in expected_analyzed_classes:
if "attributes" in item:
item["attributes"].sort(key=lambda x: x["name"])

assert extracted == expected_structure
assert extracted == expected_analyzed_classes


# --- Mock Types ---
Expand All @@ -284,8 +286,8 @@ class MyClass:
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 1
assert analyzer.structure[0]["class_name"] == "MyClass"
assert len(analyzer.analyzed_classes) == 1
assert analyzer.analyzed_classes[0]["class_name"] == "MyClass"


def test_codeanalyzer_finds_multiple_classes():
Expand All @@ -302,8 +304,8 @@ class ClassB:
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 2
class_names = sorted([c["class_name"] for c in analyzer.structure])
assert len(analyzer.analyzed_classes) == 2
class_names = sorted([c["class_name"] for c in analyzer.analyzed_classes])
assert class_names == ["ClassA", "ClassB"]


Expand All @@ -318,9 +320,9 @@ def my_method(self):
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 1
assert len(analyzer.structure[0]["methods"]) == 1
assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method"
assert len(analyzer.analyzed_classes) == 1
assert len(analyzer.analyzed_classes[0]["methods"]) == 1
assert analyzer.analyzed_classes[0]["methods"][0]["method_name"] == "my_method"


def test_codeanalyzer_finds_multiple_methods():
Expand All @@ -337,8 +339,8 @@ def method_b(self):
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 1
method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]])
assert len(analyzer.analyzed_classes) == 1
method_names = sorted([m["method_name"] for m in analyzer.analyzed_classes[0]["methods"]])
assert method_names == ["method_a", "method_b"]


Expand All @@ -352,7 +354,7 @@ def top_level_function():
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 0
assert len(analyzer.analyzed_classes) == 0


def test_codeanalyzer_class_with_no_methods():
Expand All @@ -365,9 +367,9 @@ class MyClass:
analyzer = CodeAnalyzer()
tree = ast.parse(code)
analyzer.visit(tree)
assert len(analyzer.structure) == 1
assert analyzer.structure[0]["class_name"] == "MyClass"
assert len(analyzer.structure[0]["methods"]) == 0
assert len(analyzer.analyzed_classes) == 1
assert analyzer.analyzed_classes[0]["class_name"] == "MyClass"
assert len(analyzer.analyzed_classes[0]["methods"]) == 0


# --- Test Data for Parameterization ---
Expand Down Expand Up @@ -487,10 +489,10 @@ class TestCodeAnalyzerArgsReturns:
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES
)
def test_type_extraction(self, code_snippet, expected_args, expected_return):
structure, imports, types = parse_code(code_snippet)
analyzed_classes, imports, types = parse_code(code_snippet)

assert len(structure) == 1, "Should parse one class"
class_info = structure[0]
assert len(analyzed_classes) == 1, "Should parse one class"
class_info = analyzed_classes[0]
assert class_info["class_name"] == "TestClass"

assert len(class_info["methods"]) == 1, "Should find one method"
Expand All @@ -506,3 +508,4 @@ def test_type_extraction(self, code_snippet, expected_args, expected_return):

assert extracted_args == expected_args
assert method_info.get("return_type") == expected_return

Loading
Loading