diff options
| author | Max Resnick <max@ofmax.li> | 2025-01-28 22:27:47 -0800 |
|---|---|---|
| committer | Max Resnick <max@ofmax.li> | 2025-01-31 22:45:23 -0800 |
| commit | c8162fc43a748a97aab4647be2f5cdf50bc739ea (patch) | |
| tree | 0099688a7a09af76a2ae978790c5e013b59bcb9c /py | |
| parent | 380d5cdd9538c2522dfc8d03a8a261760bb0439a (diff) | |
| download | vim-ai-c8162fc43a748a97aab4647be2f5cdf50bc739ea.tar.gz | |
chore: rebase and fix up conflicts
Diffstat (limited to 'py')
| -rw-r--r-- | py/chat.py | 4 | ||||
| -rw-r--r-- | py/complete.py | 5 | ||||
| -rw-r--r-- | py/types.py | 17 | ||||
| -rw-r--r-- | py/utils.py | 20 |
4 files changed, 44 insertions, 2 deletions
@@ -63,8 +63,10 @@ def run_ai_chat(context): print('Answering...') vim.command("redraw") + provider_class = load_provider(config['provider']) + provider = provider_class(config) + text_chunks = provider.request(messages) - text_chunks = make_chat_text_chunks(messages, options) render_text_chunks(text_chunks) vim.command("normal! a\n\n>>> user\n\n") diff --git a/py/complete.py b/py/complete.py index 8078dea..765d6e0 100644 --- a/py/complete.py +++ b/py/complete.py @@ -42,7 +42,10 @@ def run_ai_completition(context): if prompt: print('Completing...') vim.command("redraw") - text_chunks = engines[engine](prompt) + provider_class = load_provider(config['provider']) + provider = provider_class(config) + messages = parse_chat_messages(f">>> user\n\n{prompt}".strip()) + text_chunks = provider.request(messages) render_text_chunks(text_chunks) clear_echo_message() except BaseException as error: diff --git a/py/types.py b/py/types.py new file mode 100644 index 0000000..819afe5 --- /dev/null +++ b/py/types.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence, Mapping +from typing import TypedDict, Protocol + + +class AIProvider(Protocol): + + def __init__(self, config: Mapping[str, str]) -> None: + pass + + def request(self, messages: Sequence[Message]) -> Generator[str]: + pass + +class Message(TypedDict): + role: str + content: str + type: str + diff --git a/py/utils.py b/py/utils.py index 8cd204f..2147785 100644 --- a/py/utils.py +++ b/py/utils.py @@ -359,3 +359,23 @@ def save_b64_to_file(path, b64_data): f = open(path, "wb") f.write(base64.b64decode(b64_data)) f.close() + +def load_provider(provider): + provider_name, provider_module = provider["name"].split(".") + if provider_name != "openai": + provider_path = os.path.join(f"{plugin_root}", + "..", + f"vim-ai-{provider_name}", + "py", + f"{provider_module}.py") + else: + return openai_request + vim.command(f"py3file {provider_path}") + try: + provider_class = globals()[provider["class"]] + except KeyError as error: + printDebug("[load-provider] provider: {}", error) + raise KeyError(error.message, "provider not found") + return provider_class + + |