Skip to content

Commit

Permalink
Fix the Ollama FIM completion (#848)
Browse files Browse the repository at this point in the history
* Do not send the system message, only the user message to Ollama FIM

We just blindly took the first message which could have been a system
message. The prompt is in the user message, so let's pass that.

* Pass the suffix parameter if present to the FIM generate call

Some models don't use FIM markers in the code, but use the suffix
top-level attribute instead. If we don't pass it, the FIM won't complete
successfully.

* Run make format

Signed-off-by: Radoslav Dimitrov <[email protected]>

* Fix the unit tests for ollama FIM

Signed-off-by: Radoslav Dimitrov <[email protected]>

* Pass along the `raw` parameter

The `raw` parameter tells the LLM to never use natural language, but
just reply in the format of the message. We need to pass that to the
generate call or else we migth get garbage back to the client.

* Print the full reply as debug message in integration tests

This is just useful to debug the tests

* Adjust the Ollama FIM testcase to match the model we are using

The FIM format didn't match the model it seemed. I replaced it with a
dump of a FIM message I received from Continue.

---------

Signed-off-by: Radoslav Dimitrov <[email protected]>
Co-authored-by: Radoslav Dimitrov <[email protected]>
  • Loading branch information
2 people authored and lukehinds committed Jan 31, 2025
1 parent 57259da commit 0ebb02c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 41 deletions.
16 changes: 14 additions & 2 deletions src/codegate/providers/ollama/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,21 @@ async def execute_completion(
"""Stream response directly from Ollama API."""
self.base_tool = base_tool
if is_fim_request:
prompt = request["messages"][0].get("content", "")
prompt = ""
for i in reversed(range(len(request["messages"]))):
if request["messages"][i]["role"] == "user":
prompt = request["messages"][i]["content"] # type: ignore
break
if not prompt:
raise ValueError("No user message found in FIM request")

response = await self.client.generate(
model=request["model"], prompt=prompt, stream=stream, options=request["options"] # type: ignore
model=request["model"],
prompt=prompt,
raw=request.get("raw", False),
suffix=request.get("suffix", ""),
stream=stream,
options=request["options"], # type: ignore
)
else:
response = await self.client.chat(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async def run_test(self, test: dict, test_headers: dict) -> bool:

try:
parsed_response = self.parse_response_message(response, streaming=streaming)
logger.debug(f"Response message: {parsed_response}")

# Load appropriate checks for this test
checks = CheckLoader.load(test)
Expand Down
62 changes: 23 additions & 39 deletions tests/integration/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -297,48 +297,32 @@ testcases:
url: http://127.0.0.1:8989/ollama/api/generate
data: |
{
"model": "qwen2.5-coder:0.5b",
"max_tokens": 4096,
"temperature": 0,
"stream": true,
"stop": [
"<fim_prefix>",
"<fim_suffix>",
"<fim_middle>",
"<file_sep>",
"</fim_middle>",
"</code>",
"/src/",
"#- coding: utf-8",
"```",
""
],
"model": "qwen2.5-coder:0.5b",
"raw": true,
"options": {
"temperature": 0.01,
"num_predict": 4096,
"stop": [
"<|endoftext|>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>",
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<|im_start|>",
"<|im_end|>",
"/src/",
"#- coding: utf-8",
"```"
],
"num_ctx": 8096
},
"prompt":"<|fim_prefix|>\n# codegate/test.py\nimport invokehttp\nimport requests\n\nkey = \"mysecret-key\"\n\ndef call_api():\n <|fim_suffix|>\n\n\ndata = {'key1': 'test1', 'key2': 'test2'}\nresponse = call_api('http://localhost:8080', method='post', data='data')\n<|fim_middle|>"
}
likes: |
```python
import invokehttp
import requests
key = "mysecret-key"
def call_api(url, method='get', data=None):
headers = {
'Authorization': f'Bearer {key}'
}
if method == 'get':
response = requests.get(url, headers=headers)
elif method == 'post':
response = requests.post(url, headers=headers, json=data)
else:
raise ValueError("Unsupported HTTP method")
return response
data = {'key1': 'test1', 'key2': 'test2'}
response = call_api('http://localhost:8080', method='post', data=data)
print(response.status_code)
print(response.json())
if __name__ == '__main__':
invokehttp.run(call_api)
```
2 changes: 2 additions & 0 deletions tests/providers/ollama/test_ollama_completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ async def test_execute_completion_is_fim_request(handler, chat_request):
prompt="FIM prompt",
stream=False,
options=chat_request["options"],
suffix="",
raw=False,
)


Expand Down

0 comments on commit 0ebb02c

Please sign in to comment.