-
Notifications
You must be signed in to change notification settings - Fork 21
/
fix_pydantic.py
106 lines (84 loc) · 3.32 KB
/
fix_pydantic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
""" Using ast transformer to fix issues with automatic pydantic generation
Currently the main issue is with the LangString representation in the pydantic
model, I'm changing it to Dict[str, str].
Perhaps in the future this script will not be required.
"""
import ast
import sys
import astor
class ClassRemover(ast.NodeTransformer):
def __init__(self, class_name):
self.class_name = class_name
def visit_ClassDef(self, node):
# Remove the class if its name matches the class_to_remove
if node.name == self.class_name:
return None
return node
def visit_Expr(self, node):
# Check if the node is a call expression
if isinstance(node.value, ast.Call):
# Check if the call expression is an attribute (method call)
if isinstance(node.value.func, ast.Attribute):
# Check if the method call matches the specified class
if (
isinstance(node.value.func.value, ast.Name)
and node.value.func.value.id == self.class_name
):
return None # Remove this node
return self.generic_visit(node)
class TypeReplacer(ast.NodeTransformer):
def __init__(self, old_type, new_type):
self.old_type = old_type
self.new_type = new_type
def visit_FunctionDef(self, node):
# Check all arguments in the function definition
for arg in node.args.args:
if arg.annotation:
arg.annotation = self.visit(arg.annotation)
return self.generic_visit(node)
def visit_AsyncFunctionDef(self, node):
# Handle async function definitions similarly
for arg in node.args.args:
if arg.annotation:
arg.annotation = self.visit(arg.annotation)
return self.generic_visit(node)
def visit_Name(self, node):
# Replace the old type with the new type
if node.id == self.old_type:
node.id = self.new_type
return node
def visit_Subscript(self, node):
# Handle Union, Optional, and other subscripted types
node.value = self.visit(node.value)
node.slice = self.visit(node.slice)
return node
def visit_Index(self, node):
# Handle the index part of subscripted types
node.value = self.visit(node.value)
return node
def visit_Tuple(self, node):
# Handle tuples in type annotations
node.elts = [self.visit(elt) for elt in node.elts]
return node
def edit_pydantic(input_file, output_file):
with open(input_file, "r") as file:
tree = ast.parse(file.read())
transformer_class = ClassRemover(class_name="LangString")
tree_modclass = transformer_class.visit(tree)
transformer_tp = TypeReplacer(
old_type="LangString", new_type="Dict[str, str]"
)
tree_modclass_modtype = transformer_tp.visit(tree_modclass)
with open(output_file, "w") as file:
file.write(astor.to_source(tree_modclass_modtype))
if __name__ == "__main__":
input_file = sys.argv[1]
if len(sys.argv) < 3:
output_file = input_file
else:
output_file = sys.argv[2]
print(
f"Fixing automatically generated pydantic file {input_file} "
f"and saving to {output_file}"
)
edit_pydantic(input_file, output_file)