读取文件夹三方依赖

Posted by Shallow Dreameron August 7, 2025
import os
import ast
import argparse
from stdlib_list import stdlib_list
from importlib.metadata import packages_distributions, distribution


def find_py_files(directory):
    py_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".py"):
                py_files.append(os.path.join(root, file))
    return py_files


def extract_imports(file_path):
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            tree = ast.parse(f.read(), filename=file_path)
    except Exception:
        return set()

    imports = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                imports.add(alias.name.split('.')[0])
        elif isinstance(node, ast.ImportFrom):
            if node.module:
                imports.add(node.module.split('.')[0])
    return imports


def is_third_party(module_name, stdlib_modules, installed_dists):
    if module_name in stdlib_modules:
        return False
    return module_name in installed_dists


def map_module_to_distribution(modules):
    dist_map = packages_distributions()
    result = {}
    for m in modules:
        dist_names = dist_map.get(m)
        if dist_names:
            result[m] = dist_names[0]  # 假设第一个是主 distribution 名
    return result


def get_distribution_versions(distributions):
    versions = {}
    for dist in set(distributions):
        try:
            versions[dist] = distribution(dist).version
        except:
            pass
    return versions


def generate_requirements(directory, output_file="requirements.txt", python_version="3.9"):
    stdlib_modules = set(stdlib_list(python_version))
    py_files = find_py_files(directory)

    all_imports = set()
    for py_file in py_files:
        all_imports |= extract_imports(py_file)

    module_to_dist = map_module_to_distribution(all_imports)
    third_party_dists = set(module_to_dist.values())

    versions = get_distribution_versions(third_party_dists)

    with open(output_file, "w", encoding="utf-8") as f:
        for dist in sorted(third_party_dists):
            version = versions.get(dist)
            if version:
                f.write(f"{dist}=={version}\n")
            else:
                f.write(f"{dist}\n")

    print(f"✅ requirements.txt 已生成,共 {len(third_party_dists)} 个包。")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="根据目录中 import 的模块生成 requirements.txt")
    parser.add_argument("directory", help="要扫描的目录")
    parser.add_argument("--output", default="requirements.txt", help="输出文件名")
    parser.add_argument("--python-version", default="3.9", help="Python 版本(用于标准库过滤)")

    args = parser.parse_args()
    generate_requirements(args.directory, args.output, args.python_version)
python test_reqs.py E:\Program\projects_python\lock_test\test_reqs --output custom_requirements.txt