import os
import ast
import argparse
from stdlib_list import stdlib_list
from importlib.metadata import packages_distributions, distribution
def find_py_files(directory):
py_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
py_files.append(os.path.join(root, file))
return py_files
def extract_imports(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
tree = ast.parse(f.read(), filename=file_path)
except Exception:
return set()
imports = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.add(alias.name.split('.')[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.add(node.module.split('.')[0])
return imports
def is_third_party(module_name, stdlib_modules, installed_dists):
if module_name in stdlib_modules:
return False
return module_name in installed_dists
def map_module_to_distribution(modules):
dist_map = packages_distributions()
result = {}
for m in modules:
dist_names = dist_map.get(m)
if dist_names:
result[m] = dist_names[0] # 假设第一个是主 distribution 名
return result
def get_distribution_versions(distributions):
versions = {}
for dist in set(distributions):
try:
versions[dist] = distribution(dist).version
except:
pass
return versions
def generate_requirements(directory, output_file="requirements.txt", python_version="3.9"):
stdlib_modules = set(stdlib_list(python_version))
py_files = find_py_files(directory)
all_imports = set()
for py_file in py_files:
all_imports |= extract_imports(py_file)
module_to_dist = map_module_to_distribution(all_imports)
third_party_dists = set(module_to_dist.values())
versions = get_distribution_versions(third_party_dists)
with open(output_file, "w", encoding="utf-8") as f:
for dist in sorted(third_party_dists):
version = versions.get(dist)
if version:
f.write(f"{dist}=={version}\n")
else:
f.write(f"{dist}\n")
print(f"✅ requirements.txt 已生成,共 {len(third_party_dists)} 个包。")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="根据目录中 import 的模块生成 requirements.txt")
parser.add_argument("directory", help="要扫描的目录")
parser.add_argument("--output", default="requirements.txt", help="输出文件名")
parser.add_argument("--python-version", default="3.9", help="Python 版本(用于标准库过滤)")
args = parser.parse_args()
generate_requirements(args.directory, args.output, args.python_version)
python test_reqs.py E:\Program\projects_python\lock_test\test_reqs --output custom_requirements.txt