Added online model.
This commit is contained in:
parent
c4655fb49e
commit
b3ae2625ac
5 changed files with 26 additions and 8 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
venv/
|
|
@ -3,5 +3,5 @@
|
||||||
<component name="Black">
|
<component name="Black">
|
||||||
<option name="sdkName" value="Python 3.12" />
|
<option name="sdkName" value="Python 3.12" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (py)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
|
@ -2,7 +2,7 @@
|
||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="inheritedJdk" />
|
<orderEntry type="jdk" jdkName="Python 3.12 (py)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
24
py/api.py
24
py/api.py
|
@ -1,10 +1,20 @@
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
class API:
|
class API:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_text(prompt, model):
|
def process_text_transformers(prompt, model):
|
||||||
|
model = LlamaForCausalLM.from_pretrained(model)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
return tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_text_local(prompt, model):
|
||||||
ollama_url = "http://localhost:11434"
|
ollama_url = "http://localhost:11434"
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -26,13 +36,15 @@ class API:
|
||||||
|
|
||||||
def send_message(self, message, model):
|
def send_message(self, message, model):
|
||||||
if model == 1:
|
if model == 1:
|
||||||
answer = self.process_text(message, "phi3.5")
|
answer = self.process_text_local(message, "phi3.5")
|
||||||
elif model == 2:
|
elif model == 2:
|
||||||
answer = self.process_text(message, "gemma2:2b")
|
answer = self.process_text_local(message, "gemma2:2b")
|
||||||
elif model == 3:
|
elif model == 3:
|
||||||
answer = self.process_text(message, "qwen2:0.5b")
|
answer = self.process_text_local(message, "qwen2:0.5b")
|
||||||
elif model == 4:
|
elif model == 4:
|
||||||
answer = self.process_text(message, "codegemma:2b")
|
answer = self.process_text_local(message, "codegemma:2b")
|
||||||
|
elif model == 5:
|
||||||
|
answer = self.process_text_transformers(message, "meta-llama/Meta-Llama-3.1-8B")
|
||||||
else:
|
else:
|
||||||
return "Invalid choice"
|
return "Invalid choice"
|
||||||
return answer
|
return answer
|
||||||
|
|
5
py/venv.sh
Normal file
5
py/venv.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
virtualenv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install transformers
|
Loading…
Reference in a new issue