summaryrefslogtreecommitdiff
path: root/py/context.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--py/context.py97
1 files changed, 68 insertions, 29 deletions
diff --git a/py/context.py b/py/context.py
index 15c5e8e..f1eacd6 100644
--- a/py/context.py
+++ b/py/context.py
@@ -24,6 +24,63 @@ def merge_deep(objects):
merge_deep_recursive(result, o)
return result
+def is_deprecated_role_syntax(roles, role):
+ deprecated_sections = [
+ 'options', 'options-complete', 'options-edit', 'options-chat',
+ 'ui', 'ui-complete', 'ui-edit', 'ui-chat',
+ ]
+ for section in deprecated_sections:
+ if f"{role}.{section}" in roles:
+ return True
+ return False
+
+def load_roles_with_deprecated_syntax(roles, role):
+ prompt = dict(roles[role]).get('prompt', '')
+ return {
+ 'role_default': {
+ 'prompt': prompt,
+ 'config': {
+ 'options': dict(roles.get(f"{role}.options", {})),
+ 'ui': dict(roles.get(f"{role}.ui", {})),
+ },
+ },
+ 'role_complete': {
+ 'prompt': prompt,
+ 'config': {
+ 'options': dict(roles.get(f"{role}.options-complete", {})),
+ 'ui': dict(roles.get(f"{role}.ui-complete", {})),
+ },
+ },
+ 'role_edit': {
+ 'prompt': prompt,
+ 'config': {
+ 'options': dict(roles.get(f"{role}.options-edit", {})),
+ 'ui': dict(roles.get(f"{role}.ui-edit", {})),
+ },
+ },
+ 'role_chat': {
+ 'prompt': prompt,
+ 'config': {
+ 'options': dict(roles.get(f"{role}.options-chat", {})),
+ 'ui': dict(roles.get(f"{role}.ui-chat", {})),
+ },
+ },
+ }
+
+def parse_role_section(role):
+ result = {}
+ for key in role.keys():
+ parts = key.split('.')
+ structure = parts[:-1]
+ primitive = parts[-1]
+ obj = result
+ for path in structure:
+ if not path in obj:
+ obj[path] = {}
+ obj = obj[path]
+ obj[primitive] = role.get(key)
+ return result
+
def load_role_config(role):
roles_config_path = os.path.expanduser(vim.eval("g:vim_ai_roles_config_file"))
if not os.path.exists(roles_config_path):
@@ -38,34 +95,14 @@ def load_role_config(role):
if not role in roles:
raise Exception(f"Role `{role}` not found")
- options = roles.get(f"{role}.options", {})
- options_complete = roles.get(f"{role}.options-complete", {})
- options_edit = roles.get(f"{role}.options-edit", {})
- options_chat = roles.get(f"{role}.options-chat", {})
-
- ui = roles.get(f"{role}.ui", {})
- ui_complete = roles.get(f"{role}.ui-complete", {})
- ui_edit = roles.get(f"{role}.ui-edit", {})
- ui_chat = roles.get(f"{role}.ui-chat", {})
+ if is_deprecated_role_syntax(roles, role):
+ return load_roles_with_deprecated_syntax(roles, role)
return {
- 'role': dict(roles[role]),
- 'config_default': {
- 'options': dict(options),
- 'ui': dict(ui),
- },
- 'config_complete': {
- 'options': dict(options_complete),
- 'ui': dict(ui_complete),
- },
- 'config_edit': {
- 'options': dict(options_edit),
- 'ui': dict(ui_edit),
- },
- 'config_chat': {
- 'options': dict(options_chat),
- 'ui': dict(ui_chat),
- },
+ 'role_default': parse_role_section(roles.get(role, {})),
+ 'role_complete': parse_role_section(roles.get(f"{role}.complete", {})),
+ 'role_edit': parse_role_section(roles.get(f"{role}.edit", {})),
+ 'role_chat': parse_role_section(roles.get(f"{role}.chat", {})),
}
def parse_role_names(prompt):
@@ -87,9 +124,11 @@ def parse_prompt_and_role_config(user_instruction, command_type):
last_role = roles[-1]
user_prompt = user_instruction[user_instruction.index(last_role) + len(last_role):].strip() # strip roles
- role_configs = merge_deep([load_role_config(role) for role in roles])
- config = merge_deep([role_configs['config_default'], role_configs['config_' + command_type]])
- role_prompt = role_configs['role'].get('prompt', '')
+ parsed_role = merge_deep([load_role_config(role) for role in roles])
+ role_default = parsed_role['role_default']
+ role_command = parsed_role['role_' + command_type]
+ config = merge_deep([role_default.get('config', {}), role_command.get('config', {})])
+ role_prompt = role_default.get('prompt') or role_command.get('prompt', '')
return user_prompt, role_prompt, config
def make_selection_prompt(user_selection, user_prompt, role_prompt, selection_boundary):