summaryrefslogtreecommitdiff
path: root/py/chat.py
blob: 810fc468795078b934d016a6df9fcd442744d988 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import vim

chat_py_imported = True

def run_ai_chat(context):
    prompt = context['prompt']
    config = make_config(context['config'])
    config_options = config['options']
    config_ui = config['ui']

    def initialize_chat_window():
        lines = vim.eval('getline(1, "$")')
        contains_user_prompt = '>>> user' in lines
        if not contains_user_prompt:
            # user role not found, put whole file content as an user prompt
            vim.command("normal! gg")
            populates_options = config_ui['populate_options'] == '1'
            if populates_options:
                vim.command("normal! O[chat-options]")
                vim.command("normal! o")
                for key, value in config_options.items():
                    if key == 'initial_prompt':
                        value = "\\n".join(value)
                    vim.command("normal! i" + key + "=" + value + "\n")
            vim.command("normal! " + ("o" if populates_options else "O"))
            vim.command("normal! i>>> user\n")

        vim.command("normal! G")
        vim_break_undo_sequence()
        vim.command("redraw")

        file_content = vim.eval('trim(join(getline(1, "$"), "\n"))')
        role_lines = re.findall(r'(^>>> user|^>>> system|^<<< assistant).*', file_content, flags=re.MULTILINE)
        if not role_lines[-1].startswith(">>> user"):
            # last role is not user, most likely completion was cancelled before
            vim.command("normal! o")
            vim.command("normal! i\n>>> user\n\n")

        if prompt:
            vim.command("normal! i" + prompt)
            vim_break_undo_sequence()
            vim.command("redraw")

    initialize_chat_window()

    chat_options = parse_chat_header_options()
    options = {**config_options, **chat_options}

    initial_prompt = '\n'.join(options.get('initial_prompt', []))
    initial_messages = parse_chat_messages(initial_prompt)

    chat_content = vim.eval('trim(join(getline(1, "$"), "\n"))')
    print_debug("[chat] text:\n" + chat_content)
    chat_messages = parse_chat_messages(chat_content)

    messages = initial_messages + chat_messages

    try:
        last_content = messages[-1]["content"][-1]
        if last_content['type'] != 'text' or last_content['text']:
            vim.command("normal! Go\n<<< assistant\n\n")
            vim.command("redraw")

            print('Answering...')
            vim.command("redraw")
            provider_class = load_provider(config['provider'])
            provider = provider_class(config)
            text_chunks = provider.request(messages)

            render_text_chunks(text_chunks)

            vim.command("normal! a\n\n>>> user\n\n")
            vim.command("redraw")
            clear_echo_message()
    except BaseException as error:
        handle_completion_error(error)
        print_debug("[chat] error: {}", traceback.format_exc())