From b3ae2625ac25b32f6315b71efb643b8024ffef8a Mon Sep 17 00:00:00 2001 From: Patrick_Pluto Date: Mon, 16 Sep 2024 11:44:35 +0200 Subject: [PATCH] Added online model. --- .gitignore | 1 + py/.idea/misc.xml | 2 +- py/.idea/py.iml | 2 +- py/api.py | 24 ++++++++++++++++++------ py/venv.sh | 5 +++++ 5 files changed, 26 insertions(+), 8 deletions(-) create mode 100644 .gitignore create mode 100644 py/venv.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7275bb --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +venv/ diff --git a/py/.idea/misc.xml b/py/.idea/misc.xml index db8786c..f5d7485 100644 --- a/py/.idea/misc.xml +++ b/py/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/py/.idea/py.iml b/py/.idea/py.iml index d0876a7..451946f 100644 --- a/py/.idea/py.iml +++ b/py/.idea/py.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/py/api.py b/py/api.py index f1df05c..9e6b40d 100644 --- a/py/api.py +++ b/py/api.py @@ -1,10 +1,20 @@ import requests import json - +from transformers import AutoTokenizer, LlamaForCausalLM class API: @staticmethod - def process_text(prompt, model): + def process_text_transformers(prompt, model): + model = LlamaForCausalLM.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model) + + inputs = tokenizer(prompt, return_tensors="pt") + + generate_ids = model.generate(inputs.input_ids, max_length=30) + return tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + @staticmethod + def process_text_local(prompt, model): ollama_url = "http://localhost:11434" response = requests.post( @@ -26,13 +36,15 @@ class API: def send_message(self, message, model): if model == 1: - answer = self.process_text(message, "phi3.5") + answer = self.process_text_local(message, "phi3.5") elif model == 2: - answer = self.process_text(message, "gemma2:2b") + answer = self.process_text_local(message, "gemma2:2b") elif model == 3: - answer = self.process_text(message, "qwen2:0.5b") + answer = self.process_text_local(message, "qwen2:0.5b") elif model == 4: - answer = self.process_text(message, "codegemma:2b") + answer = self.process_text_local(message, "codegemma:2b") + elif model == 5: + answer = self.process_text_transformers(message, "meta-llama/Meta-Llama-3.1-8B") else: return "Invalid choice" return answer diff --git a/py/venv.sh b/py/venv.sh new file mode 100644 index 0000000..4a3be2f --- /dev/null +++ b/py/venv.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +virtualenv venv +source venv/bin/activate +pip install transformers