menderbot.ingest.is_path_included()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 23
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 22
nop 1
dl 0
loc 23
rs 9.352
c 0
b 0
f 0
1
import glob
2
import os
3
from os.path import splitext
4
5
from git import Repo
6
from llama_index.agent.openai import OpenAIAgent  # type: ignore[import-untyped]
7
from llama_index.core import (
8
    ServiceContext,
9
    SimpleDirectoryReader,
10
    StorageContext,
11
    VectorStoreIndex,
12
    load_index_from_storage,
13
)
14
from llama_index.core.llms.mock import MockLLM
15
from llama_index.core.tools import QueryEngineTool
16
from llama_index.embeddings.openai import (  # type: ignore[import-untyped]
17
    OpenAIEmbedding,
18
)
19
from llama_index.llms.openai import OpenAI  # type: ignore[import-untyped]
20
21
from menderbot.llm import is_test_override
22
23
PERSIST_DIR = ".menderbot/ingest"
24
INDEX_FILE_NAMES = [
25
    "docstore.json",
26
    "graph_store.json",
27
    "index_store.json",
28
    "vector_store.json",
29
]
30
31
32
def delete_index(persist_dir: str) -> None:
33
    if os.path.exists(persist_dir):
34
        map(os.remove, glob.glob(os.path.join(persist_dir, "*.json")))
35
36
37
def is_path_included(path: str) -> bool:
38
    excluded_paths = ["Pipfile.lock"]
39
    included_extensions = [
40
        ".md",
41
        ".java",
42
        ".c",
43
        ".cpp",
44
        ".cc",
45
        ".py",
46
        ".txt",
47
        ".yaml",
48
        ".yml",
49
        ".go",
50
        ".sh",
51
        ".js",
52
        ".ts",
53
        ".tsx",
54
        ".dart",
55
        ".test",
56
        ".bat",
57
    ]
58
    _, ext = splitext(path)
59
    return path not in excluded_paths and ext.lower() in included_extensions
60
61
62
def ingest_repo(replace=False) -> None:
63
    if replace:
64
        delete_index(PERSIST_DIR)
65
66
    repo = Repo(".")
67
    commit = repo.commit("HEAD")
68
69
    file_paths = [
70
        item.path  # type: ignore
71
        for item in commit.tree.traverse()
72
        if item.type == "blob" and is_path_included(item.path)  # type: ignore
73
    ]
74
75
    def filename_fn(filename: str) -> dict:
76
        return {"file_name": filename}
77
78
    documents = SimpleDirectoryReader(
79
        input_files=file_paths,
80
        recursive=True,
81
        file_metadata=filename_fn,
82
    ).load_data()
83
    index = VectorStoreIndex.from_documents(
84
        documents,
85
        show_progress=True,
86
    )
87
    index.storage_context.persist(persist_dir=PERSIST_DIR)
88
89
90
def index_exists() -> bool:
91
    return all(
92
        [
93
            os.path.exists(os.path.join(PERSIST_DIR, filename))
94
            for filename in INDEX_FILE_NAMES
95
        ]
96
    )
97
98
99
def load_index():
100
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
101
    return load_index_from_storage(storage_context)
102
103
104
def get_llm():
105
    if is_test_override():
106
        return MockLLM(max_tokens=5)
107
    # NB: This uses the default OPENAI_API_KEY env var.
108
    # TODO: Use same API key as llm.py.
109
    return OpenAI(temperature=0, model="gpt-3.5-turbo")
110
111
112
def get_service_context() -> ServiceContext:
113
    return ServiceContext.from_defaults(
114
        llm=get_llm(), embed_model=OpenAIEmbedding(model="text-embedding-ada-002")
115
    )
116
117
118
def get_query_engine():
119
    if index_exists():
120
        return load_index().as_query_engine(
121
            similarity_top_k=5, service_context=get_service_context()
122
        )
123
    return VectorStoreIndex.from_documents([]).as_query_engine(
124
        service_context=get_service_context()
125
    )
126
127
128
def ask_index(query: str):
129
    return get_query_engine().query(query)
130
131
132
def get_chat_engine(verbose=False) -> OpenAIAgent:
133
    system_prompt = """
134
You are a Menderbot chat agent discussing a legacy codebase.
135
"""
136
    tool_description = """Useful for running a natural language query
137
about the codebase and get back a natural language response.
138
"""
139
    query_engine_tool = QueryEngineTool.from_defaults(
140
        query_engine=get_query_engine(), description=tool_description
141
    )
142
    service_context = get_service_context()
143
    llm = service_context.llm
144
    return OpenAIAgent.from_tools(
145
        tools=[query_engine_tool], llm=llm, verbose=verbose, system_prompt=system_prompt
146
    )
147