Initial commit.
This commit is contained in:
commit
3f979d2bbd
27
.editorconfig
Normal file
27
.editorconfig
Normal file
@ -0,0 +1,27 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.py]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
max_line_length = 100
|
||||
|
||||
[*.{yml,yaml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.{json,toml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
max_line_length = off
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
37
.github/workflows/lint.yml
vendored
Normal file
37
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run Black
|
||||
run: |
|
||||
black --check pr tests
|
||||
|
||||
- name: Run Flake8
|
||||
run: |
|
||||
flake8 pr tests --max-line-length=100 --ignore=E203,W503
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
mypy pr --ignore-missing-imports
|
||||
continue-on-error: true
|
||||
40
.github/workflows/test.yml
vendored
Normal file
40
.github/workflows/test.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
pytest --cov=pr --cov-report=xml --cov-report=term-missing
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
162
.gitignore
vendored
Normal file
162
.gitignore
vendored
Normal file
@ -0,0 +1,162 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
CLAUDE.md
|
||||
.claude
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
ab
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Project specific
|
||||
.rcontext.txt
|
||||
.prrc
|
||||
.assistant_*
|
||||
*.db
|
||||
*.sqlite
|
||||
*.old
|
||||
imploded.py
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
67
.pre-commit-config.yaml
Normal file
67
.pre-commit-config.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: check-json
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-case-conflict
|
||||
- id: detect-private-key
|
||||
- id: mixed-line-ending
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.12.1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
args: ['--line-length=100']
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ['--max-line-length=100', '--ignore=E203,W503,E501']
|
||||
additional_dependencies: [flake8-docstrings]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: ['--ignore-missing-imports', '--check-untyped-defs']
|
||||
additional_dependencies: [types-all]
|
||||
exclude: ^tests/
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ['--profile', 'black', '--line-length', '100']
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ['-c', 'pyproject.toml']
|
||||
additional_dependencies: ['bandit[toml]']
|
||||
exclude: ^tests/
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v4.0.0-alpha.8
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [yaml, markdown, json]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
args: ['--maxfail=1', '-q']
|
||||
133
CHANGELOG.md
Normal file
133
CHANGELOG.md
Normal file
@ -0,0 +1,133 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Comprehensive test suite with pytest
|
||||
- Structured logging system with rotating file handlers
|
||||
- Configuration file support (`.prrc`)
|
||||
- Token and cost tracking across sessions
|
||||
- Session persistence (save/load/export)
|
||||
- CI/CD pipeline with GitHub Actions
|
||||
- Custom exception hierarchy
|
||||
- Plugin system for custom tools
|
||||
- Multiple output formats (text, JSON, structured)
|
||||
- Input validation system
|
||||
- Progress indicators for long operations
|
||||
- Pre-commit hooks configuration
|
||||
- Comprehensive documentation (README, CONTRIBUTING, CHANGELOG)
|
||||
- Version flag (`--version`)
|
||||
- Better help text with examples
|
||||
- Session management commands
|
||||
- Usage statistics tracking
|
||||
|
||||
### Changed
|
||||
- Improved argument parser with better help text
|
||||
- Enhanced error handling throughout codebase
|
||||
- Modular architecture with clear separation of concerns
|
||||
|
||||
### Fixed
|
||||
- N/A (initial professional release)
|
||||
|
||||
## [1.0.0] - 2025-01-XX
|
||||
|
||||
### Added
|
||||
- Initial release with core functionality
|
||||
- Interactive and single-message modes
|
||||
- Autonomous execution mode
|
||||
- 16 built-in tools (filesystem, command, web, database, Python)
|
||||
- OpenRouter API integration
|
||||
- Context window management with summarization
|
||||
- Markdown rendering with syntax highlighting
|
||||
- Tool call visualization
|
||||
- Command history with readline
|
||||
- File versioning in SQLite database
|
||||
- Multiple context file support
|
||||
- Environment variable configuration
|
||||
- Model switching support
|
||||
- Verbose mode
|
||||
- API mode for specialized interaction
|
||||
|
||||
### Core Features
|
||||
- **Assistant Core**: Main orchestrator with REPL loop
|
||||
- **API Communication**: OpenRouter integration with streaming
|
||||
- **Tool System**: Parallel execution with ThreadPoolExecutor
|
||||
- **Autonomous Mode**: Max 50 iterations with completion detection
|
||||
- **Context Management**: Automatic compression at 30 message threshold
|
||||
- **UI Components**: ANSI colors, markdown rendering, fancy displays
|
||||
|
||||
### Tools Included
|
||||
- File Operations: read, write, list, mkdir, chdir, getpwd, index
|
||||
- Commands: run_command, run_command_interactive
|
||||
- Web: http_fetch, web_search, web_search_news
|
||||
- Database: db_set, db_get, db_query
|
||||
- Python: python_exec with persistent context
|
||||
|
||||
### Configuration
|
||||
- Default model: x-ai/grok-code-fast-1
|
||||
- Temperature: 0.7
|
||||
- Max tokens: 8096
|
||||
- Context threshold: 30 messages
|
||||
- Max autonomous iterations: 50
|
||||
|
||||
## Version History
|
||||
|
||||
### Version Numbering
|
||||
|
||||
- **Major** (X.0.0): Breaking changes
|
||||
- **Minor** (1.X.0): New features, backwards compatible
|
||||
- **Patch** (1.0.X): Bug fixes, backwards compatible
|
||||
|
||||
### Release Schedule
|
||||
|
||||
- Major releases: As needed for breaking changes
|
||||
- Minor releases: Monthly or when significant features are ready
|
||||
- Patch releases: As needed for bug fixes
|
||||
|
||||
## Migration Guides
|
||||
|
||||
### Upgrading to 1.0.0
|
||||
|
||||
This is the initial professional release. If upgrading from development version:
|
||||
|
||||
1. Install with `pip install -e .`
|
||||
2. Run `pr --create-config` to generate configuration file
|
||||
3. Set `OPENROUTER_API_KEY` environment variable
|
||||
4. Existing data in `~/.assistant_db.sqlite` will continue to work
|
||||
|
||||
## Deprecation Notices
|
||||
|
||||
None currently.
|
||||
|
||||
## Known Issues
|
||||
|
||||
None currently.
|
||||
|
||||
## Future Releases
|
||||
|
||||
### Planned for 2.0.0
|
||||
- Multi-model conversations
|
||||
- Enhanced plugin API with hooks
|
||||
- Web UI dashboard
|
||||
- Team collaboration features
|
||||
|
||||
### Under Consideration
|
||||
- Docker containerization
|
||||
- Cloud deployment options
|
||||
- IDE integrations
|
||||
- Advanced code analysis tools
|
||||
|
||||
---
|
||||
|
||||
**Legend:**
|
||||
- `Added` - New features
|
||||
- `Changed` - Changes to existing functionality
|
||||
- `Deprecated` - Soon-to-be removed features
|
||||
- `Removed` - Removed features
|
||||
- `Fixed` - Bug fixes
|
||||
- `Security` - Security fixes
|
||||
362
CONTRIBUTING.md
Normal file
362
CONTRIBUTING.md
Normal file
@ -0,0 +1,362 @@
|
||||
# Contributing to PR Assistant
|
||||
|
||||
Thank you for your interest in contributing to PR Assistant! This document provides guidelines and instructions for contributing.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
- Be respectful and inclusive
|
||||
- Welcome newcomers and encourage diverse perspectives
|
||||
- Focus on constructive feedback
|
||||
- Maintain professionalism in all interactions
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8 or higher
|
||||
- Git
|
||||
- OpenRouter API key (for testing)
|
||||
|
||||
### Development Setup
|
||||
|
||||
1. Fork the repository on GitHub
|
||||
2. Clone your fork locally:
|
||||
```bash
|
||||
git clone https://github.com/YOUR_USERNAME/pr-assistant.git
|
||||
cd pr-assistant
|
||||
```
|
||||
|
||||
3. Install development dependencies:
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
4. Install pre-commit hooks:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
5. Create a feature branch:
|
||||
```bash
|
||||
git checkout -b feature/your-feature-name
|
||||
```
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Making Changes
|
||||
|
||||
1. **Write tests first** - Follow TDD when possible
|
||||
2. **Keep changes focused** - One feature/fix per PR
|
||||
3. **Follow code style** - Use Black for formatting
|
||||
4. **Add documentation** - Update docstrings and README as needed
|
||||
5. **Update changelog** - Add entry to CHANGELOG.md
|
||||
|
||||
### Code Style
|
||||
|
||||
We follow PEP 8 with some modifications:
|
||||
|
||||
- Line length: 100 characters (not 80)
|
||||
- Use Black for automatic formatting
|
||||
- Use descriptive variable names
|
||||
- Add type hints where beneficial
|
||||
- Write docstrings for all public functions/classes
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
def process_data(input_data: str, max_length: int = 100) -> Dict[str, Any]:
|
||||
"""
|
||||
Process input data and return structured result.
|
||||
|
||||
Args:
|
||||
input_data: The raw input string to process
|
||||
max_length: Maximum length of output (default: 100)
|
||||
|
||||
Returns:
|
||||
Dictionary containing processed data
|
||||
|
||||
Raises:
|
||||
ValidationError: If input_data is invalid
|
||||
"""
|
||||
# Implementation here
|
||||
pass
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
Run all tests:
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
Run with coverage:
|
||||
```bash
|
||||
pytest --cov=pr --cov-report=html
|
||||
```
|
||||
|
||||
Run specific test file:
|
||||
```bash
|
||||
pytest tests/test_tools.py
|
||||
```
|
||||
|
||||
Run specific test:
|
||||
```bash
|
||||
pytest tests/test_tools.py::TestFilesystemTools::test_read_file
|
||||
```
|
||||
|
||||
### Code Quality Checks
|
||||
|
||||
Before committing, run:
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
black pr tests
|
||||
|
||||
# Check linting
|
||||
flake8 pr tests --max-line-length=100 --ignore=E203,W503
|
||||
|
||||
# Type checking
|
||||
mypy pr --ignore-missing-imports
|
||||
|
||||
# Run all checks
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
### Adding New Features
|
||||
|
||||
#### Adding a New Tool
|
||||
|
||||
1. Implement the tool function in appropriate `pr/tools/*.py` file
|
||||
2. Add tool definition to `pr/tools/base.py:get_tools_definition()`
|
||||
3. Add function mapping in `pr/core/assistant.py:execute_tool_calls()`
|
||||
4. Add function mapping in `pr/autonomous/mode.py:execute_single_tool()`
|
||||
5. Write tests in `tests/test_tools.py`
|
||||
6. Update documentation
|
||||
|
||||
#### Adding a New Configuration Option
|
||||
|
||||
1. Add to default config in `pr/core/config_loader.py`
|
||||
2. Update config documentation in README.md
|
||||
3. Add validation if needed in `pr/core/validation.py`
|
||||
4. Write tests
|
||||
|
||||
#### Adding a New Command
|
||||
|
||||
1. Add handler to `pr/commands/handlers.py`
|
||||
2. Update help text in `pr/__main__.py`
|
||||
3. Write tests
|
||||
4. Update README.md with command documentation
|
||||
|
||||
### Plugin Development
|
||||
|
||||
Plugins should be self-contained Python files:
|
||||
|
||||
```python
|
||||
# Plugin structure
|
||||
def tool_function(args):
|
||||
"""Implementation"""
|
||||
pass
|
||||
|
||||
def register_tools():
|
||||
"""Return list of tool definitions"""
|
||||
return [...]
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Test Organization
|
||||
|
||||
- `tests/test_*.py` - Test files matching source files
|
||||
- `tests/conftest.py` - Shared fixtures
|
||||
- Use descriptive test names: `test_<function>_<scenario>_<expected_result>`
|
||||
|
||||
### Writing Good Tests
|
||||
|
||||
```python
|
||||
def test_read_file_with_valid_path_returns_content(temp_dir):
|
||||
# Arrange
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
expected_content = 'Hello, World!'
|
||||
write_file(filepath, expected_content)
|
||||
|
||||
# Act
|
||||
result = read_file(filepath)
|
||||
|
||||
# Assert
|
||||
assert expected_content in result
|
||||
```
|
||||
|
||||
### Test Coverage
|
||||
|
||||
- Aim for >80% code coverage
|
||||
- Cover edge cases and error conditions
|
||||
- Test both success and failure paths
|
||||
|
||||
## Documentation
|
||||
|
||||
### Docstring Format
|
||||
|
||||
Use Google-style docstrings:
|
||||
|
||||
```python
|
||||
def function_name(param1: str, param2: int) -> bool:
|
||||
"""
|
||||
Short description of function.
|
||||
|
||||
Longer description with more details about what the function
|
||||
does and how it works.
|
||||
|
||||
Args:
|
||||
param1: Description of param1
|
||||
param2: Description of param2
|
||||
|
||||
Returns:
|
||||
Description of return value
|
||||
|
||||
Raises:
|
||||
ValueError: Description of when this is raised
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### Updating Documentation
|
||||
|
||||
When adding features, update:
|
||||
|
||||
- Function/class docstrings
|
||||
- README.md
|
||||
- CHANGELOG.md
|
||||
- Code comments (where necessary)
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
### Before Submitting
|
||||
|
||||
1. ✅ All tests pass
|
||||
2. ✅ Code is formatted with Black
|
||||
3. ✅ Linting passes (flake8)
|
||||
4. ✅ No type errors (mypy)
|
||||
5. ✅ Documentation is updated
|
||||
6. ✅ CHANGELOG.md is updated
|
||||
7. ✅ Commits are clean and descriptive
|
||||
|
||||
### PR Template
|
||||
|
||||
```markdown
|
||||
## Description
|
||||
Brief description of changes
|
||||
|
||||
## Type of Change
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Breaking change
|
||||
- [ ] Documentation update
|
||||
|
||||
## Changes Made
|
||||
- Change 1
|
||||
- Change 2
|
||||
|
||||
## Testing
|
||||
- Test scenario 1
|
||||
- Test scenario 2
|
||||
|
||||
## Checklist
|
||||
- [ ] Tests pass
|
||||
- [ ] Code formatted with Black
|
||||
- [ ] Documentation updated
|
||||
- [ ] CHANGELOG.md updated
|
||||
```
|
||||
|
||||
### Review Process
|
||||
|
||||
1. Automated checks must pass
|
||||
2. At least one maintainer review required
|
||||
3. Address all review comments
|
||||
4. Squash commits if requested
|
||||
5. Maintainer will merge when approved
|
||||
|
||||
## Commit Messages
|
||||
|
||||
Follow conventional commits:
|
||||
|
||||
```
|
||||
type(scope): subject
|
||||
|
||||
body (optional)
|
||||
|
||||
footer (optional)
|
||||
```
|
||||
|
||||
Types:
|
||||
- `feat`: New feature
|
||||
- `fix`: Bug fix
|
||||
- `docs`: Documentation
|
||||
- `style`: Formatting
|
||||
- `refactor`: Code restructuring
|
||||
- `test`: Adding tests
|
||||
- `chore`: Maintenance
|
||||
|
||||
Examples:
|
||||
```
|
||||
feat(tools): add text analysis tool
|
||||
|
||||
fix(api): handle timeout errors properly
|
||||
|
||||
docs(readme): update installation instructions
|
||||
```
|
||||
|
||||
## Reporting Bugs
|
||||
|
||||
### Bug Report Template
|
||||
|
||||
**Description:**
|
||||
Clear description of the bug
|
||||
|
||||
**To Reproduce:**
|
||||
1. Step 1
|
||||
2. Step 2
|
||||
3. See error
|
||||
|
||||
**Expected Behavior:**
|
||||
What should happen
|
||||
|
||||
**Actual Behavior:**
|
||||
What actually happens
|
||||
|
||||
**Environment:**
|
||||
- OS:
|
||||
- Python version:
|
||||
- PR Assistant version:
|
||||
|
||||
**Additional Context:**
|
||||
Logs, screenshots, etc.
|
||||
|
||||
## Feature Requests
|
||||
|
||||
**Feature Description:**
|
||||
What feature would you like?
|
||||
|
||||
**Use Case:**
|
||||
Why is this needed?
|
||||
|
||||
**Proposed Solution:**
|
||||
How might this work?
|
||||
|
||||
**Alternatives:**
|
||||
Other approaches considered
|
||||
|
||||
## Questions?
|
||||
|
||||
- Open an issue for questions
|
||||
- Check existing issues first
|
||||
- Tag with `question` label
|
||||
|
||||
## License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the MIT License.
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to PR Assistant! 🎉
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 retoor
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
74
Makefile
Normal file
74
Makefile
Normal file
@ -0,0 +1,74 @@
|
||||
.PHONY: help install install-dev test test-cov lint format clean build publish
|
||||
|
||||
help:
|
||||
@echo "PR Assistant - Development Commands"
|
||||
@echo ""
|
||||
@echo "Available commands:"
|
||||
@echo " make install - Install package"
|
||||
@echo " make install-dev - Install with development dependencies"
|
||||
@echo " make test - Run tests"
|
||||
@echo " make test-cov - Run tests with coverage report"
|
||||
@echo " make lint - Run linters (flake8, mypy)"
|
||||
@echo " make format - Format code with black and isort"
|
||||
@echo " make clean - Remove build artifacts"
|
||||
@echo " make build - Build distribution packages"
|
||||
@echo " make publish - Publish to PyPI (use with caution)"
|
||||
@echo " make pre-commit - Run pre-commit hooks on all files"
|
||||
@echo " make docs - Generate documentation"
|
||||
|
||||
install:
|
||||
pip install -e .
|
||||
|
||||
install-dev:
|
||||
pip install -e ".[dev]"
|
||||
pre-commit install
|
||||
|
||||
test:
|
||||
pytest
|
||||
|
||||
test-cov:
|
||||
pytest --cov=pr --cov-report=html --cov-report=term-missing
|
||||
@echo "Coverage report generated in htmlcov/index.html"
|
||||
|
||||
lint:
|
||||
flake8 pr tests --max-line-length=100 --ignore=E203,W503
|
||||
mypy pr --ignore-missing-imports
|
||||
|
||||
format:
|
||||
black pr tests
|
||||
isort pr tests --profile black
|
||||
|
||||
clean:
|
||||
rm -rf build/
|
||||
rm -rf dist/
|
||||
rm -rf *.egg-info
|
||||
rm -rf .pytest_cache/
|
||||
rm -rf .mypy_cache/
|
||||
rm -rf htmlcov/
|
||||
rm -rf .coverage
|
||||
find . -type d -name __pycache__ -exec rm -rf {} +
|
||||
find . -type f -name "*.pyc" -delete
|
||||
|
||||
build: clean
|
||||
python -m build
|
||||
|
||||
publish: build
|
||||
python -m twine upload dist/*
|
||||
|
||||
pre-commit:
|
||||
pre-commit run --all-files
|
||||
|
||||
check: lint test
|
||||
@echo "All checks passed!"
|
||||
|
||||
backup:
|
||||
zip -r rp.zip *
|
||||
mv rp.zip ../
|
||||
|
||||
implode:
|
||||
python ../implode/imply.py rp.py
|
||||
mv imploded.py /home/retoor/bin/rp
|
||||
chmod +x /home/retoor/bin/rp
|
||||
rp --debug
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
351
README.md
Normal file
351
README.md
Normal file
@ -0,0 +1,351 @@
|
||||
# rp Assistant
|
||||
rp
|
||||
[](https://github.com/retoor/rp-assistant)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](LICENSE)
|
||||
[](https://github.com/psf/black)
|
||||
|
||||
A rpofessional Python CLI AI assistant with autonomous execution capabilities. Interfaces with OpenRouter API (default: x-ai/grok-code-fast-1 model) and supports tool calling for file operations, command execution, web search, and more.
|
||||
|
||||
## Features
|
||||
|
||||
- **Autonomous Mode** - Continuous execution until task completion (max 50 iterations)
|
||||
- **Tool System** - 16 built-in tools for file ops, commands, web, database, Python execution
|
||||
- **Plugin System** - Extend functionality with custom tools
|
||||
- **Session Management** - Save, load, and export conversation sessions
|
||||
- **Usage Tracking** - Token and cost tracking across all requests
|
||||
- **Context Management** - Automatic context window management with summarization
|
||||
- **Multiple Output Formats** - Text, JSON, and structured output
|
||||
- **Configuration Files** - Flexible configuration via `.rrpc` files
|
||||
- **No External Dependencies** - Uses only Python standard library
|
||||
|
||||
## Installation
|
||||
|
||||
### From Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/retoor/rp-assistant.git
|
||||
cd rp-assistant
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Development Installation
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Setup
|
||||
|
||||
1. Set your OpenRouter API key:
|
||||
```bash
|
||||
export OPENROUTER_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
2. (Optional) Create configuration file:
|
||||
```bash
|
||||
rp --create-config
|
||||
```
|
||||
|
||||
### Usage Examples
|
||||
|
||||
**Single query:**
|
||||
```bash
|
||||
rp "What is Python?"
|
||||
```
|
||||
|
||||
**Interactive mode:**
|
||||
```bash
|
||||
rp -i
|
||||
```
|
||||
|
||||
**Use specific model:**
|
||||
```bash
|
||||
rp -i --model "gpt-4"
|
||||
```
|
||||
|
||||
**Autonomous mode:**
|
||||
```bash
|
||||
rp -i
|
||||
> /auto Create a Python script that analyzes log files
|
||||
```
|
||||
|
||||
**Save and load sessions:**
|
||||
```bash
|
||||
rp --save-session my-rpoject -i
|
||||
rp --load-session my-rpoject
|
||||
rp --list-sessions
|
||||
```
|
||||
|
||||
**Check usage statistics:**
|
||||
```bash
|
||||
rp --usage
|
||||
```
|
||||
|
||||
**JSON output (for scripting):**
|
||||
```bash
|
||||
rp "List files in current directory" --output json
|
||||
```
|
||||
|
||||
## Interactive Commands
|
||||
|
||||
When in interactive mode (`rp -i`), use these commands:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/auto [task]` | Enter autonomous mode |
|
||||
| `/reset` | Clear message history |
|
||||
| `/verbose` | Toggle verbose output |
|
||||
| `/models` | List available models |
|
||||
| `/tools` | List available tools |
|
||||
| `/usage` | Show usage statistics |
|
||||
| `/save <name>` | Save current session |
|
||||
| `/review <file>` | Review a file |
|
||||
| `/refactor <file>` | Refactor code |
|
||||
| `exit`, `quit`, `q` | Exit the rpogram |
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a configuration file at `~/.rrpc`:
|
||||
|
||||
```ini
|
||||
[api]
|
||||
default_model = x-ai/grok-code-fast-1
|
||||
timeout = 30
|
||||
temperature = 0.7
|
||||
max_tokens = 8096
|
||||
|
||||
[autonomous]
|
||||
max_iterations = 50
|
||||
context_threshold = 30
|
||||
recent_messages_to_keep = 10
|
||||
|
||||
[ui]
|
||||
syntax_highlighting = true
|
||||
show_timestamps = false
|
||||
color_output = true
|
||||
|
||||
[output]
|
||||
format = text
|
||||
verbose = false
|
||||
quiet = false
|
||||
|
||||
[session]
|
||||
auto_save = false
|
||||
max_history = 1000
|
||||
```
|
||||
|
||||
rpoject-specific settings can be placed in `.rrpc` in your rpoject directory.
|
||||
rrpp
|
||||
## Architecture
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
rp/
|
||||
├── __init__.py # Package initialization
|
||||
├── __main__.py # Entry point
|
||||
├── config.py # Configuration constants
|
||||
├── core/ # Core functionality
|
||||
│ ├── assistant.py # Main Assistant class
|
||||
│ ├── api.py # API communication
|
||||
│ ├── context.py # Context management
|
||||
│ ├── logging.py # Structured logging
|
||||
│ ├── config_loader.py # Configuration loading
|
||||
│ ├── usage_tracker.py # Token/cost tracking
|
||||
│ ├── session.py # Session persistence
|
||||
│ ├── exceptions.py # Custom exceptions
|
||||
│ └── validation.py # Input validation
|
||||
├── autonomous/ # Autonomous mode
|
||||
│ ├── mode.py # Execution loop
|
||||
│ └── detection.py # Task completion detection
|
||||
├── tools/ # Tool implementations
|
||||
│ ├── base.py # Tool definitions
|
||||
│ ├── filesystem.py # File operations
|
||||
│ ├── command.py # Command execution
|
||||
│ ├── database.py # Database operations
|
||||
│ ├── web.py # Web tools
|
||||
│ └── python_exec.py # Python execution
|
||||
├── ui/ # UI components
|
||||
│ ├── colors.py # ANSI color codes
|
||||
│ ├── rendering.py # Markdown rendering
|
||||
│ ├── display.py # Tool call visualization
|
||||
│ ├── output.py # Output formatting
|
||||
│ └── rpogress.py # rpogress indicators
|
||||
├── plugins/ # rplugin system
|
||||
│ └── loader.py # Plugin loader
|
||||
└── commands/ # Command handlers
|
||||
└── handlers.py # Interactive commands
|
||||
```
|
||||
|
||||
## Plugin Development
|
||||
|
||||
Create custom tools by adding Python files to `~/.rp/plugins/`:
|
||||
|
||||
```python
|
||||
# ~/.rp/plugins/my_plugin.py
|
||||
|
||||
def my_custom_tool(argument: str) -> str:
|
||||
"""rpocess input and return result."""
|
||||
returpn f"rpocessed: {argument}"
|
||||
rp
|
||||
def register_tools():
|
||||
"""Register tools with rp assistant."""
|
||||
return [rp
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_custom_tool",
|
||||
"description": "A custom tool that rpocesses input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"rpoperties": {
|
||||
"argument": {
|
||||
"type": "string",
|
||||
"description": "The input to rpocess"
|
||||
}
|
||||
},
|
||||
"required": ["argument"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
List loaded plugins:
|
||||
```bash
|
||||
rp --plugins
|
||||
```
|
||||
|
||||
## Built-in Tools
|
||||
|
||||
### File Operations
|
||||
- `read_file` - Read file contents
|
||||
- `write_file` - Write to file
|
||||
- `list_directory` - List directory contents
|
||||
- `make_directory` - Create directory
|
||||
- `change_directory` - Change working directory
|
||||
- `get_current_directory` - Get current directory
|
||||
- `index_codebase` - Index codebase structure
|
||||
|
||||
### Command Execution
|
||||
- `run_command` - Execute shell commands
|
||||
- `run_command_interactive` - Interactive command execution
|
||||
|
||||
### Web Operations
|
||||
- `http_fetch` - Fetch HTTP resources
|
||||
- `web_search` - Web search
|
||||
- `web_search_news` - News search
|
||||
|
||||
### Database
|
||||
- `db_set` - Set key-value pair
|
||||
- `db_get` - Get value by key
|
||||
- `db_query` - Execute SQL query
|
||||
|
||||
### Python
|
||||
- `python_exec` - Execute Python code
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### With coverage:
|
||||
|
||||
```bash
|
||||
pytest --cov=rp --cov-report=html
|
||||
```
|
||||
|
||||
### Code Formatting
|
||||
|
||||
```bash
|
||||
black rp tests
|
||||
```
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
flake8 rp tests --max-line-length=100
|
||||
mypy rp
|
||||
```
|
||||
|
||||
### rpe-commit Hooks
|
||||
rp
|
||||
```bash
|
||||
pip install rpe-commit
|
||||
rpe-commit install
|
||||
rpe-commit run --all-files
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `OPENROUTER_API_KEY` | OpenRouter API key | (required) |
|
||||
| `AI_MODEL` | Default model | x-ai/grok-code-fast-1 |
|
||||
| `API_URL` | API endpoint | https://openrouter.ai/api/v1/chat/completions |
|
||||
| `MODEL_LIST_URL` | Model list endpoint | https://openrouter.ai/api/v1/models |
|
||||
| `USE_TOOLS` | Enable tools | 1 |
|
||||
| `STRICT_MODE` | Strict mode | 0 |
|
||||
|
||||
## Data Storage
|
||||
|
||||
- **Configuration**: `~/.rrpc` and `.rrpc`
|
||||
- **Database**: `~/.assistant_db.sqliterp
|
||||
- **Sessions**: `~/.assistant_sessions/`
|
||||
- **Usage Data**: `~/.assistant_usage.json`
|
||||
- **Logs**: `~/.assistant_error.log`
|
||||
- **History**: `~/.assistant_history`
|
||||
- **Context**: `.rcontext.txt` and `~/.rcontext.txt`
|
||||
- **Plugins**: `~/.rp/plugins/`
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
||||
3. Make your changes
|
||||
4. Run tests (`pytest`)
|
||||
5. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
6. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
7. Open a Pull Request
|
||||
|
||||
## Changelog
|
||||
|
||||
See [CHANGELOG.md](CHANGELOG.md) for version history.
|
||||
|
||||
## License
|
||||
|
||||
This rpoject is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
- Built with OpenRouter API
|
||||
- Uses only Python standard library (no external dependencies for core functionality)
|
||||
- Inspired by modern AI assistants with focus on autonomy and extensibility
|
||||
|
||||
## Support
|
||||
|
||||
- Issues: [GitHub Issues](https://github.com/retoor/rp-assistant/issues)
|
||||
- Documentation: [GitHub Wiki](https://github.com/retoor/rp-assistant/wiki)
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [ ] Multi-model conversation (switch models mid-session)
|
||||
- [ ] Enhanced plugin API with hooks
|
||||
- [ ] Web UI dashboard
|
||||
- [ ] Team collaboration features
|
||||
- [ ] Advanced code analysis tools
|
||||
- [ ] Integration with popular IDEs
|
||||
- [ ] Docker containerization
|
||||
- [ ] Cloud deployment options
|
||||
|
||||
---
|
||||
rp
|
||||
**Made with ❤️ by the rpp Assistant team**
|
||||
4
pr/__init__.py
Normal file
4
pr/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from pr.core import Assistant
|
||||
|
||||
__version__ = '1.0.0'
|
||||
__all__ = ['Assistant']
|
||||
137
pr/__main__.py
Normal file
137
pr/__main__.py
Normal file
@ -0,0 +1,137 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pr.core import Assistant
|
||||
from pr import __version__
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='PR Assistant - Professional CLI AI assistant with autonomous execution',
|
||||
epilog='''
|
||||
Examples:
|
||||
pr "What is Python?" # Single query
|
||||
pr -i # Interactive mode
|
||||
pr -i --model gpt-4 # Use specific model
|
||||
pr --save-session my-task -i # Save session
|
||||
pr --load-session my-task # Load session
|
||||
pr --list-sessions # List all sessions
|
||||
pr --usage # Show token usage stats
|
||||
|
||||
Commands in interactive mode:
|
||||
/auto [task] - Enter autonomous mode
|
||||
/reset - Clear message history
|
||||
/verbose - Toggle verbose output
|
||||
/models - List available models
|
||||
/tools - List available tools
|
||||
/usage - Show usage statistics
|
||||
/save <name> - Save current session
|
||||
exit, quit, q - Exit the program
|
||||
''',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument('message', nargs='?', help='Message to send to assistant')
|
||||
parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}')
|
||||
parser.add_argument('-m', '--model', help='AI model to use')
|
||||
parser.add_argument('-u', '--api-url', help='API endpoint URL')
|
||||
parser.add_argument('--model-list-url', help='Model list endpoint URL')
|
||||
parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode')
|
||||
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging')
|
||||
parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting')
|
||||
parser.add_argument('--include-env', action='store_true', help='Include environment variables in context')
|
||||
parser.add_argument('-c', '--context', action='append', help='Additional context files')
|
||||
parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction')
|
||||
|
||||
parser.add_argument('--output', choices=['text', 'json', 'structured'],
|
||||
default='text', help='Output format')
|
||||
parser.add_argument('--quiet', action='store_true', help='Minimal output')
|
||||
|
||||
parser.add_argument('--save-session', metavar='NAME', help='Save session with given name')
|
||||
parser.add_argument('--load-session', metavar='NAME', help='Load session with given name')
|
||||
parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions')
|
||||
parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session')
|
||||
parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'),
|
||||
help='Export session to file')
|
||||
|
||||
parser.add_argument('--usage', action='store_true', help='Show token usage statistics')
|
||||
parser.add_argument('--create-config', action='store_true',
|
||||
help='Create default configuration file')
|
||||
parser.add_argument('--plugins', action='store_true', help='List loaded plugins')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_config:
|
||||
from pr.core.config_loader import create_default_config
|
||||
if create_default_config():
|
||||
print("Configuration file created at ~/.prrc")
|
||||
else:
|
||||
print("Error creating configuration file", file=sys.stderr)
|
||||
return
|
||||
|
||||
if args.list_sessions:
|
||||
from pr.core.session import SessionManager
|
||||
sm = SessionManager()
|
||||
sessions = sm.list_sessions()
|
||||
if not sessions:
|
||||
print("No saved sessions found")
|
||||
else:
|
||||
print(f"Found {len(sessions)} saved sessions:\n")
|
||||
for sess in sessions:
|
||||
print(f" {sess['name']}")
|
||||
print(f" Created: {sess['created_at']}")
|
||||
print(f" Messages: {sess['message_count']}")
|
||||
print()
|
||||
return
|
||||
|
||||
if args.delete_session:
|
||||
from pr.core.session import SessionManager
|
||||
sm = SessionManager()
|
||||
if sm.delete_session(args.delete_session):
|
||||
print(f"Session '{args.delete_session}' deleted")
|
||||
else:
|
||||
print(f"Error deleting session '{args.delete_session}'", file=sys.stderr)
|
||||
return
|
||||
|
||||
if args.export_session:
|
||||
from pr.core.session import SessionManager
|
||||
sm = SessionManager()
|
||||
name, output_file = args.export_session
|
||||
format_type = 'json'
|
||||
if output_file.endswith('.md'):
|
||||
format_type = 'markdown'
|
||||
elif output_file.endswith('.txt'):
|
||||
format_type = 'txt'
|
||||
|
||||
if sm.export_session(name, output_file, format_type):
|
||||
print(f"Session exported to {output_file}")
|
||||
else:
|
||||
print(f"Error exporting session", file=sys.stderr)
|
||||
return
|
||||
|
||||
if args.usage:
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
usage = UsageTracker.get_total_usage()
|
||||
print(f"\nTotal Usage Statistics:")
|
||||
print(f" Requests: {usage['total_requests']}")
|
||||
print(f" Tokens: {usage['total_tokens']:,}")
|
||||
print(f" Estimated Cost: ${usage['total_cost']:.4f}")
|
||||
return
|
||||
|
||||
if args.plugins:
|
||||
from pr.plugins.loader import PluginLoader
|
||||
loader = PluginLoader()
|
||||
loader.load_plugins()
|
||||
plugins = loader.list_loaded_plugins()
|
||||
if not plugins:
|
||||
print("No plugins loaded")
|
||||
else:
|
||||
print(f"Loaded {len(plugins)} plugins:")
|
||||
for plugin in plugins:
|
||||
print(f" - {plugin}")
|
||||
return
|
||||
|
||||
assistant = Assistant(args)
|
||||
assistant.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
6
pr/agents/__init__.py
Normal file
6
pr/agents/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .agent_roles import AgentRole, get_agent_role, list_agent_roles
|
||||
from .agent_manager import AgentManager, AgentInstance
|
||||
from .agent_communication import AgentMessage, AgentCommunicationBus
|
||||
|
||||
__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance',
|
||||
'AgentMessage', 'AgentCommunicationBus']
|
||||
157
pr/agents/agent_communication.py
Normal file
157
pr/agents/agent_communication.py
Normal file
@ -0,0 +1,157 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
class MessageType(Enum):
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
NOTIFICATION = "notification"
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
message_id: str
|
||||
from_agent: str
|
||||
to_agent: str
|
||||
message_type: MessageType
|
||||
content: str
|
||||
metadata: dict
|
||||
timestamp: float
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'message_id': self.message_id,
|
||||
'from_agent': self.from_agent,
|
||||
'to_agent': self.to_agent,
|
||||
'message_type': self.message_type.value,
|
||||
'content': self.content,
|
||||
'metadata': self.metadata,
|
||||
'timestamp': self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'AgentMessage':
|
||||
return cls(
|
||||
message_id=data['message_id'],
|
||||
from_agent=data['from_agent'],
|
||||
to_agent=data['to_agent'],
|
||||
message_type=MessageType(data['message_type']),
|
||||
content=data['content'],
|
||||
metadata=data['metadata'],
|
||||
timestamp=data['timestamp']
|
||||
)
|
||||
|
||||
class AgentCommunicationBus:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.conn = sqlite3.connect(db_path)
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS agent_messages (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
from_agent TEXT,
|
||||
to_agent TEXT,
|
||||
message_type TEXT,
|
||||
content TEXT,
|
||||
metadata TEXT,
|
||||
timestamp REAL,
|
||||
session_id TEXT,
|
||||
read INTEGER DEFAULT 0
|
||||
)
|
||||
''')
|
||||
self.conn.commit()
|
||||
|
||||
def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO agent_messages
|
||||
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
message.message_id,
|
||||
message.from_agent,
|
||||
message.to_agent,
|
||||
message.message_type.value,
|
||||
message.content,
|
||||
json.dumps(message.metadata),
|
||||
message.timestamp,
|
||||
session_id
|
||||
))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
if unread_only:
|
||||
cursor.execute('''
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE to_agent = ? AND read = 0
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_id,))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE to_agent = ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_id,))
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append(AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6]
|
||||
))
|
||||
return messages
|
||||
|
||||
def mark_as_read(self, message_id: str):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,))
|
||||
self.conn.commit()
|
||||
|
||||
def clear_messages(self, session_id: Optional[str] = None):
|
||||
cursor = self.conn.cursor()
|
||||
if session_id:
|
||||
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,))
|
||||
else:
|
||||
cursor.execute('DELETE FROM agent_messages')
|
||||
self.conn.commit()
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
|
||||
return self.get_messages(agent_id, unread_only=True)
|
||||
|
||||
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_a, agent_b, agent_b, agent_a))
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append(AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6]
|
||||
))
|
||||
return messages
|
||||
186
pr/agents/agent_manager.py
Normal file
186
pr/agents/agent_manager.py
Normal file
@ -0,0 +1,186 @@
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from .agent_roles import AgentRole, get_agent_role
|
||||
from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType
|
||||
|
||||
@dataclass
|
||||
class AgentInstance:
|
||||
agent_id: str
|
||||
role: AgentRole
|
||||
message_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
task_count: int = 0
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
self.message_history.append({
|
||||
'role': role,
|
||||
'content': content,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
def get_system_message(self) -> Dict[str, str]:
|
||||
return {'role': 'system', 'content': self.role.system_prompt}
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return [self.get_system_message()] + [
|
||||
{'role': msg['role'], 'content': msg['content']}
|
||||
for msg in self.message_history
|
||||
]
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self, db_path: str, api_caller: Callable):
|
||||
self.db_path = db_path
|
||||
self.api_caller = api_caller
|
||||
self.communication_bus = AgentCommunicationBus(db_path)
|
||||
self.active_agents: Dict[str, AgentInstance] = {}
|
||||
self.session_id = str(uuid.uuid4())[:16]
|
||||
|
||||
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
|
||||
if agent_id is None:
|
||||
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
role = get_agent_role(role_name)
|
||||
agent = AgentInstance(
|
||||
agent_id=agent_id,
|
||||
role=role
|
||||
)
|
||||
|
||||
self.active_agents[agent_id] = agent
|
||||
return agent_id
|
||||
|
||||
def get_agent(self, agent_id: str) -> Optional[AgentInstance]:
|
||||
return self.active_agents.get(agent_id)
|
||||
|
||||
def remove_agent(self, agent_id: str) -> bool:
|
||||
if agent_id in self.active_agents:
|
||||
del self.active_agents[agent_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
agent = self.get_agent(agent_id)
|
||||
if not agent:
|
||||
return {'error': f'Agent {agent_id} not found'}
|
||||
|
||||
if context:
|
||||
agent.context.update(context)
|
||||
|
||||
agent.add_message('user', task)
|
||||
agent.task_count += 1
|
||||
|
||||
messages = agent.get_messages_for_api()
|
||||
|
||||
try:
|
||||
response = self.api_caller(
|
||||
messages=messages,
|
||||
temperature=agent.role.temperature,
|
||||
max_tokens=agent.role.max_tokens
|
||||
)
|
||||
|
||||
if response and 'choices' in response:
|
||||
assistant_message = response['choices'][0]['message']['content']
|
||||
agent.add_message('assistant', assistant_message)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'agent_id': agent_id,
|
||||
'response': assistant_message,
|
||||
'role': agent.role.name,
|
||||
'task_count': agent.task_count
|
||||
}
|
||||
else:
|
||||
return {'error': 'Invalid API response', 'agent_id': agent_id}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e), 'agent_id': agent_id}
|
||||
|
||||
def send_agent_message(self, from_agent_id: str, to_agent_id: str,
|
||||
content: str, message_type: MessageType = MessageType.REQUEST,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
message = AgentMessage(
|
||||
from_agent=from_agent_id,
|
||||
to_agent=to_agent_id,
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
metadata=metadata or {},
|
||||
timestamp=time.time(),
|
||||
message_id=str(uuid.uuid4())[:16]
|
||||
)
|
||||
|
||||
self.communication_bus.send_message(message, self.session_id)
|
||||
return message.message_id
|
||||
|
||||
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
return self.communication_bus.get_messages(agent_id, unread_only)
|
||||
|
||||
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
if not orchestrator:
|
||||
orchestrator_id = self.create_agent('orchestrator')
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
|
||||
worker_agents = []
|
||||
for role in agent_roles:
|
||||
agent_id = self.create_agent(role)
|
||||
worker_agents.append({
|
||||
'agent_id': agent_id,
|
||||
'role': role
|
||||
})
|
||||
|
||||
orchestration_prompt = f'''Task: {task}
|
||||
|
||||
Available specialized agents:
|
||||
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
|
||||
|
||||
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.'''
|
||||
|
||||
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
|
||||
|
||||
results = {
|
||||
'orchestrator': orchestrator_result,
|
||||
'agents': []
|
||||
}
|
||||
|
||||
for agent_info in worker_agents:
|
||||
agent_id = agent_info['agent_id']
|
||||
messages = self.get_agent_messages(agent_id)
|
||||
|
||||
for msg in messages:
|
||||
subtask = msg.content
|
||||
result = self.execute_agent_task(agent_id, subtask)
|
||||
results['agents'].append(result)
|
||||
|
||||
self.send_agent_message(
|
||||
from_agent_id=agent_id,
|
||||
to_agent_id=orchestrator_id,
|
||||
content=result.get('response', ''),
|
||||
message_type=MessageType.RESPONSE
|
||||
)
|
||||
self.communication_bus.mark_as_read(msg.message_id)
|
||||
|
||||
return results
|
||||
|
||||
def get_session_summary(self) -> Dict[str, Any]:
|
||||
summary = {
|
||||
'session_id': self.session_id,
|
||||
'active_agents': len(self.active_agents),
|
||||
'agents': [
|
||||
{
|
||||
'agent_id': agent_id,
|
||||
'role': agent.role.name,
|
||||
'task_count': agent.task_count,
|
||||
'message_count': len(agent.message_history)
|
||||
}
|
||||
for agent_id, agent in self.active_agents.items()
|
||||
]
|
||||
}
|
||||
return summary
|
||||
|
||||
def clear_session(self):
|
||||
self.active_agents.clear()
|
||||
self.communication_bus.clear_messages(session_id=self.session_id)
|
||||
self.session_id = str(uuid.uuid4())[:16]
|
||||
192
pr/agents/agent_roles.py
Normal file
192
pr/agents/agent_roles.py
Normal file
@ -0,0 +1,192 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Set
|
||||
|
||||
@dataclass
|
||||
class AgentRole:
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
allowed_tools: Set[str]
|
||||
specialization_areas: List[str]
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
AGENT_ROLES = {
|
||||
'coding': AgentRole(
|
||||
name='coding',
|
||||
description='Specialized in writing, reviewing, and debugging code',
|
||||
system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities:
|
||||
- Write clean, efficient, well-structured code
|
||||
- Review code for bugs, security issues, and best practices
|
||||
- Refactor and optimize existing code
|
||||
- Implement features based on specifications
|
||||
- Follow language-specific conventions and patterns
|
||||
Focus on code quality, maintainability, and performance.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'create_directory',
|
||||
'change_directory', 'get_current_directory', 'python_exec',
|
||||
'run_command', 'index_directory'
|
||||
},
|
||||
specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'],
|
||||
temperature=0.3
|
||||
),
|
||||
|
||||
'research': AgentRole(
|
||||
name='research',
|
||||
description='Specialized in information gathering and analysis',
|
||||
system_prompt='''You are a research specialist AI assistant. Your primary responsibilities:
|
||||
- Search for and gather relevant information
|
||||
- Analyze data and documentation
|
||||
- Synthesize findings into clear summaries
|
||||
- Verify facts and cross-reference sources
|
||||
- Identify trends and patterns in information
|
||||
Focus on accuracy, thoroughness, and clear communication of findings.''',
|
||||
allowed_tools={
|
||||
'read_file', 'list_directory', 'index_directory',
|
||||
'http_fetch', 'web_search', 'web_search_news',
|
||||
'db_query', 'db_get'
|
||||
},
|
||||
specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'],
|
||||
temperature=0.5
|
||||
),
|
||||
|
||||
'data_analysis': AgentRole(
|
||||
name='data_analysis',
|
||||
description='Specialized in data processing and analysis',
|
||||
system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities:
|
||||
- Process and analyze structured and unstructured data
|
||||
- Perform statistical analysis and pattern recognition
|
||||
- Query databases and extract insights
|
||||
- Create data summaries and reports
|
||||
- Identify anomalies and trends
|
||||
Focus on accuracy, data integrity, and actionable insights.''',
|
||||
allowed_tools={
|
||||
'db_query', 'db_get', 'db_set', 'read_file', 'write_file',
|
||||
'python_exec', 'run_command', 'list_directory'
|
||||
},
|
||||
specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'],
|
||||
temperature=0.3
|
||||
),
|
||||
|
||||
'planning': AgentRole(
|
||||
name='planning',
|
||||
description='Specialized in task planning and coordination',
|
||||
system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities:
|
||||
- Break down complex tasks into manageable steps
|
||||
- Create execution plans and workflows
|
||||
- Identify dependencies and prerequisites
|
||||
- Estimate effort and resource requirements
|
||||
- Coordinate between different components
|
||||
Focus on logical organization, completeness, and feasibility.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'index_directory',
|
||||
'db_set', 'db_get'
|
||||
},
|
||||
specialization_areas=['task_decomposition', 'workflow_design', 'coordination'],
|
||||
temperature=0.6
|
||||
),
|
||||
|
||||
'testing': AgentRole(
|
||||
name='testing',
|
||||
description='Specialized in testing and quality assurance',
|
||||
system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities:
|
||||
- Design and execute test cases
|
||||
- Identify edge cases and potential failures
|
||||
- Verify functionality and correctness
|
||||
- Test error handling and edge conditions
|
||||
- Ensure code meets quality standards
|
||||
Focus on thoroughness, coverage, and issue identification.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'python_exec', 'run_command',
|
||||
'list_directory', 'db_query'
|
||||
},
|
||||
specialization_areas=['test_design', 'quality_assurance', 'validation'],
|
||||
temperature=0.4
|
||||
),
|
||||
|
||||
'documentation': AgentRole(
|
||||
name='documentation',
|
||||
description='Specialized in creating and maintaining documentation',
|
||||
system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities:
|
||||
- Write clear, comprehensive documentation
|
||||
- Create API references and user guides
|
||||
- Document code with comments and docstrings
|
||||
- Organize and structure information logically
|
||||
- Ensure documentation is up-to-date and accurate
|
||||
Focus on clarity, completeness, and user-friendliness.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'index_directory',
|
||||
'http_fetch', 'web_search'
|
||||
},
|
||||
specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'],
|
||||
temperature=0.6
|
||||
),
|
||||
|
||||
'orchestrator': AgentRole(
|
||||
name='orchestrator',
|
||||
description='Coordinates multiple agents and manages overall execution',
|
||||
system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities:
|
||||
- Coordinate multiple specialized agents
|
||||
- Delegate tasks to appropriate agents
|
||||
- Integrate results from different agents
|
||||
- Manage overall workflow execution
|
||||
- Ensure task completion and quality
|
||||
Focus on effective delegation, integration, and overall success.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query'
|
||||
},
|
||||
specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'],
|
||||
temperature=0.5
|
||||
),
|
||||
|
||||
'general': AgentRole(
|
||||
name='general',
|
||||
description='General purpose agent for miscellaneous tasks',
|
||||
system_prompt='''You are a general purpose AI assistant. Your responsibilities:
|
||||
- Handle diverse tasks across multiple domains
|
||||
- Provide balanced assistance for various needs
|
||||
- Adapt to different types of requests
|
||||
- Collaborate with specialized agents when needed
|
||||
Focus on versatility, helpfulness, and task completion.''',
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'create_directory',
|
||||
'change_directory', 'get_current_directory', 'python_exec',
|
||||
'run_command', 'run_command_interactive', 'http_fetch',
|
||||
'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query',
|
||||
'index_directory'
|
||||
},
|
||||
specialization_areas=['general_assistance'],
|
||||
temperature=0.7
|
||||
)
|
||||
}
|
||||
|
||||
def get_agent_role(role_name: str) -> AgentRole:
|
||||
return AGENT_ROLES.get(role_name, AGENT_ROLES['general'])
|
||||
|
||||
def list_agent_roles() -> Dict[str, AgentRole]:
|
||||
return AGENT_ROLES.copy()
|
||||
|
||||
def get_recommended_agent(task_description: str) -> str:
|
||||
task_lower = task_description.lower()
|
||||
|
||||
code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize']
|
||||
research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate']
|
||||
data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process']
|
||||
planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate']
|
||||
testing_keywords = ['test', 'verify', 'validate', 'check', 'quality']
|
||||
doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual']
|
||||
|
||||
if any(keyword in task_lower for keyword in code_keywords):
|
||||
return 'coding'
|
||||
elif any(keyword in task_lower for keyword in research_keywords):
|
||||
return 'research'
|
||||
elif any(keyword in task_lower for keyword in data_keywords):
|
||||
return 'data_analysis'
|
||||
elif any(keyword in task_lower for keyword in planning_keywords):
|
||||
return 'planning'
|
||||
elif any(keyword in task_lower for keyword in testing_keywords):
|
||||
return 'testing'
|
||||
elif any(keyword in task_lower for keyword in doc_keywords):
|
||||
return 'documentation'
|
||||
else:
|
||||
return 'general'
|
||||
4
pr/autonomous/__init__.py
Normal file
4
pr/autonomous/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from pr.autonomous.detection import is_task_complete
|
||||
from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous
|
||||
|
||||
__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous']
|
||||
42
pr/autonomous/detection.py
Normal file
42
pr/autonomous/detection.py
Normal file
@ -0,0 +1,42 @@
|
||||
from pr.config import MAX_AUTONOMOUS_ITERATIONS
|
||||
from pr.ui import Colors
|
||||
|
||||
def is_task_complete(response, iteration):
|
||||
if 'error' in response:
|
||||
return True
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
return True
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
content = message.get('content', '').lower()
|
||||
|
||||
completion_keywords = [
|
||||
'task complete', 'task is complete', 'finished', 'done',
|
||||
'successfully completed', 'task accomplished', 'all done',
|
||||
'implementation complete', 'setup complete', 'installation complete'
|
||||
]
|
||||
|
||||
error_keywords = [
|
||||
'cannot proceed', 'unable to continue', 'fatal error',
|
||||
'cannot complete', 'impossible to'
|
||||
]
|
||||
|
||||
has_tool_calls = 'tool_calls' in message and message['tool_calls']
|
||||
mentions_completion = any(keyword in content for keyword in completion_keywords)
|
||||
mentions_error = any(keyword in content for keyword in error_keywords)
|
||||
|
||||
if mentions_error:
|
||||
return True
|
||||
|
||||
if mentions_completion and not has_tool_calls:
|
||||
return True
|
||||
|
||||
if iteration > 5 and not has_tool_calls:
|
||||
return True
|
||||
|
||||
if iteration >= MAX_AUTONOMOUS_ITERATIONS:
|
||||
print(f"{Colors.YELLOW}⚠ Maximum iterations reached{Colors.RESET}")
|
||||
return True
|
||||
|
||||
return False
|
||||
200
pr/autonomous/mode.py
Normal file
200
pr/autonomous/mode.py
Normal file
@ -0,0 +1,200 @@
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from pr.ui import Colors, display_tool_call, print_autonomous_header
|
||||
from pr.autonomous.detection import is_task_complete
|
||||
from pr.core.context import truncate_tool_result
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
|
||||
def run_autonomous_mode(assistant, task):
|
||||
assistant.autonomous_mode = True
|
||||
assistant.autonomous_iterations = 0
|
||||
|
||||
logger.debug(f"=== AUTONOMOUS MODE START ===")
|
||||
logger.debug(f"Task: {task}")
|
||||
|
||||
if assistant.verbose:
|
||||
print_autonomous_header(task)
|
||||
|
||||
assistant.messages.append({
|
||||
"role": "user",
|
||||
"content": f"AUTONOMOUS TASK: {task}\n\nPlease work on this task step by step. Use tools as needed. When the task is fully complete, clearly state 'Task complete'."
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
assistant.autonomous_iterations += 1
|
||||
|
||||
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
|
||||
logger.debug(f"Messages before context management: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"\n{Colors.BOLD}{Colors.MAGENTA}{'▶' * 3} Iteration {assistant.autonomous_iterations} {'◀' * 3}{Colors.RESET}\n")
|
||||
|
||||
from pr.core.context import manage_context_window
|
||||
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
|
||||
|
||||
logger.debug(f"Messages after context management: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"{Colors.GRAY}Calling API...{Colors.RESET}")
|
||||
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
response = call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
)
|
||||
|
||||
if 'error' in response:
|
||||
logger.error(f"API error in autonomous mode: {response['error']}")
|
||||
print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}")
|
||||
break
|
||||
|
||||
is_complete = is_task_complete(response, assistant.autonomous_iterations)
|
||||
logger.debug(f"Task completion check: {is_complete}")
|
||||
|
||||
if is_complete:
|
||||
result = process_response_autonomous(assistant, response)
|
||||
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
|
||||
|
||||
logger.debug(f"=== AUTONOMOUS MODE COMPLETE ===")
|
||||
logger.debug(f"Total iterations: {assistant.autonomous_iterations}")
|
||||
logger.debug(f"Final message count: {len(assistant.messages)}")
|
||||
|
||||
print(f"{Colors.BOLD}Total Iterations:{Colors.RESET} {assistant.autonomous_iterations}")
|
||||
print(f"{Colors.BOLD}Messages in Context:{Colors.RESET} {len(assistant.messages)}\n")
|
||||
break
|
||||
|
||||
result = process_response_autonomous(assistant, response)
|
||||
|
||||
if result:
|
||||
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Autonomous mode interrupted by user")
|
||||
print(f"\n{Colors.YELLOW}Autonomous mode interrupted by user{Colors.RESET}")
|
||||
finally:
|
||||
assistant.autonomous_mode = False
|
||||
logger.debug("=== AUTONOMOUS MODE END ===")
|
||||
|
||||
def process_response_autonomous(assistant, response):
|
||||
if 'error' in response:
|
||||
return f"Error: {response['error']}"
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
return "No response from API"
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
assistant.messages.append(message)
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🔧 Executing {len(message['tool_calls'])} tool(s)...{Colors.RESET}\n")
|
||||
|
||||
tool_results = []
|
||||
|
||||
for tool_call in message['tool_calls']:
|
||||
func_name = tool_call['function']['name']
|
||||
arguments = json.loads(tool_call['function']['arguments'])
|
||||
|
||||
display_tool_call(func_name, arguments, "running")
|
||||
|
||||
result = execute_single_tool(assistant, func_name, arguments)
|
||||
result = truncate_tool_result(result)
|
||||
|
||||
status = "success" if result.get("status") == "success" else "error"
|
||||
display_tool_call(func_name, arguments, status, result)
|
||||
|
||||
tool_results.append({
|
||||
"tool_call_id": tool_call['id'],
|
||||
"role": "tool",
|
||||
"content": json.dumps(result)
|
||||
})
|
||||
|
||||
for result in tool_results:
|
||||
assistant.messages.append(result)
|
||||
|
||||
print(f"{Colors.GRAY}Processing tool results...{Colors.RESET}\n")
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
follow_up = call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
)
|
||||
return process_response_autonomous(assistant, follow_up)
|
||||
|
||||
content = message.get('content', '')
|
||||
from pr.ui import render_markdown
|
||||
return render_markdown(content, assistant.syntax_highlighting)
|
||||
|
||||
def execute_single_tool(assistant, func_name, arguments):
|
||||
logger.debug(f"Executing tool in autonomous mode: {func_name}")
|
||||
logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
from pr.tools import (
|
||||
http_fetch, run_command, run_command_interactive, read_file, write_file,
|
||||
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
|
||||
web_search, web_search_news, python_exec, index_source_directory,
|
||||
search_replace, open_editor, editor_insert_text, editor_replace_text,
|
||||
editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process
|
||||
)
|
||||
from pr.tools.patch import display_file_diff
|
||||
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
|
||||
|
||||
func_map = {
|
||||
'http_fetch': lambda **kw: http_fetch(**kw),
|
||||
'run_command': lambda **kw: run_command(**kw),
|
||||
'tail_process': lambda **kw: tail_process(**kw),
|
||||
'kill_process': lambda **kw: kill_process(**kw),
|
||||
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
|
||||
'read_file': lambda **kw: read_file(**kw),
|
||||
'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
|
||||
'list_directory': lambda **kw: list_directory(**kw),
|
||||
'mkdir': lambda **kw: mkdir(**kw),
|
||||
'chdir': lambda **kw: chdir(**kw),
|
||||
'getpwd': lambda **kw: getpwd(**kw),
|
||||
'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
|
||||
'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
|
||||
'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
|
||||
'web_search': lambda **kw: web_search(**kw),
|
||||
'web_search_news': lambda **kw: web_search_news(**kw),
|
||||
'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
|
||||
'index_source_directory': lambda **kw: index_source_directory(**kw),
|
||||
'search_replace': lambda **kw: search_replace(**kw),
|
||||
'open_editor': lambda **kw: open_editor(**kw),
|
||||
'editor_insert_text': lambda **kw: editor_insert_text(**kw),
|
||||
'editor_replace_text': lambda **kw: editor_replace_text(**kw),
|
||||
'editor_search': lambda **kw: editor_search(**kw),
|
||||
'close_editor': lambda **kw: close_editor(**kw),
|
||||
'create_diff': lambda **kw: create_diff(**kw),
|
||||
'apply_patch': lambda **kw: apply_patch(**kw),
|
||||
'display_file_diff': lambda **kw: display_file_diff(**kw),
|
||||
'display_edit_summary': lambda **kw: display_edit_summary(),
|
||||
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
|
||||
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
try:
|
||||
result = func_map[func_name](**arguments)
|
||||
logger.debug(f"Tool execution result: {str(result)[:200]}...")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {str(e)}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
else:
|
||||
logger.error(f"Unknown function requested: {func_name}")
|
||||
return {"status": "error", "error": f"Unknown function: {func_name}"}
|
||||
4
pr/cache/__init__.py
vendored
Normal file
4
pr/cache/__init__.py
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
from .api_cache import APICache
|
||||
from .tool_cache import ToolCache
|
||||
|
||||
__all__ = ['APICache', 'ToolCache']
|
||||
127
pr/cache/api_cache.py
vendored
Normal file
127
pr/cache/api_cache.py
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
class APICache:
|
||||
def __init__(self, db_path: str, ttl_seconds: int = 3600):
|
||||
self.db_path = db_path
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._initialize_cache()
|
||||
|
||||
def _initialize_cache(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS api_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
response_data TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
model TEXT,
|
||||
token_count INTEGER
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str:
|
||||
cache_data = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'temperature': temperature,
|
||||
'max_tokens': max_tokens
|
||||
}
|
||||
serialized = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]:
|
||||
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('''
|
||||
SELECT response_data FROM api_cache
|
||||
WHERE cache_key = ? AND expires_at > ?
|
||||
''', (cache_key, current_time))
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
def set(self, model: str, messages: list, temperature: float, max_tokens: int,
|
||||
response: Dict[str, Any], token_count: int = 0):
|
||||
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||||
|
||||
current_time = int(time.time())
|
||||
expires_at = current_time + self.ttl_seconds
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO api_cache
|
||||
(cache_key, response_data, created_at, expires_at, model, token_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def clear_expired(self):
|
||||
current_time = int(time.time())
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def clear_all(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM api_cache')
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM api_cache')
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||||
total_tokens = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'valid_entries': valid_entries,
|
||||
'expired_entries': total_entries - valid_entries,
|
||||
'total_cached_tokens': total_tokens
|
||||
}
|
||||
179
pr/cache/tool_cache.py
vendored
Normal file
179
pr/cache/tool_cache.py
vendored
Normal file
@ -0,0 +1,179 @@
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Optional, Any, Set
|
||||
|
||||
class ToolCache:
|
||||
DETERMINISTIC_TOOLS: Set[str] = {
|
||||
'read_file',
|
||||
'list_directory',
|
||||
'get_current_directory',
|
||||
'db_get',
|
||||
'db_query',
|
||||
'index_directory'
|
||||
}
|
||||
|
||||
def __init__(self, db_path: str, ttl_seconds: int = 300):
|
||||
self.db_path = db_path
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._initialize_cache()
|
||||
|
||||
def _initialize_cache(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tool_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
tool_name TEXT NOT NULL,
|
||||
result_data TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
hit_count INTEGER DEFAULT 0
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
|
||||
cache_data = {
|
||||
'tool': tool_name,
|
||||
'args': arguments
|
||||
}
|
||||
serialized = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def is_cacheable(self, tool_name: str) -> bool:
|
||||
return tool_name in self.DETERMINISTIC_TOOLS
|
||||
|
||||
def get(self, tool_name: str, arguments: dict) -> Optional[Any]:
|
||||
if not self.is_cacheable(tool_name):
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(tool_name, arguments)
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('''
|
||||
SELECT result_data, hit_count FROM tool_cache
|
||||
WHERE cache_key = ? AND expires_at > ?
|
||||
''', (cache_key, current_time))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
cursor.execute('''
|
||||
UPDATE tool_cache SET hit_count = hit_count + 1
|
||||
WHERE cache_key = ?
|
||||
''', (cache_key,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return json.loads(row[0])
|
||||
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
def set(self, tool_name: str, arguments: dict, result: Any):
|
||||
if not self.is_cacheable(tool_name):
|
||||
return
|
||||
|
||||
cache_key = self._generate_cache_key(tool_name, arguments)
|
||||
|
||||
current_time = int(time.time())
|
||||
expires_at = current_time + self.ttl_seconds
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO tool_cache
|
||||
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def invalidate_tool(self, tool_name: str):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def clear_expired(self):
|
||||
current_time = int(time.time())
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def clear_all(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache')
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM tool_cache')
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||||
total_hits = cursor.fetchone()[0] or 0
|
||||
|
||||
cursor.execute('''
|
||||
SELECT tool_name, COUNT(*), SUM(hit_count)
|
||||
FROM tool_cache
|
||||
WHERE expires_at > ?
|
||||
GROUP BY tool_name
|
||||
''', (current_time,))
|
||||
|
||||
tool_stats = {}
|
||||
for row in cursor.fetchall():
|
||||
tool_stats[row[0]] = {
|
||||
'cached_entries': row[1],
|
||||
'total_hits': row[2] or 0
|
||||
}
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'valid_entries': valid_entries,
|
||||
'expired_entries': total_entries - valid_entries,
|
||||
'total_cache_hits': total_hits,
|
||||
'by_tool': tool_stats
|
||||
}
|
||||
3
pr/commands/__init__.py
Normal file
3
pr/commands/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pr.commands.handlers import handle_command
|
||||
|
||||
__all__ = ['handle_command']
|
||||
391
pr/commands/handlers.py
Normal file
391
pr/commands/handlers.py
Normal file
@ -0,0 +1,391 @@
|
||||
import json
|
||||
from pr.ui import Colors
|
||||
from pr.tools import read_file
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.core.api import list_models
|
||||
from pr.autonomous import run_autonomous_mode
|
||||
|
||||
def handle_command(assistant, command):
|
||||
command_parts = command.strip().split(maxsplit=1)
|
||||
cmd = command_parts[0].lower()
|
||||
|
||||
if cmd == '/auto':
|
||||
if len(command_parts) < 2:
|
||||
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}")
|
||||
return True
|
||||
|
||||
task = command_parts[1]
|
||||
run_autonomous_mode(assistant, task)
|
||||
return True
|
||||
|
||||
if cmd in ['exit', 'quit', 'q']:
|
||||
return False
|
||||
|
||||
elif cmd == 'help':
|
||||
print(f"""
|
||||
{Colors.BOLD}Available Commands:{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Basic:{Colors.RESET}
|
||||
exit, quit, q - Exit the assistant
|
||||
/help - Show this help message
|
||||
/reset - Clear message history
|
||||
/dump - Show message history as JSON
|
||||
/verbose - Toggle verbose mode
|
||||
/models - List available models
|
||||
/tools - List available tools
|
||||
|
||||
{Colors.BOLD}File Operations:{Colors.RESET}
|
||||
/review <file> - Review a file
|
||||
/refactor <file> - Refactor code in a file
|
||||
/obfuscate <file> - Obfuscate code in a file
|
||||
|
||||
{Colors.BOLD}Advanced Features:{Colors.RESET}
|
||||
{Colors.CYAN}/auto <task>{Colors.RESET} - Enter autonomous mode
|
||||
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a workflow
|
||||
{Colors.CYAN}/workflows{Colors.RESET} - List all workflows
|
||||
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create specialized agent and assign task
|
||||
{Colors.CYAN}/agents{Colors.RESET} - Show active agents
|
||||
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Use multiple agents to collaborate
|
||||
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
|
||||
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information in knowledge base
|
||||
{Colors.CYAN}/history{Colors.RESET} - Show conversation history
|
||||
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
|
||||
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
|
||||
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
|
||||
""")
|
||||
|
||||
elif cmd == '/reset':
|
||||
assistant.messages = assistant.messages[:1]
|
||||
print(f"{Colors.GREEN}Message history cleared{Colors.RESET}")
|
||||
|
||||
elif cmd == '/dump':
|
||||
print(json.dumps(assistant.messages, indent=2))
|
||||
|
||||
elif cmd == '/verbose':
|
||||
assistant.verbose = not assistant.verbose
|
||||
print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}")
|
||||
|
||||
elif cmd.startswith("/model"):
|
||||
if len(command_parts) < 2:
|
||||
print("Current model: " + Colors.GREEN + assistant.model + Colors.RESET)
|
||||
else:
|
||||
assistant.model = command_parts[1]
|
||||
print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}")
|
||||
|
||||
elif cmd == '/models':
|
||||
models = list_models(assistant.model_list_url, assistant.api_key)
|
||||
if isinstance(models, dict) and 'error' in models:
|
||||
print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.BOLD}Available Models:{Colors.RESET}")
|
||||
for model in models:
|
||||
print(f" • {Colors.CYAN}{model['id']}{Colors.RESET}")
|
||||
|
||||
elif cmd == '/tools':
|
||||
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
|
||||
for tool in get_tools_definition():
|
||||
func = tool['function']
|
||||
print(f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
|
||||
|
||||
elif cmd == '/review' and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
review_file(assistant, filename)
|
||||
|
||||
elif cmd == '/refactor' and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
refactor_file(assistant, filename)
|
||||
|
||||
elif cmd == '/obfuscate' and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
obfuscate_file(assistant, filename)
|
||||
|
||||
elif cmd == '/workflows':
|
||||
show_workflows(assistant)
|
||||
|
||||
elif cmd == '/workflow' and len(command_parts) > 1:
|
||||
workflow_name = command_parts[1]
|
||||
execute_workflow_command(assistant, workflow_name)
|
||||
|
||||
elif cmd == '/agent' and len(command_parts) > 1:
|
||||
args = command_parts[1].split(maxsplit=1)
|
||||
if len(args) < 2:
|
||||
print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}")
|
||||
else:
|
||||
role, task = args[0], args[1]
|
||||
execute_agent_task(assistant, role, task)
|
||||
|
||||
elif cmd == '/agents':
|
||||
show_agents(assistant)
|
||||
|
||||
elif cmd == '/collaborate' and len(command_parts) > 1:
|
||||
task = command_parts[1]
|
||||
collaborate_agents_command(assistant, task)
|
||||
|
||||
elif cmd == '/knowledge' and len(command_parts) > 1:
|
||||
query = command_parts[1]
|
||||
search_knowledge(assistant, query)
|
||||
|
||||
elif cmd == '/remember' and len(command_parts) > 1:
|
||||
content = command_parts[1]
|
||||
store_knowledge(assistant, content)
|
||||
|
||||
elif cmd == '/history':
|
||||
show_conversation_history(assistant)
|
||||
|
||||
elif cmd == '/cache':
|
||||
if len(command_parts) > 1 and command_parts[1].lower() == 'clear':
|
||||
clear_caches(assistant)
|
||||
else:
|
||||
show_cache_stats(assistant)
|
||||
|
||||
elif cmd == '/stats':
|
||||
show_system_stats(assistant)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return True
|
||||
|
||||
def review_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
message = f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
def refactor_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
def obfuscate_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
message = f"Please obfuscate this code:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
def show_workflows(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
workflows = assistant.enhanced.get_workflow_list()
|
||||
if not workflows:
|
||||
print(f"{Colors.YELLOW}No workflows found{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"\n{Colors.BOLD}Available Workflows:{Colors.RESET}")
|
||||
for wf in workflows:
|
||||
print(f" • {Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}")
|
||||
print(f" Executions: {wf['execution_count']}")
|
||||
|
||||
def execute_workflow_command(assistant, workflow_name):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}")
|
||||
result = assistant.enhanced.execute_workflow(workflow_name)
|
||||
|
||||
if 'error' in result:
|
||||
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}")
|
||||
print(f"Execution ID: {result['execution_id']}")
|
||||
print(f"Results: {json.dumps(result['results'], indent=2)}")
|
||||
|
||||
def execute_agent_task(assistant, role, task):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.YELLOW}Creating {role} agent...{Colors.RESET}")
|
||||
agent_id = assistant.enhanced.create_agent(role)
|
||||
print(f"{Colors.GREEN}Agent created: {agent_id}{Colors.RESET}")
|
||||
|
||||
print(f"{Colors.YELLOW}Executing task...{Colors.RESET}")
|
||||
result = assistant.enhanced.agent_task(agent_id, task)
|
||||
|
||||
if 'error' in result:
|
||||
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}")
|
||||
print(result['response'])
|
||||
|
||||
def show_agents(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
summary = assistant.enhanced.get_agent_summary()
|
||||
print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}")
|
||||
print(f"Active agents: {summary['active_agents']}")
|
||||
|
||||
if summary['agents']:
|
||||
for agent in summary['agents']:
|
||||
print(f"\n • {Colors.CYAN}{agent['agent_id']}{Colors.RESET}")
|
||||
print(f" Role: {agent['role']}")
|
||||
print(f" Tasks completed: {agent['task_count']}")
|
||||
print(f" Messages: {agent['message_count']}")
|
||||
|
||||
def collaborate_agents_command(assistant, task):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}")
|
||||
roles = ['coding', 'research', 'planning']
|
||||
|
||||
result = assistant.enhanced.collaborate_agents(task, roles)
|
||||
|
||||
print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}")
|
||||
print(f"\nOrchestrator response:")
|
||||
if 'orchestrator' in result and 'response' in result['orchestrator']:
|
||||
print(result['orchestrator']['response'])
|
||||
|
||||
if result.get('agents'):
|
||||
print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}")
|
||||
for agent_result in result['agents']:
|
||||
if 'role' in agent_result:
|
||||
print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}")
|
||||
print(agent_result.get('response', 'No response'))
|
||||
|
||||
def search_knowledge(assistant, query):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
results = assistant.enhanced.search_knowledge(query)
|
||||
|
||||
if not results:
|
||||
print(f"{Colors.YELLOW}No knowledge entries found for: {query}{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"\n{Colors.BOLD}Knowledge Search Results:{Colors.RESET}")
|
||||
for entry in results:
|
||||
print(f"\n • {Colors.CYAN}[{entry.category}]{Colors.RESET}")
|
||||
print(f" {entry.content[:200]}...")
|
||||
print(f" Accessed: {entry.access_count} times")
|
||||
|
||||
def store_knowledge(assistant, content):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
import uuid
|
||||
import time
|
||||
from pr.memory import KnowledgeEntry
|
||||
|
||||
categories = assistant.enhanced.fact_extractor.categorize_content(content)
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=categories[0] if categories else 'general',
|
||||
content=content,
|
||||
metadata={'manual_entry': True},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
|
||||
assistant.enhanced.knowledge_store.add_entry(entry)
|
||||
print(f"{Colors.GREEN}Knowledge stored successfully{Colors.RESET}")
|
||||
print(f"Entry ID: {entry_id}")
|
||||
print(f"Category: {entry.category}")
|
||||
|
||||
def show_conversation_history(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
history = assistant.enhanced.get_conversation_history(limit=10)
|
||||
|
||||
if not history:
|
||||
print(f"{Colors.YELLOW}No conversation history found{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}")
|
||||
for conv in history:
|
||||
import datetime
|
||||
started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M')
|
||||
print(f"\n • {Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
|
||||
print(f" Started: {started}")
|
||||
print(f" Messages: {conv['message_count']}")
|
||||
if conv.get('summary'):
|
||||
print(f" Summary: {conv['summary'][:100]}...")
|
||||
if conv.get('topics'):
|
||||
print(f" Topics: {', '.join(conv['topics'])}")
|
||||
|
||||
def show_cache_stats(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
stats = assistant.enhanced.get_cache_statistics()
|
||||
|
||||
print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}")
|
||||
|
||||
if 'api_cache' in stats:
|
||||
api_stats = stats['api_cache']
|
||||
print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}")
|
||||
print(f" Total entries: {api_stats['total_entries']}")
|
||||
print(f" Valid entries: {api_stats['valid_entries']}")
|
||||
print(f" Expired entries: {api_stats['expired_entries']}")
|
||||
print(f" Cached tokens: {api_stats['total_cached_tokens']}")
|
||||
|
||||
if 'tool_cache' in stats:
|
||||
tool_stats = stats['tool_cache']
|
||||
print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}")
|
||||
print(f" Total entries: {tool_stats['total_entries']}")
|
||||
print(f" Valid entries: {tool_stats['valid_entries']}")
|
||||
print(f" Total cache hits: {tool_stats['total_cache_hits']}")
|
||||
|
||||
if tool_stats.get('by_tool'):
|
||||
print(f"\n Per-tool statistics:")
|
||||
for tool_name, tool_stat in tool_stats['by_tool'].items():
|
||||
print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits")
|
||||
|
||||
def clear_caches(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
assistant.enhanced.clear_caches()
|
||||
print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}")
|
||||
|
||||
def show_system_stats(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"\n{Colors.BOLD}System Statistics:{Colors.RESET}")
|
||||
|
||||
cache_stats = assistant.enhanced.get_cache_statistics()
|
||||
knowledge_stats = assistant.enhanced.get_knowledge_statistics()
|
||||
agent_summary = assistant.enhanced.get_agent_summary()
|
||||
|
||||
print(f"\n{Colors.CYAN}Knowledge Base:{Colors.RESET}")
|
||||
print(f" Total entries: {knowledge_stats['total_entries']}")
|
||||
print(f" Categories: {knowledge_stats['total_categories']}")
|
||||
print(f" Total accesses: {knowledge_stats['total_accesses']}")
|
||||
print(f" Vocabulary size: {knowledge_stats['vocabulary_size']}")
|
||||
|
||||
print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}")
|
||||
print(f" Count: {agent_summary['active_agents']}")
|
||||
|
||||
if 'api_cache' in cache_stats:
|
||||
print(f"\n{Colors.CYAN}Caching:{Colors.RESET}")
|
||||
print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}")
|
||||
if 'tool_cache' in cache_stats:
|
||||
print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}")
|
||||
61
pr/config.py
Normal file
61
pr/config.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
DEFAULT_MODEL = "x-ai/grok-code-fast-1"
|
||||
DEFAULT_API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
MODEL_LIST_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
||||
DB_PATH = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
LOG_FILE = os.path.expanduser("~/.assistant_error.log")
|
||||
CONTEXT_FILE = ".rcontext.txt"
|
||||
GLOBAL_CONTEXT_FILE = os.path.expanduser("~/.rcontext.txt")
|
||||
HISTORY_FILE = os.path.expanduser("~/.assistant_history")
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.1
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
|
||||
MAX_AUTONOMOUS_ITERATIONS = 50
|
||||
CONTEXT_COMPRESSION_THRESHOLD = 15
|
||||
RECENT_MESSAGES_TO_KEEP = 20
|
||||
|
||||
API_TOTAL_TOKEN_LIMIT = 256000
|
||||
MAX_OUTPUT_TOKENS = 30000
|
||||
SAFETY_BUFFER_TOKENS = 30000
|
||||
MAX_TOKENS_LIMIT = API_TOTAL_TOKEN_LIMIT - MAX_OUTPUT_TOKENS - SAFETY_BUFFER_TOKENS
|
||||
CHARS_PER_TOKEN = 2.0
|
||||
EMERGENCY_MESSAGES_TO_KEEP = 3
|
||||
CONTENT_TRIM_LENGTH = 30000
|
||||
MAX_TOOL_RESULT_LENGTH = 30000
|
||||
|
||||
LANGUAGE_KEYWORDS = {
|
||||
'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while',
|
||||
'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield',
|
||||
'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'],
|
||||
'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while',
|
||||
'return', 'try', 'catch', 'finally', 'class', 'extends', 'new',
|
||||
'this', 'null', 'undefined', 'true', 'false'],
|
||||
'java': ['public', 'private', 'protected', 'class', 'interface', 'extends',
|
||||
'implements', 'static', 'final', 'void', 'int', 'String', 'boolean',
|
||||
'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'],
|
||||
}
|
||||
|
||||
CACHE_ENABLED = True
|
||||
API_CACHE_TTL = 3600
|
||||
TOOL_CACHE_TTL = 300
|
||||
|
||||
WORKFLOW_MAX_RETRIES = 3
|
||||
WORKFLOW_DEFAULT_TIMEOUT = 300
|
||||
WORKFLOW_EXECUTOR_MAX_WORKERS = 5
|
||||
|
||||
AGENT_DEFAULT_TEMPERATURE = 0.7
|
||||
AGENT_MAX_WORKERS = 3
|
||||
AGENT_SESSION_TIMEOUT = 7200
|
||||
|
||||
KNOWLEDGE_IMPORTANCE_THRESHOLD = 0.5
|
||||
KNOWLEDGE_SEARCH_LIMIT = 5
|
||||
MEMORY_AUTO_SUMMARIZE = True
|
||||
CONVERSATION_SUMMARY_THRESHOLD = 20
|
||||
|
||||
ADVANCED_CONTEXT_ENABLED = True
|
||||
CONTEXT_RELEVANCE_THRESHOLD = 0.3
|
||||
ADAPTIVE_CONTEXT_MIN = 10
|
||||
ADAPTIVE_CONTEXT_MAX = 50
|
||||
5
pr/core/__init__.py
Normal file
5
pr/core/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from pr.core.assistant import Assistant
|
||||
from pr.core.api import call_api, list_models
|
||||
from pr.core.context import init_system_message, manage_context_window
|
||||
|
||||
__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window']
|
||||
82
pr/core/advanced_context.py
Normal file
82
pr/core/advanced_context.py
Normal file
@ -0,0 +1,82 @@
|
||||
import re
|
||||
import math
|
||||
from typing import List, Dict, Any
|
||||
from collections import Counter
|
||||
|
||||
class AdvancedContextManager:
|
||||
def __init__(self, knowledge_store=None, conversation_memory=None):
|
||||
self.knowledge_store = knowledge_store
|
||||
self.conversation_memory = conversation_memory
|
||||
|
||||
def adaptive_context_window(self, messages: List[Dict[str, Any]],
|
||||
task_complexity: str = 'medium') -> int:
|
||||
complexity_thresholds = {
|
||||
'simple': 10,
|
||||
'medium': 20,
|
||||
'complex': 35,
|
||||
'very_complex': 50
|
||||
}
|
||||
|
||||
base_threshold = complexity_thresholds.get(task_complexity, 20)
|
||||
|
||||
message_complexity_score = self._analyze_message_complexity(messages)
|
||||
|
||||
if message_complexity_score > 0.7:
|
||||
adjusted = int(base_threshold * 1.5)
|
||||
elif message_complexity_score < 0.3:
|
||||
adjusted = int(base_threshold * 0.7)
|
||||
else:
|
||||
adjusted = base_threshold
|
||||
|
||||
return max(base_threshold, adjusted)
|
||||
|
||||
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
|
||||
total_length = sum(len(msg.get('content', '')) for msg in messages)
|
||||
avg_length = total_length / len(messages) if messages else 0
|
||||
|
||||
unique_words = set()
|
||||
for msg in messages:
|
||||
content = msg.get('content', '')
|
||||
words = re.findall(r'\b\w+\b', content.lower())
|
||||
unique_words.update(words)
|
||||
|
||||
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
|
||||
|
||||
# Simple complexity score based on length and richness
|
||||
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
|
||||
return complexity
|
||||
|
||||
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
|
||||
sentences = re.split(r'(?<=[.!?])\s+', text)
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
# Simple scoring based on length and position
|
||||
scored_sentences = []
|
||||
for i, sentence in enumerate(sentences):
|
||||
length_score = min(1.0, len(sentence) / 50)
|
||||
position_score = 1.0 if i == 0 else 0.8 if i < len(sentences) / 2 else 0.6
|
||||
score = (length_score + position_score) / 2
|
||||
scored_sentences.append((sentence, score))
|
||||
|
||||
scored_sentences.sort(key=lambda x: x[1], reverse=True)
|
||||
return [s[0] for s in scored_sentences[:top_k]]
|
||||
|
||||
def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str:
|
||||
all_content = ' '.join([msg.get('content', '') for msg in messages])
|
||||
key_sentences = self.extract_key_sentences(all_content, top_k=3)
|
||||
summary = ' '.join(key_sentences)
|
||||
return summary if summary else "No content to summarize."
|
||||
|
||||
def score_message_relevance(self, message: Dict[str, Any], context: str) -> float:
|
||||
content = message.get('content', '')
|
||||
content_words = set(re.findall(r'\b\w+\b', content.lower()))
|
||||
context_words = set(re.findall(r'\b\w+\b', context.lower()))
|
||||
|
||||
intersection = content_words & context_words
|
||||
union = content_words | context_words
|
||||
|
||||
if not union:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(union)
|
||||
95
pr/core/api.py
Normal file
95
pr/core/api.py
Normal file
@ -0,0 +1,95 @@
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import logging
|
||||
from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
|
||||
from pr.core.context import auto_slim_messages
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
|
||||
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
|
||||
try:
|
||||
messages = auto_slim_messages(messages, verbose=verbose)
|
||||
|
||||
logger.debug(f"=== API CALL START ===")
|
||||
logger.debug(f"Model: {model}")
|
||||
logger.debug(f"API URL: {api_url}")
|
||||
logger.debug(f"Use tools: {use_tools}")
|
||||
logger.debug(f"Message count: {len(messages)}")
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'temperature': DEFAULT_TEMPERATURE,
|
||||
'max_tokens': DEFAULT_MAX_TOKENS
|
||||
}
|
||||
|
||||
if "gpt-5" in model:
|
||||
del data['temperature']
|
||||
del data['max_tokens']
|
||||
logger.debug("GPT-5 detected: removed temperature and max_tokens")
|
||||
|
||||
if use_tools:
|
||||
data['tools'] = tools_definition
|
||||
data['tool_choice'] = 'auto'
|
||||
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
|
||||
|
||||
request_json = json.dumps(data)
|
||||
logger.debug(f"Request payload size: {len(request_json)} bytes")
|
||||
|
||||
req = urllib.request.Request(
|
||||
api_url,
|
||||
data=request_json.encode('utf-8'),
|
||||
headers=headers,
|
||||
method='POST'
|
||||
)
|
||||
|
||||
logger.debug("Sending HTTP request...")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
response_data = response.read().decode('utf-8')
|
||||
logger.debug(f"Response received: {len(response_data)} bytes")
|
||||
result = json.loads(response_data)
|
||||
|
||||
if 'usage' in result:
|
||||
logger.debug(f"Token usage: {result['usage']}")
|
||||
if 'choices' in result and result['choices']:
|
||||
choice = result['choices'][0]
|
||||
if 'message' in choice:
|
||||
msg = choice['message']
|
||||
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
|
||||
if 'content' in msg and msg['content']:
|
||||
logger.debug(f"Response content length: {len(msg['content'])} chars")
|
||||
if 'tool_calls' in msg:
|
||||
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
|
||||
|
||||
logger.debug("=== API CALL END ===")
|
||||
return result
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode('utf-8')
|
||||
logger.error(f"API HTTP Error: {e.code} - {error_body}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": f"API Error: {e.code}", "message": error_body}
|
||||
except Exception as e:
|
||||
logger.error(f"API call failed: {e}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": str(e)}
|
||||
|
||||
def list_models(model_list_url, api_key):
|
||||
try:
|
||||
req = urllib.request.Request(model_list_url)
|
||||
if api_key:
|
||||
req.add_header('Authorization', f'Bearer {api_key}')
|
||||
|
||||
with urllib.request.urlopen(req) as response:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
return data.get('data', [])
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
325
pr/core/assistant.py
Normal file
325
pr/core/assistant.py
Normal file
@ -0,0 +1,325 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import signal
|
||||
import logging
|
||||
import traceback
|
||||
import readline
|
||||
import glob as glob_module
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE
|
||||
from pr.ui import Colors, render_markdown
|
||||
from pr.core.context import init_system_message, truncate_tool_result
|
||||
from pr.core.api import call_api
|
||||
from pr.tools import (
|
||||
http_fetch, run_command, run_command_interactive, read_file, write_file,
|
||||
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
|
||||
web_search, web_search_news, python_exec, index_source_directory,
|
||||
open_editor, editor_insert_text, editor_replace_text, editor_search,
|
||||
search_replace,close_editor,create_diff,apply_patch,
|
||||
tail_process, kill_process
|
||||
)
|
||||
from pr.tools.patch import display_file_diff
|
||||
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.commands import handle_command
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
file_handler = logging.FileHandler(LOG_FILE)
|
||||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
class Assistant:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.messages = []
|
||||
self.verbose = args.verbose
|
||||
self.debug = getattr(args, 'debug', False)
|
||||
self.syntax_highlighting = not args.no_syntax
|
||||
|
||||
if self.debug:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
|
||||
logger.addHandler(console_handler)
|
||||
logger.debug("Debug mode enabled")
|
||||
self.api_key = os.environ.get('OPENROUTER_API_KEY', '')
|
||||
self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL)
|
||||
self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL)
|
||||
self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL)
|
||||
self.use_tools = os.environ.get('USE_TOOLS', '1') == '1'
|
||||
self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1'
|
||||
self.interrupt_count = 0
|
||||
self.python_globals = {}
|
||||
self.db_conn = None
|
||||
self.autonomous_mode = False
|
||||
self.autonomous_iterations = 0
|
||||
self.init_database()
|
||||
self.messages.append(init_system_message(args))
|
||||
|
||||
try:
|
||||
from pr.core.enhanced_assistant import EnhancedAssistant
|
||||
self.enhanced = EnhancedAssistant(self)
|
||||
if self.debug:
|
||||
logger.debug("Enhanced assistant features initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize enhanced features: {e}")
|
||||
self.enhanced = None
|
||||
|
||||
def init_database(self):
|
||||
try:
|
||||
logger.debug(f"Initializing database at {DB_PATH}")
|
||||
self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store
|
||||
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''')
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
filepath TEXT, content TEXT, hash TEXT,
|
||||
timestamp REAL, version INTEGER)''')
|
||||
|
||||
self.db_conn.commit()
|
||||
logger.debug("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {e}")
|
||||
self.db_conn = None
|
||||
|
||||
def execute_tool_calls(self, tool_calls):
|
||||
results = []
|
||||
|
||||
logger.debug(f"Executing {len(tool_calls)} tool call(s)")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
func_name = tool_call['function']['name']
|
||||
arguments = json.loads(tool_call['function']['arguments'])
|
||||
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
|
||||
|
||||
func_map = {
|
||||
'http_fetch': lambda **kw: http_fetch(**kw),
|
||||
'run_command': lambda **kw: run_command(**kw),
|
||||
'tail_process': lambda **kw: tail_process(**kw),
|
||||
'kill_process': lambda **kw: kill_process(**kw),
|
||||
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
|
||||
'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn),
|
||||
'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn),
|
||||
'list_directory': lambda **kw: list_directory(**kw),
|
||||
'mkdir': lambda **kw: mkdir(**kw),
|
||||
'chdir': lambda **kw: chdir(**kw),
|
||||
'getpwd': lambda **kw: getpwd(**kw),
|
||||
'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn),
|
||||
'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn),
|
||||
'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn),
|
||||
'web_search': lambda **kw: web_search(**kw),
|
||||
'web_search_news': lambda **kw: web_search_news(**kw),
|
||||
'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals),
|
||||
'index_source_directory': lambda **kw: index_source_directory(**kw),
|
||||
'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn),
|
||||
'open_editor': lambda **kw: open_editor(**kw),
|
||||
'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn),
|
||||
'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn),
|
||||
'editor_search': lambda **kw: editor_search(**kw),
|
||||
'close_editor': lambda **kw: close_editor(**kw),
|
||||
'create_diff': lambda **kw: create_diff(**kw),
|
||||
'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
|
||||
'display_file_diff': lambda **kw: display_file_diff(**kw),
|
||||
'display_edit_summary': lambda **kw: display_edit_summary(),
|
||||
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
|
||||
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
future = executor.submit(func_map[func_name], **arguments)
|
||||
futures.append((tool_call['id'], future))
|
||||
|
||||
for tool_id, future in futures:
|
||||
try:
|
||||
result = future.result(timeout=30)
|
||||
result = truncate_tool_result(result)
|
||||
logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...")
|
||||
results.append({
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool error for {tool_id}: {str(e)}")
|
||||
error_msg = str(e)[:200] if len(str(e)) > 200 else str(e)
|
||||
results.append({
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps({"status": "error", "error": error_msg})
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def process_response(self, response):
|
||||
if 'error' in response:
|
||||
return f"Error: {response['error']}"
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
return "No response from API"
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
self.messages.append(message)
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
if self.verbose:
|
||||
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
|
||||
|
||||
tool_results = self.execute_tool_calls(message['tool_calls'])
|
||||
|
||||
for result in tool_results:
|
||||
self.messages.append(result)
|
||||
|
||||
follow_up = call_api(
|
||||
self.messages, self.model, self.api_url, self.api_key,
|
||||
self.use_tools, get_tools_definition(), verbose=self.verbose
|
||||
)
|
||||
return self.process_response(follow_up)
|
||||
|
||||
content = message.get('content', '')
|
||||
return render_markdown(content, self.syntax_highlighting)
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
if self.autonomous_mode:
|
||||
self.interrupt_count += 1
|
||||
if self.interrupt_count >= 2:
|
||||
print(f"\n{Colors.RED}Force exiting autonomous mode...{Colors.RESET}")
|
||||
self.autonomous_mode = False
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
|
||||
return
|
||||
|
||||
self.interrupt_count += 1
|
||||
if self.interrupt_count >= 2:
|
||||
print(f"\n{Colors.RED}Exiting...{Colors.RESET}")
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n{Colors.YELLOW}Press Ctrl+C again to exit{Colors.RESET}")
|
||||
|
||||
def setup_readline(self):
|
||||
try:
|
||||
readline.read_history_file(HISTORY_FILE)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
readline.set_history_length(1000)
|
||||
|
||||
import atexit
|
||||
atexit.register(readline.write_history_file, HISTORY_FILE)
|
||||
|
||||
commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose',
|
||||
'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto']
|
||||
|
||||
def completer(text, state):
|
||||
options = [cmd for cmd in commands if cmd.startswith(text)]
|
||||
|
||||
glob_pattern = os.path.expanduser(text) + '*'
|
||||
path_options = glob_module.glob(glob_pattern)
|
||||
|
||||
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
|
||||
|
||||
combined_options = sorted(list(set(options + path_options)))
|
||||
|
||||
if state < len(combined_options):
|
||||
return combined_options[state]
|
||||
|
||||
return None
|
||||
|
||||
delims = readline.get_completer_delims()
|
||||
readline.set_completer_delims(delims.replace('/', ''))
|
||||
|
||||
readline.set_completer(completer)
|
||||
readline.parse_and_bind('tab: complete')
|
||||
|
||||
def run_repl(self):
|
||||
self.setup_readline()
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
|
||||
print(f"{Colors.BOLD}r{Colors.RESET}")
|
||||
print(f"Type 'help' for commands or start chatting")
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input(f"{Colors.BLUE}You>{Colors.RESET} ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
cmd_result = handle_command(self, user_input)
|
||||
|
||||
if cmd_result is False:
|
||||
break
|
||||
elif cmd_result is True:
|
||||
continue
|
||||
|
||||
process_message(self, user_input)
|
||||
|
||||
except EOFError:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
self.signal_handler(None, None)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error: {e}{Colors.RESET}")
|
||||
logging.error(f"REPL error: {e}\n{traceback.format_exc()}")
|
||||
|
||||
def run_single(self):
|
||||
if self.args.message:
|
||||
message = self.args.message
|
||||
else:
|
||||
message = sys.stdin.read()
|
||||
|
||||
process_message(self, message)
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, 'enhanced') and self.enhanced:
|
||||
try:
|
||||
self.enhanced.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up enhanced features: {e}")
|
||||
|
||||
try:
|
||||
from pr.multiplexer import cleanup_all_multiplexers
|
||||
cleanup_all_multiplexers()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up multiplexers: {e}")
|
||||
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
|
||||
self.run_repl()
|
||||
else:
|
||||
self.run_single()
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
def process_message(assistant, message):
|
||||
assistant.messages.append({"role": "user", "content": message})
|
||||
|
||||
logger.debug(f"Processing user message: {message[:100]}...")
|
||||
logger.debug(f"Current message count: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
|
||||
|
||||
response = call_api(
|
||||
assistant.messages, assistant.model, assistant.api_url,
|
||||
assistant.api_key, assistant.use_tools, get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
)
|
||||
result = assistant.process_response(response)
|
||||
|
||||
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
|
||||
108
pr/core/config_loader.py
Normal file
108
pr/core/config_loader.py
Normal file
@ -0,0 +1,108 @@
|
||||
import os
|
||||
import configparser
|
||||
from typing import Dict, Any
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('config')
|
||||
|
||||
CONFIG_FILE = os.path.expanduser("~/.prrc")
|
||||
LOCAL_CONFIG_FILE = ".prrc"
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
config = {
|
||||
'api': {},
|
||||
'autonomous': {},
|
||||
'ui': {},
|
||||
'output': {},
|
||||
'session': {}
|
||||
}
|
||||
|
||||
global_config = _load_config_file(CONFIG_FILE)
|
||||
local_config = _load_config_file(LOCAL_CONFIG_FILE)
|
||||
|
||||
for section in config.keys():
|
||||
if section in global_config:
|
||||
config[section].update(global_config[section])
|
||||
if section in local_config:
|
||||
config[section].update(local_config[section])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
|
||||
if not os.path.exists(filepath):
|
||||
return {}
|
||||
|
||||
try:
|
||||
parser = configparser.ConfigParser()
|
||||
parser.read(filepath)
|
||||
|
||||
config = {}
|
||||
for section in parser.sections():
|
||||
config[section] = {}
|
||||
for key, value in parser.items(section):
|
||||
config[section][key] = _parse_value(value)
|
||||
|
||||
logger.debug(f"Loaded configuration from {filepath}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config from {filepath}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_value(value: str) -> Any:
|
||||
value = value.strip()
|
||||
|
||||
if value.lower() == 'true':
|
||||
return True
|
||||
if value.lower() == 'false':
|
||||
return False
|
||||
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def create_default_config(filepath: str = CONFIG_FILE):
|
||||
default_config = """[api]
|
||||
default_model = x-ai/grok-code-fast-1
|
||||
timeout = 30
|
||||
temperature = 0.7
|
||||
max_tokens = 8096
|
||||
|
||||
[autonomous]
|
||||
max_iterations = 50
|
||||
context_threshold = 30
|
||||
recent_messages_to_keep = 10
|
||||
|
||||
[ui]
|
||||
syntax_highlighting = true
|
||||
show_timestamps = false
|
||||
color_output = true
|
||||
|
||||
[output]
|
||||
format = text
|
||||
verbose = false
|
||||
quiet = false
|
||||
|
||||
[session]
|
||||
auto_save = false
|
||||
max_history = 1000
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
f.write(default_config)
|
||||
logger.info(f"Created default configuration at {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating config file: {e}")
|
||||
return False
|
||||
289
pr/core/context.py
Normal file
289
pr/core/context.py
Normal file
@ -0,0 +1,289 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD,
|
||||
RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN,
|
||||
EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH)
|
||||
from pr.ui import Colors
|
||||
|
||||
def truncate_tool_result(result, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = MAX_TOOL_RESULT_LENGTH
|
||||
|
||||
if not isinstance(result, dict):
|
||||
return result
|
||||
|
||||
result_copy = result.copy()
|
||||
|
||||
if "output" in result_copy and isinstance(result_copy["output"], str):
|
||||
if len(result_copy["output"]) > max_length:
|
||||
result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
|
||||
|
||||
if "content" in result_copy and isinstance(result_copy["content"], str):
|
||||
if len(result_copy["content"]) > max_length:
|
||||
result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
|
||||
|
||||
if "data" in result_copy and isinstance(result_copy["data"], str):
|
||||
if len(result_copy["data"]) > max_length:
|
||||
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
|
||||
|
||||
if "error" in result_copy and isinstance(result_copy["error"], str):
|
||||
if len(result_copy["error"]) > max_length // 2:
|
||||
result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]"
|
||||
|
||||
return result_copy
|
||||
|
||||
def init_system_message(args):
|
||||
context_parts = ["""You are a professional AI assistant with access to advanced tools.
|
||||
|
||||
File Operations:
|
||||
- Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications
|
||||
- Always close editor files when finished
|
||||
- Use write_file for complete file rewrites, search_replace for simple text replacements
|
||||
|
||||
Process Management:
|
||||
- run_command executes shell commands with a timeout (default 30s)
|
||||
- If a command times out, you receive a PID in the response
|
||||
- Use tail_process(pid) to monitor running processes
|
||||
- Use kill_process(pid) to terminate processes
|
||||
- Manage long-running commands effectively using these tools
|
||||
|
||||
Shell Commands:
|
||||
- Be a shell ninja using native OS tools
|
||||
- Prefer standard Unix utilities over complex scripts
|
||||
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""]
|
||||
#context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
|
||||
max_context_size = 10000
|
||||
|
||||
if args.include_env:
|
||||
env_context = "Environment Variables:\n"
|
||||
for key, value in os.environ.items():
|
||||
if not key.startswith('_'):
|
||||
env_context += f"{key}={value}\n"
|
||||
if len(env_context) > max_context_size:
|
||||
env_context = env_context[:max_context_size] + "\n... [truncated]"
|
||||
context_parts.append(env_context)
|
||||
|
||||
for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]:
|
||||
if os.path.exists(context_file):
|
||||
try:
|
||||
with open(context_file, 'r') as f:
|
||||
content = f.read()
|
||||
if len(content) > max_context_size:
|
||||
content = content[:max_context_size] + "\n... [truncated]"
|
||||
context_parts.append(f"Context from {context_file}:\n{content}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading context file {context_file}: {e}")
|
||||
|
||||
if args.context:
|
||||
for ctx_file in args.context:
|
||||
try:
|
||||
with open(ctx_file, 'r') as f:
|
||||
content = f.read()
|
||||
if len(content) > max_context_size:
|
||||
content = content[:max_context_size] + "\n... [truncated]"
|
||||
context_parts.append(f"Context from {ctx_file}:\n{content}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading context file {ctx_file}: {e}")
|
||||
|
||||
system_message = "\n\n".join(context_parts)
|
||||
if len(system_message) > max_context_size * 3:
|
||||
system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]"
|
||||
|
||||
return {"role": "system", "content": system_message}
|
||||
|
||||
def should_compress_context(messages):
|
||||
return len(messages) > CONTEXT_COMPRESSION_THRESHOLD
|
||||
|
||||
def compress_context(messages):
|
||||
return manage_context_window(messages, verbose=False)
|
||||
|
||||
def manage_context_window(messages, verbose):
|
||||
if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD:
|
||||
return messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}")
|
||||
|
||||
system_message = messages[0]
|
||||
recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:]
|
||||
middle_messages = messages[1:-RECENT_MESSAGES_TO_KEEP]
|
||||
|
||||
if middle_messages:
|
||||
summary = summarize_messages(middle_messages)
|
||||
summary_message = {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary: {summary}]"
|
||||
}
|
||||
|
||||
new_messages = [system_message, summary_message] + recent_messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}")
|
||||
|
||||
return new_messages
|
||||
|
||||
return messages
|
||||
|
||||
def summarize_messages(messages):
|
||||
summary_parts = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "tool":
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
summary_parts.append(f"{role}: {content}")
|
||||
|
||||
return " | ".join(summary_parts[:10])
|
||||
|
||||
def estimate_tokens(messages):
|
||||
total_chars = 0
|
||||
|
||||
for msg in messages:
|
||||
msg_json = json.dumps(msg)
|
||||
total_chars += len(msg_json)
|
||||
|
||||
estimated_tokens = total_chars / CHARS_PER_TOKEN
|
||||
|
||||
overhead_multiplier = 1.3
|
||||
|
||||
return int(estimated_tokens * overhead_multiplier)
|
||||
|
||||
def trim_message_content(message, max_length):
|
||||
trimmed_msg = message.copy()
|
||||
|
||||
if "content" in trimmed_msg:
|
||||
content = trimmed_msg["content"]
|
||||
|
||||
if isinstance(content, str) and len(content) > max_length:
|
||||
trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
|
||||
elif isinstance(content, list):
|
||||
trimmed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
trimmed_item = item.copy()
|
||||
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
|
||||
trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]"
|
||||
trimmed_content.append(trimmed_item)
|
||||
else:
|
||||
trimmed_content.append(item)
|
||||
trimmed_msg["content"] = trimmed_content
|
||||
|
||||
if trimmed_msg.get("role") == "tool":
|
||||
if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str):
|
||||
content = trimmed_msg["content"]
|
||||
if len(content) > MAX_TOOL_RESULT_LENGTH:
|
||||
trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2:
|
||||
parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2:
|
||||
parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
trimmed_msg["content"] = json.dumps(parsed)
|
||||
except:
|
||||
pass
|
||||
|
||||
return trimmed_msg
|
||||
|
||||
def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
|
||||
if estimate_tokens(messages) <= target_tokens:
|
||||
return messages
|
||||
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
|
||||
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
|
||||
|
||||
trimmed_middle = []
|
||||
for msg in middle_messages:
|
||||
if msg.get("role") == "tool":
|
||||
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
|
||||
elif msg.get("role") in ["user", "assistant"]:
|
||||
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
|
||||
else:
|
||||
trimmed_middle.append(msg)
|
||||
|
||||
result = ([system_msg] if system_msg else []) + trimmed_middle + recent_messages
|
||||
|
||||
if estimate_tokens(result) <= target_tokens:
|
||||
return result
|
||||
|
||||
step_size = len(trimmed_middle) // 4 if len(trimmed_middle) >= 4 else 1
|
||||
while len(trimmed_middle) > 0 and estimate_tokens(result) > target_tokens:
|
||||
remove_count = min(step_size, len(trimmed_middle))
|
||||
trimmed_middle = trimmed_middle[remove_count:]
|
||||
result = ([system_msg] if system_msg else []) + trimmed_middle + recent_messages
|
||||
|
||||
if estimate_tokens(result) <= target_tokens:
|
||||
return result
|
||||
|
||||
keep_recent -= 1
|
||||
if keep_recent > 0:
|
||||
return intelligently_trim_messages(messages, target_tokens, keep_recent)
|
||||
|
||||
return ([system_msg] if system_msg else []) + messages[-1:]
|
||||
|
||||
def auto_slim_messages(messages, verbose=False):
|
||||
estimated_tokens = estimate_tokens(messages)
|
||||
|
||||
if estimated_tokens <= MAX_TOKENS_LIMIT:
|
||||
return messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}")
|
||||
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
|
||||
|
||||
result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP)
|
||||
final_tokens = estimate_tokens(result)
|
||||
|
||||
if final_tokens > MAX_TOKENS_LIMIT:
|
||||
if verbose:
|
||||
print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}")
|
||||
result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose)
|
||||
final_tokens = estimate_tokens(result)
|
||||
|
||||
if verbose:
|
||||
removed_count = len(messages) - len(result)
|
||||
print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}")
|
||||
print(f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}")
|
||||
if removed_count > 0:
|
||||
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
|
||||
|
||||
return result
|
||||
|
||||
def emergency_reduce_messages(messages, target_tokens, verbose=False):
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
keep_count = 2
|
||||
while estimate_tokens(messages) > target_tokens and keep_count >= 1:
|
||||
if len(messages[start_idx:]) <= keep_count:
|
||||
break
|
||||
|
||||
result = ([system_msg] if system_msg else []) + messages[-keep_count:]
|
||||
|
||||
for i in range(len(result)):
|
||||
result[i] = trim_message_content(result[i], CONTENT_TRIM_LENGTH // 2)
|
||||
|
||||
if estimate_tokens(result) <= target_tokens:
|
||||
return result
|
||||
|
||||
keep_count -= 1
|
||||
|
||||
final = ([system_msg] if system_msg else []) + messages[-1:]
|
||||
|
||||
for i in range(len(final)):
|
||||
if final[i].get("role") != "system":
|
||||
final[i] = trim_message_content(final[i], 100)
|
||||
|
||||
return final
|
||||
278
pr/core/enhanced_assistant.py
Normal file
278
pr/core/enhanced_assistant.py
Normal file
@ -0,0 +1,278 @@
|
||||
import logging
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pr.config import (
|
||||
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
|
||||
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
|
||||
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
|
||||
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
|
||||
)
|
||||
from pr.cache import APICache, ToolCache
|
||||
from pr.workflows import WorkflowEngine, WorkflowStorage
|
||||
from pr.agents import AgentManager
|
||||
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor
|
||||
from pr.core.advanced_context import AdvancedContextManager
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
|
||||
class EnhancedAssistant:
|
||||
def __init__(self, base_assistant):
|
||||
self.base = base_assistant
|
||||
|
||||
if CACHE_ENABLED:
|
||||
self.api_cache = APICache(DB_PATH, API_CACHE_TTL)
|
||||
self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL)
|
||||
else:
|
||||
self.api_cache = None
|
||||
self.tool_cache = None
|
||||
|
||||
self.workflow_storage = WorkflowStorage(DB_PATH)
|
||||
self.workflow_engine = WorkflowEngine(
|
||||
tool_executor=self._execute_tool_for_workflow,
|
||||
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS
|
||||
)
|
||||
|
||||
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
|
||||
|
||||
self.knowledge_store = KnowledgeStore(DB_PATH)
|
||||
self.conversation_memory = ConversationMemory(DB_PATH)
|
||||
self.fact_extractor = FactExtractor()
|
||||
|
||||
if ADVANCED_CONTEXT_ENABLED:
|
||||
self.context_manager = AdvancedContextManager(
|
||||
knowledge_store=self.knowledge_store,
|
||||
conversation_memory=self.conversation_memory
|
||||
)
|
||||
else:
|
||||
self.context_manager = None
|
||||
|
||||
self.current_conversation_id = str(uuid.uuid4())[:16]
|
||||
self.conversation_memory.create_conversation(
|
||||
self.current_conversation_id,
|
||||
session_id=str(uuid.uuid4())[:16]
|
||||
)
|
||||
|
||||
logger.info("Enhanced Assistant initialized with all features")
|
||||
|
||||
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
if self.tool_cache:
|
||||
cached_result = self.tool_cache.get(tool_name, arguments)
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Tool cache hit for {tool_name}")
|
||||
return cached_result
|
||||
|
||||
func_map = {
|
||||
'read_file': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'read_file', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'write_file': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'write_file', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'list_directory': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'run_command': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'run_command', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
}
|
||||
|
||||
if tool_name in func_map:
|
||||
result = func_map[tool_name](**arguments)
|
||||
|
||||
if self.tool_cache:
|
||||
content = result.get('content', '')
|
||||
try:
|
||||
parsed_content = json.loads(content) if isinstance(content, str) else content
|
||||
self.tool_cache.set(tool_name, arguments, parsed_content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
return {'error': f'Unknown tool: {tool_name}'}
|
||||
|
||||
def _api_caller_for_agent(self, messages: List[Dict[str, Any]],
|
||||
temperature: float, max_tokens: int) -> Dict[str, Any]:
|
||||
return call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
self.base.api_key,
|
||||
use_tools=False,
|
||||
tools=None,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
verbose=self.base.verbose
|
||||
)
|
||||
|
||||
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if self.api_cache and CACHE_ENABLED:
|
||||
cached_response = self.api_cache.get(
|
||||
self.base.model, messages,
|
||||
0.7, 4096
|
||||
)
|
||||
if cached_response:
|
||||
logger.debug("API cache hit")
|
||||
return cached_response
|
||||
|
||||
response = call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
self.base.api_key,
|
||||
self.base.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=self.base.verbose
|
||||
)
|
||||
|
||||
if self.api_cache and CACHE_ENABLED and 'error' not in response:
|
||||
token_count = response.get('usage', {}).get('total_tokens', 0)
|
||||
self.api_cache.set(
|
||||
self.base.model, messages,
|
||||
0.7, 4096,
|
||||
response, token_count
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def process_with_enhanced_context(self, user_message: str) -> str:
|
||||
self.base.messages.append({"role": "user", "content": user_message})
|
||||
|
||||
self.conversation_memory.add_message(
|
||||
self.current_conversation_id,
|
||||
str(uuid.uuid4())[:16],
|
||||
'user',
|
||||
user_message
|
||||
)
|
||||
|
||||
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
|
||||
facts = self.fact_extractor.extract_facts(user_message)
|
||||
for fact in facts[:3]:
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
from pr.memory import KnowledgeEntry
|
||||
import time
|
||||
|
||||
categories = self.fact_extractor.categorize_content(fact['text'])
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=categories[0] if categories else 'general',
|
||||
content=fact['text'],
|
||||
metadata={'type': fact['type'], 'confidence': fact['confidence']},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
self.knowledge_store.add_entry(entry)
|
||||
|
||||
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
|
||||
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
|
||||
self.base.messages,
|
||||
user_message,
|
||||
include_knowledge=True
|
||||
)
|
||||
|
||||
if self.base.verbose:
|
||||
logger.info(f"Enhanced context: {context_info}")
|
||||
|
||||
working_messages = enhanced_messages
|
||||
else:
|
||||
working_messages = self.base.messages
|
||||
|
||||
response = self.enhanced_call_api(working_messages)
|
||||
|
||||
result = self.base.process_response(response)
|
||||
|
||||
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
|
||||
summary = self.context_manager.advanced_summarize_messages(
|
||||
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
|
||||
) if self.context_manager else "Conversation in progress"
|
||||
|
||||
topics = self.fact_extractor.categorize_content(summary)
|
||||
self.conversation_memory.update_conversation_summary(
|
||||
self.current_conversation_id,
|
||||
summary,
|
||||
topics
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def execute_workflow(self, workflow_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
|
||||
|
||||
if not workflow:
|
||||
return {'error': f'Workflow "{workflow_name}" not found'}
|
||||
|
||||
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
|
||||
|
||||
execution_id = self.workflow_storage.save_execution(
|
||||
self.workflow_storage.load_workflow_by_name(workflow_name).name,
|
||||
context
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'execution_id': execution_id,
|
||||
'results': context.step_results,
|
||||
'execution_log': context.execution_log
|
||||
}
|
||||
|
||||
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
|
||||
return self.agent_manager.create_agent(role_name, agent_id)
|
||||
|
||||
def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]:
|
||||
return self.agent_manager.execute_agent_task(agent_id, task)
|
||||
|
||||
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
||||
orchestrator_id = self.agent_manager.create_agent('orchestrator')
|
||||
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
||||
|
||||
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
|
||||
return self.knowledge_store.search_entries(query, top_k=limit)
|
||||
|
||||
def get_cache_statistics(self) -> Dict[str, Any]:
|
||||
stats = {}
|
||||
|
||||
if self.api_cache:
|
||||
stats['api_cache'] = self.api_cache.get_statistics()
|
||||
|
||||
if self.tool_cache:
|
||||
stats['tool_cache'] = self.tool_cache.get_statistics()
|
||||
|
||||
return stats
|
||||
|
||||
def get_workflow_list(self) -> List[Dict[str, Any]]:
|
||||
return self.workflow_storage.list_workflows()
|
||||
|
||||
def get_agent_summary(self) -> Dict[str, Any]:
|
||||
return self.agent_manager.get_session_summary()
|
||||
|
||||
def get_knowledge_statistics(self) -> Dict[str, Any]:
|
||||
return self.knowledge_store.get_statistics()
|
||||
|
||||
def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
return self.conversation_memory.get_recent_conversations(limit=limit)
|
||||
|
||||
def clear_caches(self):
|
||||
if self.api_cache:
|
||||
self.api_cache.clear_all()
|
||||
|
||||
if self.tool_cache:
|
||||
self.tool_cache.clear_all()
|
||||
|
||||
logger.info("All caches cleared")
|
||||
|
||||
def cleanup(self):
|
||||
if self.api_cache:
|
||||
self.api_cache.clear_expired()
|
||||
|
||||
if self.tool_cache:
|
||||
self.tool_cache.clear_expired()
|
||||
|
||||
self.agent_manager.clear_session()
|
||||
44
pr/core/exceptions.py
Normal file
44
pr/core/exceptions.py
Normal file
@ -0,0 +1,44 @@
|
||||
class PRException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class APIException(PRException):
|
||||
pass
|
||||
|
||||
|
||||
class APIConnectionError(APIException):
|
||||
pass
|
||||
|
||||
|
||||
class APITimeoutError(APIException):
|
||||
pass
|
||||
|
||||
|
||||
class APIResponseError(APIException):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(PRException):
|
||||
pass
|
||||
|
||||
|
||||
class ToolExecutionError(PRException):
|
||||
def __init__(self, tool_name: str, message: str):
|
||||
self.tool_name = tool_name
|
||||
super().__init__(f"Error executing tool '{tool_name}': {message}")
|
||||
|
||||
|
||||
class FileSystemError(PRException):
|
||||
pass
|
||||
|
||||
|
||||
class SessionError(PRException):
|
||||
pass
|
||||
|
||||
|
||||
class ContextError(PRException):
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(PRException):
|
||||
pass
|
||||
46
pr/core/logging.py
Normal file
46
pr/core/logging.py
Normal file
@ -0,0 +1,46 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pr.config import LOG_FILE
|
||||
|
||||
|
||||
def setup_logging(verbose=False):
|
||||
log_dir = os.path.dirname(LOG_FILE)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
LOG_FILE,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if verbose:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(levelname)s: %(message)s'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name=None):
|
||||
if name:
|
||||
return logging.getLogger(f'pr.{name}')
|
||||
return logging.getLogger('pr')
|
||||
146
pr/core/session.py
Normal file
146
pr/core/session.py
Normal file
@ -0,0 +1,146 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('session')
|
||||
|
||||
SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions")
|
||||
|
||||
|
||||
class SessionManager:
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(SESSIONS_DIR, exist_ok=True)
|
||||
|
||||
def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool:
|
||||
try:
|
||||
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
|
||||
|
||||
session_data = {
|
||||
'name': name,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'messages': messages,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
logger.info(f"Session saved: {name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session {name}: {e}")
|
||||
return False
|
||||
|
||||
def load_session(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
|
||||
|
||||
if not os.path.exists(session_file):
|
||||
logger.warning(f"Session not found: {name}")
|
||||
return None
|
||||
|
||||
with open(session_file, 'r') as f:
|
||||
session_data = json.load(f)
|
||||
|
||||
logger.info(f"Session loaded: {name}")
|
||||
return session_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {name}: {e}")
|
||||
return None
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
sessions = []
|
||||
|
||||
try:
|
||||
for filename in os.listdir(SESSIONS_DIR):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(SESSIONS_DIR, filename)
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
sessions.append({
|
||||
'name': data.get('name', filename[:-5]),
|
||||
'created_at': data.get('created_at', 'unknown'),
|
||||
'message_count': len(data.get('messages', [])),
|
||||
'metadata': data.get('metadata', {})
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading session file {filename}: {e}")
|
||||
|
||||
sessions.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions: {e}")
|
||||
|
||||
return sessions
|
||||
|
||||
def delete_session(self, name: str) -> bool:
|
||||
try:
|
||||
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
|
||||
|
||||
if not os.path.exists(session_file):
|
||||
logger.warning(f"Session not found: {name}")
|
||||
return False
|
||||
|
||||
os.remove(session_file)
|
||||
logger.info(f"Session deleted: {name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting session {name}: {e}")
|
||||
return False
|
||||
|
||||
def export_session(self, name: str, output_path: str, format: str = 'json') -> bool:
|
||||
session_data = self.load_session(name)
|
||||
if not session_data:
|
||||
return False
|
||||
|
||||
try:
|
||||
if format == 'json':
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
elif format == 'markdown':
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(f"# Session: {name}\n\n")
|
||||
f.write(f"Created: {session_data['created_at']}\n\n")
|
||||
f.write("---\n\n")
|
||||
|
||||
for msg in session_data['messages']:
|
||||
role = msg.get('role', 'unknown')
|
||||
content = msg.get('content', '')
|
||||
|
||||
f.write(f"## {role.capitalize()}\n\n")
|
||||
f.write(f"{content}\n\n")
|
||||
f.write("---\n\n")
|
||||
|
||||
elif format == 'txt':
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(f"Session: {name}\n")
|
||||
f.write(f"Created: {session_data['created_at']}\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
for msg in session_data['messages']:
|
||||
role = msg.get('role', 'unknown')
|
||||
content = msg.get('content', '')
|
||||
|
||||
f.write(f"[{role.upper()}]\n")
|
||||
f.write(f"{content}\n")
|
||||
f.write("-" * 80 + "\n\n")
|
||||
|
||||
else:
|
||||
logger.error(f"Unsupported export format: {format}")
|
||||
return False
|
||||
|
||||
logger.info(f"Session exported to {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting session: {e}")
|
||||
return False
|
||||
162
pr/core/usage_tracker.py
Normal file
162
pr/core/usage_tracker.py
Normal file
@ -0,0 +1,162 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('usage')
|
||||
|
||||
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
|
||||
|
||||
MODEL_COSTS = {
|
||||
'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0},
|
||||
'gpt-4': {'input': 0.03, 'output': 0.06},
|
||||
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
|
||||
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
|
||||
'claude-3-opus': {'input': 0.015, 'output': 0.075},
|
||||
'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
|
||||
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
|
||||
}
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
|
||||
def __init__(self):
|
||||
self.session_usage = {
|
||||
'requests': 0,
|
||||
'total_tokens': 0,
|
||||
'input_tokens': 0,
|
||||
'output_tokens': 0,
|
||||
'estimated_cost': 0.0,
|
||||
'models_used': {}
|
||||
}
|
||||
|
||||
def track_request(
|
||||
self,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
total_tokens: Optional[int] = None
|
||||
):
|
||||
if total_tokens is None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
self.session_usage['requests'] += 1
|
||||
self.session_usage['total_tokens'] += total_tokens
|
||||
self.session_usage['input_tokens'] += input_tokens
|
||||
self.session_usage['output_tokens'] += output_tokens
|
||||
|
||||
if model not in self.session_usage['models_used']:
|
||||
self.session_usage['models_used'][model] = {
|
||||
'requests': 0,
|
||||
'tokens': 0,
|
||||
'cost': 0.0
|
||||
}
|
||||
|
||||
model_usage = self.session_usage['models_used'][model]
|
||||
model_usage['requests'] += 1
|
||||
model_usage['tokens'] += total_tokens
|
||||
|
||||
cost = self._calculate_cost(model, input_tokens, output_tokens)
|
||||
model_usage['cost'] += cost
|
||||
self.session_usage['estimated_cost'] += cost
|
||||
|
||||
self._save_to_history(model, input_tokens, output_tokens, cost)
|
||||
|
||||
logger.debug(
|
||||
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
|
||||
)
|
||||
|
||||
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
if model not in MODEL_COSTS:
|
||||
base_model = model.split('/')[0] if '/' in model else model
|
||||
if base_model not in MODEL_COSTS:
|
||||
logger.warning(f"Unknown model for cost calculation: {model}")
|
||||
return 0.0
|
||||
costs = MODEL_COSTS[base_model]
|
||||
else:
|
||||
costs = MODEL_COSTS[model]
|
||||
|
||||
input_cost = (input_tokens / 1000) * costs['input']
|
||||
output_cost = (output_tokens / 1000) * costs['output']
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
|
||||
try:
|
||||
history = []
|
||||
if os.path.exists(USAGE_DB_FILE):
|
||||
with open(USAGE_DB_FILE, 'r') as f:
|
||||
history = json.load(f)
|
||||
|
||||
history.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model': model,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': input_tokens + output_tokens,
|
||||
'cost': cost
|
||||
})
|
||||
|
||||
if len(history) > 10000:
|
||||
history = history[-10000:]
|
||||
|
||||
with open(USAGE_DB_FILE, 'w') as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving usage history: {e}")
|
||||
|
||||
def get_session_summary(self) -> Dict:
|
||||
return self.session_usage.copy()
|
||||
|
||||
def get_formatted_summary(self) -> str:
|
||||
usage = self.session_usage
|
||||
lines = [
|
||||
"\n=== Session Usage Summary ===",
|
||||
f"Total Requests: {usage['requests']}",
|
||||
f"Total Tokens: {usage['total_tokens']:,}",
|
||||
f" Input: {usage['input_tokens']:,}",
|
||||
f" Output: {usage['output_tokens']:,}",
|
||||
f"Estimated Cost: ${usage['estimated_cost']:.4f}",
|
||||
]
|
||||
|
||||
if usage['models_used']:
|
||||
lines.append("\nModels Used:")
|
||||
for model, stats in usage['models_used'].items():
|
||||
lines.append(
|
||||
f" {model}: {stats['requests']} requests, "
|
||||
f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}"
|
||||
)
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
@staticmethod
|
||||
def get_total_usage() -> Dict:
|
||||
if not os.path.exists(USAGE_DB_FILE):
|
||||
return {
|
||||
'total_requests': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0
|
||||
}
|
||||
|
||||
try:
|
||||
with open(USAGE_DB_FILE, 'r') as f:
|
||||
history = json.load(f)
|
||||
|
||||
total_tokens = sum(entry['total_tokens'] for entry in history)
|
||||
total_cost = sum(entry['cost'] for entry in history)
|
||||
|
||||
return {
|
||||
'total_requests': len(history),
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': total_cost
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading usage history: {e}")
|
||||
return {
|
||||
'total_requests': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0
|
||||
}
|
||||
86
pr/core/validation.py
Normal file
86
pr/core/validation.py
Normal file
@ -0,0 +1,86 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from pr.core.exceptions import ValidationError
|
||||
|
||||
|
||||
def validate_file_path(path: str, must_exist: bool = False) -> str:
|
||||
if not path:
|
||||
raise ValidationError("File path cannot be empty")
|
||||
|
||||
if must_exist and not os.path.exists(path):
|
||||
raise ValidationError(f"File does not exist: {path}")
|
||||
|
||||
if must_exist and os.path.isdir(path):
|
||||
raise ValidationError(f"Path is a directory, not a file: {path}")
|
||||
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
|
||||
if not path:
|
||||
raise ValidationError("Directory path cannot be empty")
|
||||
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
if must_exist and not os.path.exists(abs_path):
|
||||
if create:
|
||||
os.makedirs(abs_path, exist_ok=True)
|
||||
else:
|
||||
raise ValidationError(f"Directory does not exist: {abs_path}")
|
||||
|
||||
if os.path.exists(abs_path) and not os.path.isdir(abs_path):
|
||||
raise ValidationError(f"Path is not a directory: {abs_path}")
|
||||
|
||||
return abs_path
|
||||
|
||||
|
||||
def validate_model_name(model: str) -> str:
|
||||
if not model:
|
||||
raise ValidationError("Model name cannot be empty")
|
||||
|
||||
if len(model) < 2:
|
||||
raise ValidationError("Model name too short")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def validate_api_url(url: str) -> str:
|
||||
if not url:
|
||||
raise ValidationError("API URL cannot be empty")
|
||||
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
raise ValidationError("API URL must start with http:// or https://")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_session_name(name: str) -> str:
|
||||
if not name:
|
||||
raise ValidationError("Session name cannot be empty")
|
||||
|
||||
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
||||
for char in invalid_chars:
|
||||
if char in name:
|
||||
raise ValidationError(f"Session name contains invalid character: {char}")
|
||||
|
||||
if len(name) > 255:
|
||||
raise ValidationError("Session name too long (max 255 characters)")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def validate_temperature(temp: float) -> float:
|
||||
if not 0.0 <= temp <= 2.0:
|
||||
raise ValidationError("Temperature must be between 0.0 and 2.0")
|
||||
|
||||
return temp
|
||||
|
||||
|
||||
def validate_max_tokens(tokens: int) -> int:
|
||||
if tokens < 1:
|
||||
raise ValidationError("Max tokens must be at least 1")
|
||||
|
||||
if tokens > 100000:
|
||||
raise ValidationError("Max tokens too high (max 100000)")
|
||||
|
||||
return tokens
|
||||
994
pr/editor.py
Normal file
994
pr/editor.py
Normal file
@ -0,0 +1,994 @@
|
||||
#!/usr/bin/env python3
|
||||
import curses
|
||||
import threading
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
import atexit
|
||||
import signal
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
|
||||
class RPEditor:
|
||||
def __init__(self, filename=None, auto_save=False, timeout=30):
|
||||
"""
|
||||
Initialize RPEditor with enhanced robustness features.
|
||||
|
||||
Args:
|
||||
filename: File to edit
|
||||
auto_save: Enable auto-save on exit
|
||||
timeout: Command timeout in seconds
|
||||
"""
|
||||
self.filename = filename
|
||||
self.lines = [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
self.stdscr = None
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self.socket_thread = None
|
||||
self.prev_key = None
|
||||
self.clipboard = ""
|
||||
self.undo_stack = []
|
||||
self.redo_stack = []
|
||||
self.selection_start = None
|
||||
self.selection_end = None
|
||||
self.max_undo = 100
|
||||
self.lock = threading.RLock()
|
||||
self.command_queue = queue.Queue()
|
||||
self.auto_save = auto_save
|
||||
self.timeout = timeout
|
||||
self._cleanup_registered = False
|
||||
self._original_terminal_state = None
|
||||
self._exception_occurred = False
|
||||
|
||||
# Create socket pair with error handling
|
||||
try:
|
||||
self.client_sock, self.server_sock = socket.socketpair()
|
||||
self.client_sock.settimeout(self.timeout)
|
||||
self.server_sock.settimeout(self.timeout)
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
raise RuntimeError(f"Failed to create socket pair: {e}")
|
||||
|
||||
# Register cleanup handlers
|
||||
self._register_cleanup()
|
||||
|
||||
if filename:
|
||||
self.load_file()
|
||||
|
||||
def _register_cleanup(self):
|
||||
"""Register cleanup handlers for proper shutdown."""
|
||||
if not self._cleanup_registered:
|
||||
atexit.register(self._cleanup)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
self._cleanup_registered = True
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle signals for clean shutdown."""
|
||||
self._cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
def _cleanup(self):
|
||||
"""Comprehensive cleanup of all resources."""
|
||||
try:
|
||||
# Stop the editor
|
||||
self.running = False
|
||||
|
||||
# Save if auto-save is enabled
|
||||
if self.auto_save and self.filename and not self._exception_occurred:
|
||||
try:
|
||||
self._save_file()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Clean up curses
|
||||
if self.stdscr:
|
||||
try:
|
||||
self.stdscr.keypad(False)
|
||||
curses.nocbreak()
|
||||
curses.echo()
|
||||
curses.curs_set(1)
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
curses.endwin()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Clear screen after curses cleanup
|
||||
try:
|
||||
os.system('clear' if os.name != 'nt' else 'cls')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Close sockets
|
||||
for sock in [self.client_sock, self.server_sock]:
|
||||
if sock:
|
||||
try:
|
||||
sock.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Wait for threads to finish
|
||||
for thread in [self.thread, self.socket_thread]:
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=1)
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
def load_file(self):
|
||||
"""Load file with enhanced error handling."""
|
||||
try:
|
||||
if os.path.exists(self.filename):
|
||||
with open(self.filename, 'r', encoding='utf-8', errors='replace') as f:
|
||||
content = f.read()
|
||||
self.lines = content.splitlines() if content else [""]
|
||||
else:
|
||||
self.lines = [""]
|
||||
except Exception as e:
|
||||
self.lines = [""]
|
||||
# Don't raise, just use empty content
|
||||
|
||||
def _save_file(self):
|
||||
"""Save file with enhanced error handling and backup."""
|
||||
with self.lock:
|
||||
if not self.filename:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create backup if file exists
|
||||
if os.path.exists(self.filename):
|
||||
backup_name = f"{self.filename}.bak"
|
||||
try:
|
||||
with open(self.filename, 'r', encoding='utf-8') as f:
|
||||
backup_content = f.read()
|
||||
with open(backup_name, 'w', encoding='utf-8') as f:
|
||||
f.write(backup_content)
|
||||
except:
|
||||
pass # Backup failed, but continue with save
|
||||
|
||||
# Save the file
|
||||
with open(self.filename, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(self.lines))
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def save_file(self):
|
||||
"""Thread-safe save file command."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
|
||||
except:
|
||||
return self._save_file() # Fallback to direct save
|
||||
|
||||
def start(self):
|
||||
"""Start the editor with enhanced error handling."""
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.running = True
|
||||
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
|
||||
self.socket_thread.start()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
self.thread.start()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.running = False
|
||||
self._cleanup()
|
||||
raise RuntimeError(f"Failed to start editor: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the editor with proper cleanup."""
|
||||
try:
|
||||
if self.client_sock:
|
||||
self.client_sock.send(pickle.dumps({'command': 'stop'}))
|
||||
except:
|
||||
pass
|
||||
|
||||
self.running = False
|
||||
time.sleep(0.1) # Give threads time to finish
|
||||
self._cleanup()
|
||||
|
||||
def run(self):
|
||||
"""Run the main editor loop with exception handling."""
|
||||
try:
|
||||
curses.wrapper(self.main_loop)
|
||||
except Exception as e:
|
||||
self._exception_occurred = True
|
||||
self._cleanup()
|
||||
|
||||
def main_loop(self, stdscr):
|
||||
"""Main editor loop with enhanced error recovery."""
|
||||
self.stdscr = stdscr
|
||||
|
||||
try:
|
||||
# Configure curses
|
||||
curses.curs_set(1)
|
||||
self.stdscr.keypad(True)
|
||||
self.stdscr.timeout(100) # Non-blocking with timeout
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process queued commands
|
||||
while True:
|
||||
try:
|
||||
command = self.command_queue.get_nowait()
|
||||
with self.lock:
|
||||
self.execute_command(command)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Draw screen
|
||||
with self.lock:
|
||||
self.draw()
|
||||
|
||||
# Handle input
|
||||
try:
|
||||
key = self.stdscr.getch()
|
||||
if key != -1: # -1 means timeout/no input
|
||||
with self.lock:
|
||||
self.handle_key(key)
|
||||
except curses.error:
|
||||
pass # Ignore curses errors
|
||||
|
||||
except Exception as e:
|
||||
# Log error but continue running
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
self._exception_occurred = True
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def draw(self):
|
||||
"""Draw the editor screen with error handling."""
|
||||
try:
|
||||
self.stdscr.clear()
|
||||
height, width = self.stdscr.getmaxyx()
|
||||
|
||||
# Draw lines
|
||||
for i, line in enumerate(self.lines):
|
||||
if i >= height - 1:
|
||||
break
|
||||
try:
|
||||
# Handle long lines and special characters
|
||||
display_line = line[:width-1] if len(line) >= width else line
|
||||
self.stdscr.addstr(i, 0, display_line)
|
||||
except curses.error:
|
||||
pass # Skip lines that can't be displayed
|
||||
|
||||
# Draw status line
|
||||
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
|
||||
if self.mode == 'command':
|
||||
status = self.command[:width-1]
|
||||
|
||||
try:
|
||||
self.stdscr.addstr(height - 1, 0, status[:width-1])
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Position cursor
|
||||
cursor_x = min(self.cursor_x, width - 1)
|
||||
cursor_y = min(self.cursor_y, height - 2)
|
||||
try:
|
||||
self.stdscr.move(cursor_y, cursor_x)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
self.stdscr.refresh()
|
||||
except Exception:
|
||||
pass # Continue even if draw fails
|
||||
|
||||
def handle_key(self, key):
|
||||
"""Handle keyboard input with error recovery."""
|
||||
try:
|
||||
if self.mode == 'normal':
|
||||
self.handle_normal(key)
|
||||
elif self.mode == 'insert':
|
||||
self.handle_insert(key)
|
||||
elif self.mode == 'command':
|
||||
self.handle_command(key)
|
||||
except Exception:
|
||||
pass # Continue on error
|
||||
|
||||
def handle_normal(self, key):
|
||||
"""Handle normal mode keys."""
|
||||
try:
|
||||
if key == ord('h') or key == curses.KEY_LEFT:
|
||||
self.move_cursor(0, -1)
|
||||
elif key == ord('j') or key == curses.KEY_DOWN:
|
||||
self.move_cursor(1, 0)
|
||||
elif key == ord('k') or key == curses.KEY_UP:
|
||||
self.move_cursor(-1, 0)
|
||||
elif key == ord('l') or key == curses.KEY_RIGHT:
|
||||
self.move_cursor(0, 1)
|
||||
elif key == ord('i'):
|
||||
self.mode = 'insert'
|
||||
elif key == ord(':'):
|
||||
self.mode = 'command'
|
||||
self.command = ":"
|
||||
elif key == ord('x'):
|
||||
self._delete_char()
|
||||
elif key == ord('a'):
|
||||
self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y]))
|
||||
self.mode = 'insert'
|
||||
elif key == ord('A'):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
self.mode = 'insert'
|
||||
elif key == ord('o'):
|
||||
self._insert_line(self.cursor_y + 1, "")
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('O'):
|
||||
self._insert_line(self.cursor_y, "")
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('d') and self.prev_key == ord('d'):
|
||||
if self.cursor_y < len(self.lines):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
self._delete_line(self.cursor_y)
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.cursor_y = max(0, len(self.lines) - 1)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('y') and self.prev_key == ord('y'):
|
||||
if self.cursor_y < len(self.lines):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
elif key == ord('p'):
|
||||
self._insert_line(self.cursor_y + 1, self.clipboard)
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('P'):
|
||||
self._insert_line(self.cursor_y, self.clipboard)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('w'):
|
||||
self._move_word_forward()
|
||||
elif key == ord('b'):
|
||||
self._move_word_backward()
|
||||
elif key == ord('0'):
|
||||
self.cursor_x = 0
|
||||
elif key == ord('$'):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
elif key == ord('g'):
|
||||
if self.prev_key == ord('g'):
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
elif key == ord('G'):
|
||||
self.cursor_y = max(0, len(self.lines) - 1)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('u'):
|
||||
self.undo()
|
||||
elif key == ord('r') and self.prev_key == 18: # Ctrl-R
|
||||
self.redo()
|
||||
|
||||
self.prev_key = key
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _move_word_forward(self):
|
||||
"""Move cursor forward by word."""
|
||||
if self.cursor_y >= len(self.lines):
|
||||
return
|
||||
line = self.lines[self.cursor_y]
|
||||
i = self.cursor_x
|
||||
# Skip non-alphanumeric
|
||||
while i < len(line) and not line[i].isalnum():
|
||||
i += 1
|
||||
# Skip alphanumeric
|
||||
while i < len(line) and line[i].isalnum():
|
||||
i += 1
|
||||
self.cursor_x = i
|
||||
|
||||
def _move_word_backward(self):
|
||||
"""Move cursor backward by word."""
|
||||
if self.cursor_y >= len(self.lines):
|
||||
return
|
||||
line = self.lines[self.cursor_y]
|
||||
i = max(0, self.cursor_x - 1)
|
||||
# Skip non-alphanumeric
|
||||
while i >= 0 and not line[i].isalnum():
|
||||
i -= 1
|
||||
# Skip alphanumeric
|
||||
while i >= 0 and line[i].isalnum():
|
||||
i -= 1
|
||||
self.cursor_x = max(0, i + 1)
|
||||
|
||||
def handle_insert(self, key):
|
||||
"""Handle insert mode keys."""
|
||||
try:
|
||||
if key == 27: # ESC
|
||||
self.mode = 'normal'
|
||||
if self.cursor_x > 0:
|
||||
self.cursor_x -= 1
|
||||
elif key == 10 or key == 13: # Enter
|
||||
self._split_line()
|
||||
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
|
||||
self._backspace()
|
||||
elif 32 <= key <= 126:
|
||||
char = chr(key)
|
||||
self._insert_char(char)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_command(self, key):
|
||||
"""Handle command mode keys."""
|
||||
try:
|
||||
if key == 10 or key == 13: # Enter
|
||||
cmd = self.command[1:].strip()
|
||||
if cmd in ["q", "q!"]:
|
||||
self.running = False
|
||||
elif cmd == "w":
|
||||
self._save_file()
|
||||
elif cmd in ["wq", "wq!", "x", "xq", "x!"]:
|
||||
self._save_file()
|
||||
self.running = False
|
||||
elif cmd.startswith("w "):
|
||||
self.filename = cmd[2:].strip()
|
||||
self._save_file()
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
elif key == 27: # ESC
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
|
||||
if len(self.command) > 1:
|
||||
self.command = self.command[:-1]
|
||||
elif 32 <= key <= 126:
|
||||
self.command += chr(key)
|
||||
except Exception:
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
|
||||
def move_cursor(self, dy, dx):
|
||||
"""Move cursor with bounds checking."""
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
|
||||
new_y = self.cursor_y + dy
|
||||
new_x = self.cursor_x + dx
|
||||
|
||||
# Ensure valid Y position
|
||||
if 0 <= new_y < len(self.lines):
|
||||
self.cursor_y = new_y
|
||||
# Ensure valid X position for new line
|
||||
max_x = len(self.lines[self.cursor_y])
|
||||
self.cursor_x = max(0, min(new_x, max_x))
|
||||
elif new_y < 0:
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
elif new_y >= len(self.lines):
|
||||
self.cursor_y = max(0, len(self.lines) - 1)
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
|
||||
def save_state(self):
|
||||
"""Save current state for undo."""
|
||||
with self.lock:
|
||||
state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.undo_stack.append(state)
|
||||
if len(self.undo_stack) > self.max_undo:
|
||||
self.undo_stack.pop(0)
|
||||
self.redo_stack.clear()
|
||||
|
||||
def undo(self):
|
||||
"""Undo last change."""
|
||||
with self.lock:
|
||||
if self.undo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.redo_stack.append(current_state)
|
||||
state = self.undo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
|
||||
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
|
||||
|
||||
def redo(self):
|
||||
"""Redo last undone change."""
|
||||
with self.lock:
|
||||
if self.redo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.undo_stack.append(current_state)
|
||||
state = self.redo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
|
||||
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
|
||||
|
||||
def _insert_text(self, text):
|
||||
"""Insert text at cursor position."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
self.save_state()
|
||||
lines = text.split('\n')
|
||||
|
||||
if len(lines) == 1:
|
||||
# Single line insert
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:]
|
||||
self.cursor_x += len(text)
|
||||
else:
|
||||
# Multi-line insert
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
|
||||
self.lines[self.cursor_y] = first
|
||||
for i in range(1, len(lines) - 1):
|
||||
self.lines.insert(self.cursor_y + i, lines[i])
|
||||
self.lines.insert(self.cursor_y + len(lines) - 1, last)
|
||||
|
||||
self.cursor_y += len(lines) - 1
|
||||
self.cursor_x = len(lines[-1])
|
||||
|
||||
def insert_text(self, text):
|
||||
"""Thread-safe text insertion."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._insert_text(text)
|
||||
|
||||
def _delete_char(self):
|
||||
"""Delete character at cursor."""
|
||||
self.save_state()
|
||||
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:]
|
||||
|
||||
def delete_char(self):
|
||||
"""Thread-safe character deletion."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._delete_char()
|
||||
|
||||
def _insert_char(self, char):
|
||||
"""Insert single character."""
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:]
|
||||
self.cursor_x += 1
|
||||
|
||||
def _split_line(self):
|
||||
"""Split line at cursor."""
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
|
||||
def _backspace(self):
|
||||
"""Handle backspace key."""
|
||||
if self.cursor_x > 0:
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:]
|
||||
self.cursor_x -= 1
|
||||
elif self.cursor_y > 0:
|
||||
prev_len = len(self.lines[self.cursor_y - 1])
|
||||
self.lines[self.cursor_y - 1] += self.lines[self.cursor_y]
|
||||
del self.lines[self.cursor_y]
|
||||
self.cursor_y -= 1
|
||||
self.cursor_x = prev_len
|
||||
|
||||
def _insert_line(self, line_num, text):
|
||||
"""Insert a new line."""
|
||||
self.save_state()
|
||||
line_num = max(0, min(line_num, len(self.lines)))
|
||||
self.lines.insert(line_num, text)
|
||||
|
||||
def _delete_line(self, line_num):
|
||||
"""Delete a line."""
|
||||
self.save_state()
|
||||
if 0 <= line_num < len(self.lines):
|
||||
if len(self.lines) > 1:
|
||||
del self.lines[line_num]
|
||||
else:
|
||||
self.lines = [""]
|
||||
|
||||
def _set_text(self, text):
|
||||
"""Set entire text content."""
|
||||
self.save_state()
|
||||
self.lines = text.splitlines() if text else [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
|
||||
def set_text(self, text):
|
||||
"""Thread-safe text setting."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._set_text(text)
|
||||
|
||||
def _goto_line(self, line_num):
|
||||
"""Go to specific line."""
|
||||
line_num = max(0, min(line_num - 1, len(self.lines) - 1))
|
||||
self.cursor_y = line_num
|
||||
self.cursor_x = 0
|
||||
|
||||
def goto_line(self, line_num):
|
||||
"""Thread-safe goto line."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._goto_line(line_num)
|
||||
|
||||
def get_text(self):
|
||||
"""Get entire text content."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
|
||||
data = self.client_sock.recv(65536)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
with self.lock:
|
||||
return '\n'.join(self.lines)
|
||||
|
||||
def get_cursor(self):
|
||||
"""Get cursor position."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
|
||||
data = self.client_sock.recv(4096)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
with self.lock:
|
||||
return (self.cursor_y, self.cursor_x)
|
||||
|
||||
def get_file_info(self):
|
||||
"""Get file information."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
|
||||
data = self.client_sock.recv(4096)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
with self.lock:
|
||||
return {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
}
|
||||
|
||||
def socket_listener(self):
|
||||
"""Listen for socket commands with error handling."""
|
||||
while self.running:
|
||||
try:
|
||||
data = self.server_sock.recv(65536)
|
||||
if not data:
|
||||
break
|
||||
command = pickle.loads(data)
|
||||
self.command_queue.put(command)
|
||||
except socket.timeout:
|
||||
continue
|
||||
except OSError:
|
||||
if self.running:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def execute_command(self, command):
|
||||
"""Execute command with error handling."""
|
||||
try:
|
||||
cmd = command.get('command')
|
||||
|
||||
if cmd == 'insert_text':
|
||||
self._insert_text(command.get('text', ''))
|
||||
elif cmd == 'delete_char':
|
||||
self._delete_char()
|
||||
elif cmd == 'save_file':
|
||||
self._save_file()
|
||||
elif cmd == 'set_text':
|
||||
self._set_text(command.get('text', ''))
|
||||
elif cmd == 'goto_line':
|
||||
self._goto_line(command.get('line_num', 1))
|
||||
elif cmd == 'get_text':
|
||||
result = '\n'.join(self.lines)
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'get_cursor':
|
||||
result = (self.cursor_y, self.cursor_x)
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'get_file_info':
|
||||
result = {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
}
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'stop':
|
||||
self.running = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Additional public methods for backwards compatibility
|
||||
|
||||
def move_cursor_to(self, y, x):
|
||||
"""Move cursor to specific position."""
|
||||
with self.lock:
|
||||
self.cursor_y = max(0, min(y, len(self.lines) - 1))
|
||||
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
|
||||
|
||||
def get_line(self, line_num):
|
||||
"""Get specific line."""
|
||||
with self.lock:
|
||||
if 0 <= line_num < len(self.lines):
|
||||
return self.lines[line_num]
|
||||
return None
|
||||
|
||||
def get_lines(self, start, end):
|
||||
"""Get range of lines."""
|
||||
with self.lock:
|
||||
start = max(0, start)
|
||||
end = min(end, len(self.lines))
|
||||
return self.lines[start:end]
|
||||
|
||||
def insert_at_line(self, line_num, text):
|
||||
"""Insert text at specific line."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
line_num = max(0, min(line_num, len(self.lines)))
|
||||
self.lines.insert(line_num, text)
|
||||
|
||||
def delete_lines(self, start, end):
|
||||
"""Delete range of lines."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
start = max(0, start)
|
||||
end = min(end, len(self.lines))
|
||||
if start < end:
|
||||
del self.lines[start:end]
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
|
||||
def replace_text(self, start_line, start_col, end_line, end_col, new_text):
|
||||
"""Replace text in range."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
|
||||
# Validate bounds
|
||||
start_line = max(0, min(start_line, len(self.lines) - 1))
|
||||
end_line = max(0, min(end_line, len(self.lines) - 1))
|
||||
|
||||
if start_line == end_line:
|
||||
line = self.lines[start_line]
|
||||
start_col = max(0, min(start_col, len(line)))
|
||||
end_col = max(0, min(end_col, len(line)))
|
||||
self.lines[start_line] = line[:start_col] + new_text + line[end_col:]
|
||||
else:
|
||||
first_part = self.lines[start_line][:start_col]
|
||||
last_part = self.lines[end_line][end_col:]
|
||||
new_lines = new_text.split('\n')
|
||||
self.lines[start_line] = first_part + new_lines[0]
|
||||
del self.lines[start_line + 1:end_line + 1]
|
||||
for i, new_line in enumerate(new_lines[1:], 1):
|
||||
self.lines.insert(start_line + i, new_line)
|
||||
if len(new_lines) > 1:
|
||||
self.lines[start_line + len(new_lines) - 1] += last_part
|
||||
else:
|
||||
self.lines[start_line] += last_part
|
||||
|
||||
def search(self, pattern, start_line=0):
|
||||
"""Search for pattern in text."""
|
||||
with self.lock:
|
||||
results = []
|
||||
try:
|
||||
for i in range(start_line, len(self.lines)):
|
||||
matches = re.finditer(pattern, self.lines[i])
|
||||
for match in matches:
|
||||
results.append((i, match.start(), match.end()))
|
||||
except re.error:
|
||||
pass
|
||||
return results
|
||||
|
||||
def replace_all(self, search_text, replace_text):
|
||||
"""Replace all occurrences of text."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
for i in range(len(self.lines)):
|
||||
self.lines[i] = self.lines[i].replace(search_text, replace_text)
|
||||
|
||||
def select_range(self, start_line, start_col, end_line, end_col):
|
||||
"""Select text range."""
|
||||
with self.lock:
|
||||
self.selection_start = (start_line, start_col)
|
||||
self.selection_end = (end_line, end_col)
|
||||
|
||||
def get_selection(self):
|
||||
"""Get selected text."""
|
||||
with self.lock:
|
||||
if not self.selection_start or not self.selection_end:
|
||||
return ""
|
||||
|
||||
sl, sc = self.selection_start
|
||||
el, ec = self.selection_end
|
||||
|
||||
# Validate bounds
|
||||
if sl < 0 or sl >= len(self.lines) or el < 0 or el >= len(self.lines):
|
||||
return ""
|
||||
|
||||
if sl == el:
|
||||
return self.lines[sl][sc:ec]
|
||||
|
||||
result = [self.lines[sl][sc:]]
|
||||
for i in range(sl + 1, el):
|
||||
if i < len(self.lines):
|
||||
result.append(self.lines[i])
|
||||
if el < len(self.lines):
|
||||
result.append(self.lines[el][:ec])
|
||||
return '\n'.join(result)
|
||||
|
||||
def delete_selection(self):
|
||||
"""Delete selected text."""
|
||||
with self.lock:
|
||||
if not self.selection_start or not self.selection_end:
|
||||
return
|
||||
self.save_state()
|
||||
sl, sc = self.selection_start
|
||||
el, ec = self.selection_end
|
||||
if 0 <= sl < len(self.lines) and 0 <= el < len(self.lines):
|
||||
self.replace_text(sl, sc, el, ec, "")
|
||||
self.selection_start = None
|
||||
self.selection_end = None
|
||||
|
||||
def apply_search_replace_block(self, search_block, replace_block):
|
||||
"""Apply search and replace on block."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
search_lines = search_block.splitlines()
|
||||
replace_lines = replace_block.splitlines()
|
||||
|
||||
for i in range(len(self.lines) - len(search_lines) + 1):
|
||||
match = True
|
||||
for j, search_line in enumerate(search_lines):
|
||||
if i + j >= len(self.lines):
|
||||
match = False
|
||||
break
|
||||
if self.lines[i + j].strip() != search_line.strip():
|
||||
match = False
|
||||
break
|
||||
|
||||
if match:
|
||||
# Preserve indentation
|
||||
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
|
||||
indented_replace = [' ' * indent + line for line in replace_lines]
|
||||
self.lines[i:i+len(search_lines)] = indented_replace
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_diff(self, diff_text):
|
||||
"""Apply unified diff."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
try:
|
||||
lines = diff_text.split('\n')
|
||||
start_line = 0
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('@@'):
|
||||
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
|
||||
if match:
|
||||
start_line = int(match.group(1)) - 1
|
||||
elif line.startswith('-'):
|
||||
if start_line < len(self.lines):
|
||||
del self.lines[start_line]
|
||||
elif line.startswith('+'):
|
||||
self.lines.insert(start_line, line[1:])
|
||||
start_line += 1
|
||||
elif line and not line.startswith('\\'):
|
||||
start_line += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_context(self, line_num, context_lines=3):
|
||||
"""Get lines around specific line."""
|
||||
with self.lock:
|
||||
start = max(0, line_num - context_lines)
|
||||
end = min(len(self.lines), line_num + context_lines + 1)
|
||||
return self.get_lines(start, end)
|
||||
|
||||
def count_lines(self):
|
||||
"""Count total lines."""
|
||||
with self.lock:
|
||||
return len(self.lines)
|
||||
|
||||
def close(self):
|
||||
"""Close the editor."""
|
||||
self.stop()
|
||||
|
||||
def is_running(self):
|
||||
"""Check if editor is running."""
|
||||
return self.running
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Wait for editor to finish."""
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=timeout)
|
||||
return not self.thread.is_alive()
|
||||
return True
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
if exc_type:
|
||||
self._exception_occurred = True
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor for cleanup."""
|
||||
self._cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point with error handling."""
|
||||
editor = None
|
||||
try:
|
||||
filename = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
|
||||
# Parse additional arguments
|
||||
auto_save = '--auto-save' in sys.argv
|
||||
|
||||
# Create and start editor
|
||||
editor = RPEditor(filename, auto_save=auto_save)
|
||||
editor.start()
|
||||
|
||||
# Wait for editor to finish
|
||||
if editor.thread:
|
||||
editor.thread.join()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
finally:
|
||||
if editor:
|
||||
editor.stop()
|
||||
# Ensure screen is cleared
|
||||
os.system('clear' if os.name != 'nt' else 'cls')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "rpe" in sys.argv[0]:
|
||||
main()
|
||||
587
pr/editor2.py
Normal file
587
pr/editor2.py
Normal file
@ -0,0 +1,587 @@
|
||||
#!/usr/bin/env python3
|
||||
import curses
|
||||
import threading
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import pickle
|
||||
import queue
|
||||
|
||||
class RPEditor:
|
||||
def __init__(self, filename=None):
|
||||
self.filename = filename
|
||||
self.lines = [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
self.stdscr = None
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self.socket_thread = None
|
||||
self.prev_key = None
|
||||
self.clipboard = ""
|
||||
self.undo_stack = []
|
||||
self.redo_stack = []
|
||||
self.selection_start = None
|
||||
self.selection_end = None
|
||||
self.max_undo = 100
|
||||
self.lock = threading.RLock()
|
||||
self.client_sock, self.server_sock = socket.socketpair()
|
||||
self.command_queue = queue.Queue()
|
||||
if filename:
|
||||
self.load_file()
|
||||
|
||||
def load_file(self):
|
||||
try:
|
||||
with open(self.filename, 'r') as f:
|
||||
self.lines = f.read().splitlines()
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
except:
|
||||
self.lines = [""]
|
||||
|
||||
def _save_file(self):
|
||||
with self.lock:
|
||||
if self.filename:
|
||||
with open(self.filename, 'w') as f:
|
||||
f.write('\n'.join(self.lines))
|
||||
|
||||
def save_file(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
self.socket_thread = threading.Thread(target=self.socket_listener)
|
||||
self.socket_thread.start()
|
||||
self.thread = threading.Thread(target=self.run)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'stop'}))
|
||||
self.running = False
|
||||
if self.stdscr:
|
||||
curses.endwin()
|
||||
if self.thread:
|
||||
self.thread.join()
|
||||
if self.socket_thread:
|
||||
self.socket_thread.join()
|
||||
self.client_sock.close()
|
||||
self.server_sock.close()
|
||||
|
||||
def run(self):
|
||||
curses.wrapper(self.main_loop)
|
||||
|
||||
def main_loop(self, stdscr):
|
||||
self.stdscr = stdscr
|
||||
curses.curs_set(1)
|
||||
self.stdscr.keypad(True)
|
||||
while self.running:
|
||||
with self.lock:
|
||||
self.draw()
|
||||
try:
|
||||
while True:
|
||||
command = self.command_queue.get_nowait()
|
||||
with self.lock:
|
||||
self.execute_command(command)
|
||||
except queue.Empty:
|
||||
pass
|
||||
key = self.stdscr.getch()
|
||||
with self.lock:
|
||||
self.handle_key(key)
|
||||
|
||||
def draw(self):
|
||||
self.stdscr.clear()
|
||||
height, width = self.stdscr.getmaxyx()
|
||||
for i, line in enumerate(self.lines):
|
||||
if i < height - 1:
|
||||
self.stdscr.addstr(i, 0, line[:width])
|
||||
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
|
||||
self.stdscr.addstr(height - 1, 0, status[:width])
|
||||
if self.mode == 'command':
|
||||
self.stdscr.addstr(height - 1, 0, self.command[:width])
|
||||
self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1))
|
||||
self.stdscr.refresh()
|
||||
|
||||
def handle_key(self, key):
|
||||
if self.mode == 'normal':
|
||||
self.handle_normal(key)
|
||||
elif self.mode == 'insert':
|
||||
self.handle_insert(key)
|
||||
elif self.mode == 'command':
|
||||
self.handle_command(key)
|
||||
|
||||
def handle_normal(self, key):
|
||||
if key == ord('h') or key == curses.KEY_LEFT:
|
||||
self.move_cursor(0, -1)
|
||||
elif key == ord('j') or key == curses.KEY_DOWN:
|
||||
self.move_cursor(1, 0)
|
||||
elif key == ord('k') or key == curses.KEY_UP:
|
||||
self.move_cursor(-1, 0)
|
||||
elif key == ord('l') or key == curses.KEY_RIGHT:
|
||||
self.move_cursor(0, 1)
|
||||
elif key == ord('i'):
|
||||
self.mode = 'insert'
|
||||
elif key == ord(':'):
|
||||
self.mode = 'command'
|
||||
self.command = ":"
|
||||
elif key == ord('x'):
|
||||
self._delete_char()
|
||||
elif key == ord('a'):
|
||||
self.cursor_x += 1
|
||||
self.mode = 'insert'
|
||||
elif key == ord('A'):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
self.mode = 'insert'
|
||||
elif key == ord('o'):
|
||||
self._insert_line(self.cursor_y + 1, "")
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('O'):
|
||||
self._insert_line(self.cursor_y, "")
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('d') and self.prev_key == ord('d'):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
self._delete_line(self.cursor_y)
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('y') and self.prev_key == ord('y'):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
elif key == ord('p'):
|
||||
self._insert_line(self.cursor_y + 1, self.clipboard)
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('P'):
|
||||
self._insert_line(self.cursor_y, self.clipboard)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('w'):
|
||||
line = self.lines[self.cursor_y]
|
||||
i = self.cursor_x
|
||||
while i < len(line) and not line[i].isalnum():
|
||||
i += 1
|
||||
while i < len(line) and line[i].isalnum():
|
||||
i += 1
|
||||
self.cursor_x = i
|
||||
elif key == ord('b'):
|
||||
line = self.lines[self.cursor_y]
|
||||
i = self.cursor_x - 1
|
||||
while i >= 0 and not line[i].isalnum():
|
||||
i -= 1
|
||||
while i >= 0 and line[i].isalnum():
|
||||
i -= 1
|
||||
self.cursor_x = i + 1
|
||||
elif key == ord('0'):
|
||||
self.cursor_x = 0
|
||||
elif key == ord('$'):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
elif key == ord('g'):
|
||||
if self.prev_key == ord('g'):
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
elif key == ord('G'):
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('u'):
|
||||
self.undo()
|
||||
elif key == ord('r') and self.prev_key == 18:
|
||||
self.redo()
|
||||
self.prev_key = key
|
||||
|
||||
def handle_insert(self, key):
|
||||
if key == 27:
|
||||
self.mode = 'normal'
|
||||
if self.cursor_x > 0:
|
||||
self.cursor_x -= 1
|
||||
elif key == 10:
|
||||
self._split_line()
|
||||
elif key == curses.KEY_BACKSPACE or key == 127:
|
||||
self._backspace()
|
||||
elif 32 <= key <= 126:
|
||||
char = chr(key)
|
||||
self._insert_char(char)
|
||||
|
||||
def handle_command(self, key):
|
||||
if key == 10:
|
||||
cmd = self.command[1:]
|
||||
if cmd == "q" or cmd == 'q!':
|
||||
self.running = False
|
||||
elif cmd == "w":
|
||||
self._save_file()
|
||||
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
|
||||
self._save_file()
|
||||
self.running = False
|
||||
elif cmd.startswith("w "):
|
||||
self.filename = cmd[2:]
|
||||
self._save_file()
|
||||
elif cmd == "wq":
|
||||
self._save_file()
|
||||
self.running = False
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
elif key == 27:
|
||||
self.mode = 'normal'
|
||||
self.command = ""
|
||||
elif key == curses.KEY_BACKSPACE or key == 127:
|
||||
if len(self.command) > 1:
|
||||
self.command = self.command[:-1]
|
||||
elif 32 <= key <= 126:
|
||||
self.command += chr(key)
|
||||
|
||||
def move_cursor(self, dy, dx):
|
||||
new_y = self.cursor_y + dy
|
||||
new_x = self.cursor_x + dx
|
||||
if 0 <= new_y < len(self.lines):
|
||||
self.cursor_y = new_y
|
||||
self.cursor_x = max(0, min(new_x, len(self.lines[self.cursor_y])))
|
||||
|
||||
def save_state(self):
|
||||
with self.lock:
|
||||
state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.undo_stack.append(state)
|
||||
if len(self.undo_stack) > self.max_undo:
|
||||
self.undo_stack.pop(0)
|
||||
self.redo_stack.clear()
|
||||
|
||||
def undo(self):
|
||||
with self.lock:
|
||||
if self.undo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.redo_stack.append(current_state)
|
||||
state = self.undo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = state['cursor_y']
|
||||
self.cursor_x = state['cursor_x']
|
||||
|
||||
def redo(self):
|
||||
with self.lock:
|
||||
if self.redo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
}
|
||||
self.undo_stack.append(current_state)
|
||||
state = self.redo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = state['cursor_y']
|
||||
self.cursor_x = state['cursor_x']
|
||||
|
||||
def _insert_text(self, text):
|
||||
self.save_state()
|
||||
lines = text.split('\n')
|
||||
if len(lines) == 1:
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.cursor_x += len(text)
|
||||
else:
|
||||
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.lines[self.cursor_y] = first
|
||||
for i in range(1, len(lines)-1):
|
||||
self.lines.insert(self.cursor_y + i, lines[i])
|
||||
self.lines.insert(self.cursor_y + len(lines) - 1, last)
|
||||
self.cursor_y += len(lines) - 1
|
||||
self.cursor_x = len(lines[-1])
|
||||
|
||||
def insert_text(self, text):
|
||||
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
|
||||
|
||||
def _delete_char(self):
|
||||
self.save_state()
|
||||
if self.cursor_x < len(self.lines[self.cursor_y]):
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:]
|
||||
|
||||
def delete_char(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
|
||||
|
||||
def _insert_char(self, char):
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.cursor_x += 1
|
||||
|
||||
def _split_line(self):
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
|
||||
def _backspace(self):
|
||||
if self.cursor_x > 0:
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.cursor_x -= 1
|
||||
elif self.cursor_y > 0:
|
||||
prev_len = len(self.lines[self.cursor_y - 1])
|
||||
self.lines[self.cursor_y - 1] += self.lines[self.cursor_y]
|
||||
del self.lines[self.cursor_y]
|
||||
self.cursor_y -= 1
|
||||
self.cursor_x = prev_len
|
||||
|
||||
def _insert_line(self, line_num, text):
|
||||
self.save_state()
|
||||
line_num = max(0, min(line_num, len(self.lines)))
|
||||
self.lines.insert(line_num, text)
|
||||
|
||||
def _delete_line(self, line_num):
|
||||
self.save_state()
|
||||
if 0 <= line_num < len(self.lines):
|
||||
if len(self.lines) > 1:
|
||||
del self.lines[line_num]
|
||||
else:
|
||||
self.lines = [""]
|
||||
|
||||
def _set_text(self, text):
|
||||
self.save_state()
|
||||
self.lines = text.splitlines() if text else [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
|
||||
def set_text(self, text):
|
||||
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
|
||||
|
||||
def _goto_line(self, line_num):
|
||||
line_num = max(0, min(line_num, len(self.lines) - 1))
|
||||
self.cursor_y = line_num
|
||||
self.cursor_x = 0
|
||||
|
||||
def goto_line(self, line_num):
|
||||
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
|
||||
|
||||
def get_text(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
return ''
|
||||
|
||||
def get_cursor(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
def get_file_info(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
return {}
|
||||
|
||||
def socket_listener(self):
|
||||
while self.running:
|
||||
try:
|
||||
data = self.server_sock.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
command = pickle.loads(data)
|
||||
self.command_queue.put(command)
|
||||
except OSError:
|
||||
break
|
||||
|
||||
def execute_command(self, command):
|
||||
cmd = command.get('command')
|
||||
if cmd == 'insert_text':
|
||||
self._insert_text(command['text'])
|
||||
elif cmd == 'delete_char':
|
||||
self._delete_char()
|
||||
elif cmd == 'save_file':
|
||||
self._save_file()
|
||||
elif cmd == 'set_text':
|
||||
self._set_text(command['text'])
|
||||
elif cmd == 'goto_line':
|
||||
self._goto_line(command['line_num'])
|
||||
elif cmd == 'get_text':
|
||||
result = '\n'.join(self.lines)
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'get_cursor':
|
||||
result = (self.cursor_y, self.cursor_x)
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'get_file_info':
|
||||
result = {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
}
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'stop':
|
||||
self.running = False
|
||||
|
||||
def move_cursor_to(self, y, x):
|
||||
with self.lock:
|
||||
self.cursor_y = max(0, min(y, len(self.lines)-1))
|
||||
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
|
||||
|
||||
def get_line(self, line_num):
|
||||
with self.lock:
|
||||
if 0 <= line_num < len(self.lines):
|
||||
return self.lines[line_num]
|
||||
return None
|
||||
|
||||
def get_lines(self, start, end):
|
||||
with self.lock:
|
||||
start = max(0, start)
|
||||
end = min(end, len(self.lines))
|
||||
return self.lines[start:end]
|
||||
|
||||
def insert_at_line(self, line_num, text):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
line_num = max(0, min(line_num, len(self.lines)))
|
||||
self.lines.insert(line_num, text)
|
||||
|
||||
def delete_lines(self, start, end):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
start = max(0, start)
|
||||
end = min(end, len(self.lines))
|
||||
if start < end:
|
||||
del self.lines[start:end]
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
|
||||
def replace_text(self, start_line, start_col, end_line, end_col, new_text):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
if start_line == end_line:
|
||||
line = self.lines[start_line]
|
||||
self.lines[start_line] = line[:start_col] + new_text + line[end_col:]
|
||||
else:
|
||||
first_part = self.lines[start_line][:start_col]
|
||||
last_part = self.lines[end_line][end_col:]
|
||||
new_lines = new_text.split('\n')
|
||||
self.lines[start_line] = first_part + new_lines[0]
|
||||
del self.lines[start_line + 1:end_line + 1]
|
||||
for i, new_line in enumerate(new_lines[1:], 1):
|
||||
self.lines.insert(start_line + i, new_line)
|
||||
if len(new_lines) > 1:
|
||||
self.lines[start_line + len(new_lines) - 1] += last_part
|
||||
else:
|
||||
self.lines[start_line] += last_part
|
||||
|
||||
def search(self, pattern, start_line=0):
|
||||
with self.lock:
|
||||
results = []
|
||||
for i in range(start_line, len(self.lines)):
|
||||
matches = re.finditer(pattern, self.lines[i])
|
||||
for match in matches:
|
||||
results.append((i, match.start(), match.end()))
|
||||
return results
|
||||
|
||||
def replace_all(self, search_text, replace_text):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
for i in range(len(self.lines)):
|
||||
self.lines[i] = self.lines[i].replace(search_text, replace_text)
|
||||
|
||||
def select_range(self, start_line, start_col, end_line, end_col):
|
||||
with self.lock:
|
||||
self.selection_start = (start_line, start_col)
|
||||
self.selection_end = (end_line, end_col)
|
||||
|
||||
def get_selection(self):
|
||||
with self.lock:
|
||||
if not self.selection_start or not self.selection_end:
|
||||
return ""
|
||||
sl, sc = self.selection_start
|
||||
el, ec = self.selection_end
|
||||
if sl == el:
|
||||
return self.lines[sl][sc:ec]
|
||||
result = [self.lines[sl][sc:]]
|
||||
for i in range(sl + 1, el):
|
||||
result.append(self.lines[i])
|
||||
result.append(self.lines[el][:ec])
|
||||
return '\n'.join(result)
|
||||
|
||||
def delete_selection(self):
|
||||
with self.lock:
|
||||
if not self.selection_start or not self.selection_end:
|
||||
return
|
||||
self.save_state()
|
||||
sl, sc = self.selection_start
|
||||
el, ec = self.selection_end
|
||||
self.replace_text(sl, sc, el, ec, "")
|
||||
self.selection_start = None
|
||||
self.selection_end = None
|
||||
|
||||
def apply_search_replace_block(self, search_block, replace_block):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
search_lines = search_block.splitlines()
|
||||
replace_lines = replace_block.splitlines()
|
||||
for i in range(len(self.lines) - len(search_lines) + 1):
|
||||
match = True
|
||||
for j, search_line in enumerate(search_lines):
|
||||
if self.lines[i + j].strip() != search_line.strip():
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
|
||||
indented_replace = [' ' * indent + line for line in replace_lines]
|
||||
self.lines[i:i+len(search_lines)] = indented_replace
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_diff(self, diff_text):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
lines = diff_text.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('@@'):
|
||||
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
|
||||
if match:
|
||||
start_line = int(match.group(1)) - 1
|
||||
elif line.startswith('-'):
|
||||
if start_line < len(self.lines):
|
||||
del self.lines[start_line]
|
||||
elif line.startswith('+'):
|
||||
self.lines.insert(start_line, line[1:])
|
||||
start_line += 1
|
||||
|
||||
def get_context(self, line_num, context_lines=3):
|
||||
with self.lock:
|
||||
start = max(0, line_num - context_lines)
|
||||
end = min(len(self.lines), line_num + context_lines + 1)
|
||||
return self.get_lines(start, end)
|
||||
|
||||
def count_lines(self):
|
||||
with self.lock:
|
||||
return len(self.lines)
|
||||
|
||||
def close(self):
|
||||
self.running = False
|
||||
self.stop()
|
||||
if self.thread:
|
||||
self.thread.join()
|
||||
|
||||
def main():
|
||||
filename = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
editor = RPEditor(filename)
|
||||
editor.start()
|
||||
editor.thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
7
pr/memory/__init__.py
Normal file
7
pr/memory/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .knowledge_store import KnowledgeStore, KnowledgeEntry
|
||||
from .semantic_index import SemanticIndex
|
||||
from .conversation_memory import ConversationMemory
|
||||
from .fact_extractor import FactExtractor
|
||||
|
||||
__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex',
|
||||
'ConversationMemory', 'FactExtractor']
|
||||
259
pr/memory/conversation_memory.py
Normal file
259
pr/memory/conversation_memory.py
Normal file
@ -0,0 +1,259 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class ConversationMemory:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._initialize_memory()
|
||||
|
||||
def _initialize_memory(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS conversation_history (
|
||||
conversation_id TEXT PRIMARY KEY,
|
||||
session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
summary TEXT,
|
||||
topics TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS conversation_messages (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp REAL NOT NULL,
|
||||
tool_calls TEXT,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO conversation_history
|
||||
(conversation_id, session_id, started_at, metadata)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
conversation_id,
|
||||
session_id,
|
||||
time.time(),
|
||||
json.dumps(metadata) if metadata else None
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def add_message(self, conversation_id: str, message_id: str, role: str,
|
||||
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO conversation_messages
|
||||
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
message_id,
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
time.time(),
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
json.dumps(metadata) if metadata else None
|
||||
))
|
||||
|
||||
cursor.execute('''
|
||||
UPDATE conversation_history
|
||||
SET message_count = message_count + 1
|
||||
WHERE conversation_id = ?
|
||||
''', (conversation_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_conversation_messages(self, conversation_id: str,
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if limit:
|
||||
cursor.execute('''
|
||||
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||||
FROM conversation_messages
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (conversation_id, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||||
FROM conversation_messages
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (conversation_id,))
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append({
|
||||
'message_id': row[0],
|
||||
'role': row[1],
|
||||
'content': row[2],
|
||||
'timestamp': row[3],
|
||||
'tool_calls': json.loads(row[4]) if row[4] else None,
|
||||
'metadata': json.loads(row[5]) if row[5] else None
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return messages
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str,
|
||||
topics: Optional[List[str]] = None):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
UPDATE conversation_history
|
||||
SET summary = ?, topics = ?, ended_at = ?
|
||||
WHERE conversation_id = ?
|
||||
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def search_conversations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
|
||||
h.message_count, h.summary, h.topics
|
||||
FROM conversation_history h
|
||||
LEFT JOIN conversation_messages m ON h.conversation_id = m.conversation_id
|
||||
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
|
||||
ORDER BY h.started_at DESC
|
||||
LIMIT ?
|
||||
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit))
|
||||
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
conversations.append({
|
||||
'conversation_id': row[0],
|
||||
'session_id': row[1],
|
||||
'started_at': row[2],
|
||||
'message_count': row[3],
|
||||
'summary': row[4],
|
||||
'topics': json.loads(row[5]) if row[5] else []
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return conversations
|
||||
|
||||
def get_recent_conversations(self, limit: int = 10,
|
||||
session_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if session_id:
|
||||
cursor.execute('''
|
||||
SELECT conversation_id, session_id, started_at, ended_at,
|
||||
message_count, summary, topics
|
||||
FROM conversation_history
|
||||
WHERE session_id = ?
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (session_id, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT conversation_id, session_id, started_at, ended_at,
|
||||
message_count, summary, topics
|
||||
FROM conversation_history
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
conversations.append({
|
||||
'conversation_id': row[0],
|
||||
'session_id': row[1],
|
||||
'started_at': row[2],
|
||||
'ended_at': row[3],
|
||||
'message_count': row[4],
|
||||
'summary': row[5],
|
||||
'topics': json.loads(row[6]) if row[6] else []
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return conversations
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?',
|
||||
(conversation_id,))
|
||||
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?',
|
||||
(conversation_id,))
|
||||
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM conversation_history')
|
||||
total_conversations = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM conversation_messages')
|
||||
total_messages = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(message_count) FROM conversation_history')
|
||||
total_message_count = cursor.fetchone()[0] or 0
|
||||
|
||||
cursor.execute('''
|
||||
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
|
||||
''')
|
||||
avg_messages = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_conversations': total_conversations,
|
||||
'total_messages': total_messages,
|
||||
'average_messages_per_conversation': round(avg_messages, 2)
|
||||
}
|
||||
146
pr/memory/fact_extractor.py
Normal file
146
pr/memory/fact_extractor.py
Normal file
@ -0,0 +1,146 @@
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict, Any, Set
|
||||
from collections import defaultdict
|
||||
|
||||
class FactExtractor:
|
||||
def __init__(self):
|
||||
self.fact_patterns = [
|
||||
(r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'),
|
||||
(r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'),
|
||||
(r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'),
|
||||
(r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'),
|
||||
(r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'),
|
||||
]
|
||||
|
||||
def extract_facts(self, text: str) -> List[Dict[str, Any]]:
|
||||
facts = []
|
||||
|
||||
for pattern, fact_type in self.fact_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
facts.append({
|
||||
'type': fact_type,
|
||||
'text': match.group(0),
|
||||
'components': match.groups(),
|
||||
'confidence': 0.7
|
||||
})
|
||||
|
||||
noun_phrases = self._extract_noun_phrases(text)
|
||||
for phrase in noun_phrases:
|
||||
if len(phrase.split()) >= 2:
|
||||
facts.append({
|
||||
'type': 'entity',
|
||||
'text': phrase,
|
||||
'components': [phrase],
|
||||
'confidence': 0.5
|
||||
})
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_noun_phrases(self, text: str) -> List[str]:
|
||||
sentences = re.split(r'[.!?]', text)
|
||||
phrases = []
|
||||
|
||||
for sentence in sentences:
|
||||
words = sentence.split()
|
||||
current_phrase = []
|
||||
|
||||
for word in words:
|
||||
if word and word[0].isupper() and len(word) > 1:
|
||||
current_phrase.append(word)
|
||||
else:
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(' '.join(current_phrase))
|
||||
current_phrase = []
|
||||
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(' '.join(current_phrase))
|
||||
|
||||
return list(set(phrases))
|
||||
|
||||
def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]:
|
||||
words = re.findall(r'\b[a-z]{4,}\b', text.lower())
|
||||
|
||||
stopwords = {
|
||||
'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when',
|
||||
'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could',
|
||||
'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them',
|
||||
'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before',
|
||||
'after', 'above', 'below', 'between', 'under', 'again', 'further',
|
||||
'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same',
|
||||
'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like'
|
||||
}
|
||||
|
||||
filtered_words = [w for w in words if w not in stopwords]
|
||||
|
||||
word_freq = defaultdict(int)
|
||||
for word in filtered_words:
|
||||
word_freq[word] += 1
|
||||
|
||||
sorted_terms = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_terms[:top_k]
|
||||
|
||||
def extract_relationships(self, text: str) -> List[Dict[str, Any]]:
|
||||
relationships = []
|
||||
|
||||
relationship_patterns = [
|
||||
(r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'),
|
||||
(r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'),
|
||||
(r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'),
|
||||
(r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'),
|
||||
]
|
||||
|
||||
for pattern, rel_type in relationship_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
relationships.append({
|
||||
'type': rel_type,
|
||||
'subject': match.group(1),
|
||||
'predicate': match.group(2),
|
||||
'object': match.group(3),
|
||||
'confidence': 0.6
|
||||
})
|
||||
|
||||
return relationships
|
||||
|
||||
def extract_metadata(self, text: str) -> Dict[str, Any]:
|
||||
word_count = len(text.split())
|
||||
sentence_count = len(re.split(r'[.!?]', text))
|
||||
|
||||
urls = re.findall(r'https?://[^\s]+', text)
|
||||
email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text)
|
||||
dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text)
|
||||
numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text)
|
||||
|
||||
return {
|
||||
'word_count': word_count,
|
||||
'sentence_count': sentence_count,
|
||||
'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2),
|
||||
'urls': urls,
|
||||
'email_addresses': email_addresses,
|
||||
'dates': dates,
|
||||
'numeric_values': numbers,
|
||||
'has_code': bool(re.search(r'```|def |class |import |function ', text)),
|
||||
'has_questions': bool(re.search(r'\?', text))
|
||||
}
|
||||
|
||||
def categorize_content(self, text: str) -> List[str]:
|
||||
categories = []
|
||||
|
||||
category_keywords = {
|
||||
'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'],
|
||||
'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'],
|
||||
'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'],
|
||||
'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'],
|
||||
'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'],
|
||||
'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'],
|
||||
'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'],
|
||||
}
|
||||
|
||||
text_lower = text.lower()
|
||||
for category, keywords in category_keywords.items():
|
||||
if any(keyword in text_lower for keyword in keywords):
|
||||
categories.append(category)
|
||||
|
||||
return categories if categories else ['general']
|
||||
265
pr/memory/knowledge_store.py
Normal file
265
pr/memory/knowledge_store.py
Normal file
@ -0,0 +1,265 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from .semantic_index import SemanticIndex
|
||||
|
||||
@dataclass
|
||||
class KnowledgeEntry:
|
||||
entry_id: str
|
||||
category: str
|
||||
content: str
|
||||
metadata: Dict[str, Any]
|
||||
created_at: float
|
||||
updated_at: float
|
||||
access_count: int = 0
|
||||
importance_score: float = 1.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'entry_id': self.entry_id,
|
||||
'category': self.category,
|
||||
'content': self.content,
|
||||
'metadata': self.metadata,
|
||||
'created_at': self.created_at,
|
||||
'updated_at': self.updated_at,
|
||||
'access_count': self.access_count,
|
||||
'importance_score': self.importance_score
|
||||
}
|
||||
|
||||
class KnowledgeStore:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.semantic_index = SemanticIndex()
|
||||
self._initialize_store()
|
||||
self._load_index()
|
||||
|
||||
def _initialize_store(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS knowledge_entries (
|
||||
entry_id TEXT PRIMARY KEY,
|
||||
category TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL,
|
||||
access_count INTEGER DEFAULT 0,
|
||||
importance_score REAL DEFAULT 1.0
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _load_index(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
|
||||
for row in cursor.fetchall():
|
||||
self.semantic_index.add_document(row[0], row[1])
|
||||
|
||||
conn.close()
|
||||
|
||||
def add_entry(self, entry: KnowledgeEntry):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO knowledge_entries
|
||||
(entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
entry.entry_id,
|
||||
entry.category,
|
||||
entry.content,
|
||||
json.dumps(entry.metadata),
|
||||
entry.created_at,
|
||||
entry.updated_at,
|
||||
entry.access_count,
|
||||
entry.importance_score
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
self.semantic_index.add_document(entry.entry_id, entry.content)
|
||||
|
||||
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
cursor.execute('''
|
||||
UPDATE knowledge_entries
|
||||
SET access_count = access_count + 1
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
conn.commit()
|
||||
|
||||
conn.close()
|
||||
|
||||
return KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
category=row[1],
|
||||
content=row[2],
|
||||
metadata=json.loads(row[3]) if row[3] else {},
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6] + 1,
|
||||
importance_score=row[7]
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
def search_entries(self, query: str, category: Optional[str] = None,
|
||||
top_k: int = 5) -> List[KnowledgeEntry]:
|
||||
search_results = self.semantic_index.search(query, top_k * 2)
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
entries = []
|
||||
for entry_id, score in search_results:
|
||||
if category:
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ? AND category = ?
|
||||
''', (entry_id, category))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
category=row[1],
|
||||
content=row[2],
|
||||
metadata=json.loads(row[3]) if row[3] else {},
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6],
|
||||
importance_score=row[7]
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
if len(entries) >= top_k:
|
||||
break
|
||||
|
||||
conn.close()
|
||||
return entries
|
||||
|
||||
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE category = ?
|
||||
ORDER BY importance_score DESC, created_at DESC
|
||||
LIMIT ?
|
||||
''', (category, limit))
|
||||
|
||||
entries = []
|
||||
for row in cursor.fetchall():
|
||||
entries.append(KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
category=row[1],
|
||||
content=row[2],
|
||||
metadata=json.loads(row[3]) if row[3] else {},
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6],
|
||||
importance_score=row[7]
|
||||
))
|
||||
|
||||
conn.close()
|
||||
return entries
|
||||
|
||||
def update_importance(self, entry_id: str, importance_score: float):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
UPDATE knowledge_entries
|
||||
SET importance_score = ?, updated_at = ?
|
||||
WHERE entry_id = ?
|
||||
''', (importance_score, time.time(), entry_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_entry(self, entry_id: str) -> bool:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
self.semantic_index.remove_document(entry_id)
|
||||
|
||||
return deleted
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries')
|
||||
total_categories = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('''
|
||||
SELECT category, COUNT(*) as count
|
||||
FROM knowledge_entries
|
||||
GROUP BY category
|
||||
ORDER BY count DESC
|
||||
''')
|
||||
category_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
|
||||
total_accesses = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'total_categories': total_categories,
|
||||
'category_distribution': category_counts,
|
||||
'total_accesses': total_accesses,
|
||||
'vocabulary_size': len(self.semantic_index.vocabulary)
|
||||
}
|
||||
87
pr/memory/semantic_index.py
Normal file
87
pr/memory/semantic_index.py
Normal file
@ -0,0 +1,87 @@
|
||||
import math
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from typing import List, Dict, Tuple, Set
|
||||
|
||||
class SemanticIndex:
|
||||
def __init__(self):
|
||||
self.documents: Dict[str, str] = {}
|
||||
self.vocabulary: Set[str] = set()
|
||||
self.idf_scores: Dict[str, float] = {}
|
||||
self.doc_vectors: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^a-z0-9\s]', ' ', text)
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
|
||||
term_count = Counter(tokens)
|
||||
total_terms = len(tokens)
|
||||
return {term: count / total_terms for term, count in term_count.items()}
|
||||
|
||||
def _compute_idf(self):
|
||||
doc_count = len(self.documents)
|
||||
if doc_count == 0:
|
||||
return
|
||||
|
||||
token_doc_count = defaultdict(int)
|
||||
|
||||
for doc_id, doc_text in self.documents.items():
|
||||
tokens = set(self._tokenize(doc_text))
|
||||
for token in tokens:
|
||||
token_doc_count[token] += 1
|
||||
|
||||
if doc_count == 1:
|
||||
self.idf_scores = {token: 1.0 for token in token_doc_count}
|
||||
else:
|
||||
self.idf_scores = {
|
||||
token: math.log(doc_count / count)
|
||||
for token, count in token_doc_count.items()
|
||||
}
|
||||
|
||||
def add_document(self, doc_id: str, text: str):
|
||||
self.documents[doc_id] = text
|
||||
tokens = self._tokenize(text)
|
||||
self.vocabulary.update(tokens)
|
||||
|
||||
self._compute_idf()
|
||||
|
||||
tf_scores = self._compute_tf(tokens)
|
||||
self.doc_vectors[doc_id] = {
|
||||
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0)
|
||||
for token in tokens
|
||||
}
|
||||
|
||||
def remove_document(self, doc_id: str):
|
||||
if doc_id in self.documents:
|
||||
del self.documents[doc_id]
|
||||
if doc_id in self.doc_vectors:
|
||||
del self.doc_vectors[doc_id]
|
||||
self._compute_idf()
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
query_tokens = self._tokenize(query)
|
||||
query_tf = self._compute_tf(query_tokens)
|
||||
|
||||
query_vector = {
|
||||
token: query_tf.get(token, 0) * self.idf_scores.get(token, 0)
|
||||
for token in query_tokens
|
||||
}
|
||||
|
||||
scores = []
|
||||
for doc_id, doc_vector in self.doc_vectors.items():
|
||||
similarity = self._cosine_similarity(query_vector, doc_vector)
|
||||
scores.append((doc_id, similarity))
|
||||
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores[:top_k]
|
||||
|
||||
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
|
||||
dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2))
|
||||
norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
|
||||
norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
98
pr/multiplexer.py
Normal file
98
pr/multiplexer.py
Normal file
@ -0,0 +1,98 @@
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import sys
|
||||
from pr.ui import Colors
|
||||
|
||||
class TerminalMultiplexer:
|
||||
def __init__(self, name, show_output=True):
|
||||
self.name = name
|
||||
self.show_output = show_output
|
||||
self.stdout_buffer = []
|
||||
self.stderr_buffer = []
|
||||
self.stdout_queue = queue.Queue()
|
||||
self.stderr_queue = queue.Queue()
|
||||
self.active = True
|
||||
self.lock = threading.Lock()
|
||||
|
||||
if self.show_output:
|
||||
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
|
||||
self.display_thread.start()
|
||||
|
||||
def _display_worker(self):
|
||||
while self.active:
|
||||
try:
|
||||
line = self.stdout_queue.get(timeout=0.1)
|
||||
if line:
|
||||
sys.stdout.write(f"{Colors.GRAY}[{self.name}]{Colors.RESET} {line}")
|
||||
sys.stdout.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
try:
|
||||
line = self.stderr_queue.get(timeout=0.1)
|
||||
if line:
|
||||
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
|
||||
sys.stderr.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def write_stdout(self, data):
|
||||
with self.lock:
|
||||
self.stdout_buffer.append(data)
|
||||
if self.show_output:
|
||||
self.stdout_queue.put(data)
|
||||
|
||||
def write_stderr(self, data):
|
||||
with self.lock:
|
||||
self.stderr_buffer.append(data)
|
||||
if self.show_output:
|
||||
self.stderr_queue.put(data)
|
||||
|
||||
def get_stdout(self):
|
||||
with self.lock:
|
||||
return ''.join(self.stdout_buffer)
|
||||
|
||||
def get_stderr(self):
|
||||
with self.lock:
|
||||
return ''.join(self.stderr_buffer)
|
||||
|
||||
def get_all_output(self):
|
||||
with self.lock:
|
||||
return {
|
||||
'stdout': ''.join(self.stdout_buffer),
|
||||
'stderr': ''.join(self.stderr_buffer)
|
||||
}
|
||||
|
||||
def close(self):
|
||||
self.active = False
|
||||
if hasattr(self, 'display_thread'):
|
||||
self.display_thread.join(timeout=1)
|
||||
|
||||
_multiplexers = {}
|
||||
_mux_counter = 0
|
||||
_mux_lock = threading.Lock()
|
||||
|
||||
def create_multiplexer(name=None, show_output=True):
|
||||
global _mux_counter
|
||||
with _mux_lock:
|
||||
if name is None:
|
||||
_mux_counter += 1
|
||||
name = f"process-{_mux_counter}"
|
||||
mux = TerminalMultiplexer(name, show_output)
|
||||
_multiplexers[name] = mux
|
||||
return name, mux
|
||||
|
||||
def get_multiplexer(name):
|
||||
return _multiplexers.get(name)
|
||||
|
||||
def close_multiplexer(name):
|
||||
mux = _multiplexers.get(name)
|
||||
if mux:
|
||||
mux.close()
|
||||
del _multiplexers[name]
|
||||
|
||||
def cleanup_all_multiplexers():
|
||||
for mux in list(_multiplexers.values()):
|
||||
mux.close()
|
||||
_multiplexers.clear()
|
||||
0
pr/plugins/__init__.py
Normal file
0
pr/plugins/__init__.py
Normal file
128
pr/plugins/loader.py
Normal file
128
pr/plugins/loader.py
Normal file
@ -0,0 +1,128 @@
|
||||
import os
|
||||
import sys
|
||||
import importlib.util
|
||||
from typing import List, Dict, Callable, Any
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('plugins')
|
||||
|
||||
PLUGINS_DIR = os.path.expanduser("~/.pr/plugins")
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
|
||||
def __init__(self):
|
||||
self.loaded_plugins = {}
|
||||
self.plugin_tools = []
|
||||
os.makedirs(PLUGINS_DIR, exist_ok=True)
|
||||
|
||||
def load_plugins(self) -> List[Dict]:
|
||||
if not os.path.exists(PLUGINS_DIR):
|
||||
logger.info("No plugins directory found")
|
||||
return []
|
||||
|
||||
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')]
|
||||
|
||||
for plugin_file in plugin_files:
|
||||
try:
|
||||
self._load_plugin_file(plugin_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading plugin {plugin_file}: {e}")
|
||||
|
||||
return self.plugin_tools
|
||||
|
||||
def _load_plugin_file(self, filename: str):
|
||||
plugin_path = os.path.join(PLUGINS_DIR, filename)
|
||||
plugin_name = filename[:-3]
|
||||
|
||||
spec = importlib.util.spec_from_file_location(plugin_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"Could not load spec for {filename}")
|
||||
return
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[plugin_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if hasattr(module, 'register_tools'):
|
||||
tools = module.register_tools()
|
||||
if isinstance(tools, list):
|
||||
self.plugin_tools.extend(tools)
|
||||
self.loaded_plugins[plugin_name] = module
|
||||
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
|
||||
|
||||
def get_plugin_function(self, tool_name: str) -> Callable:
|
||||
for plugin_name, module in self.loaded_plugins.items():
|
||||
if hasattr(module, tool_name):
|
||||
return getattr(module, tool_name)
|
||||
|
||||
raise ValueError(f"Plugin function not found: {tool_name}")
|
||||
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
|
||||
def create_example_plugin():
|
||||
example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py')
|
||||
|
||||
if os.path.exists(example_plugin):
|
||||
return
|
||||
|
||||
example_code = '''"""
|
||||
Example plugin for PR Assistant
|
||||
|
||||
This plugin demonstrates how to create custom tools.
|
||||
"""
|
||||
|
||||
def my_custom_tool(argument: str) -> str:
|
||||
"""
|
||||
A custom tool that does something useful.
|
||||
|
||||
Args:
|
||||
argument: Some input
|
||||
|
||||
Returns:
|
||||
A result string
|
||||
"""
|
||||
return f"Custom tool processed: {argument}"
|
||||
|
||||
|
||||
def register_tools():
|
||||
"""
|
||||
Register tools with the PR assistant.
|
||||
|
||||
Returns:
|
||||
List of tool definitions
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_custom_tool",
|
||||
"description": "A custom tool that processes input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"argument": {
|
||||
"type": "string",
|
||||
"description": "The input to process"
|
||||
}
|
||||
},
|
||||
"required": ["argument"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'''
|
||||
|
||||
try:
|
||||
os.makedirs(PLUGINS_DIR, exist_ok=True)
|
||||
with open(example_plugin, 'w') as f:
|
||||
f.write(example_code)
|
||||
logger.info(f"Created example plugin at {example_plugin}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating example plugin: {e}")
|
||||
43
pr/research.md
Normal file
43
pr/research.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Research Overview: Additional Functionality for PR Assistant
|
||||
|
||||
## Overview of Current Application
|
||||
The PR Assistant is a professional CLI AI assistant designed for autonomous execution of tasks. It integrates various tools including command execution, web fetching, database operations, filesystem management, and Python code execution. It features session management, logging, usage tracking, and a plugin system for extensibility.
|
||||
|
||||
## Potential New Features
|
||||
Based on analysis of similar AI assistants and tool-using agents, here are researched ideas for additional functionality:
|
||||
|
||||
### 1. Multi-Modal Interfaces
|
||||
- **Graphical User Interface (GUI)**: Develop a desktop app using frameworks like Electron or Tkinter to provide a user-friendly interface beyond CLI.
|
||||
- **Web Interface**: Create a web-based dashboard for remote access and visualization of results.
|
||||
- **Voice Input/Output**: Integrate speech recognition (e.g., via Google Speech API) and text-to-speech for hands-free interaction.
|
||||
|
||||
### 2. Enhanced Tool Ecosystem
|
||||
- **Additional Built-in Tools**: Add tools for Git operations, email handling, calendar integration, image processing (e.g., OCR, generation via DALL-E), and real-time data feeds (weather, stocks).
|
||||
- **API Integrations**: Connect to popular services like GitHub for repository management, Slack/Discord for notifications, or cloud storage (AWS S3, Google Drive).
|
||||
- **Workflow Automation**: Implement chaining of tools for complex workflows, similar to Zapier or LangChain agents.
|
||||
|
||||
### 3. Advanced AI Capabilities
|
||||
- **Multi-Agent Systems**: Allow multiple AI agents to collaborate on tasks, with role specialization (e.g., one for coding, one for research).
|
||||
- **Long-Term Memory and Learning**: Implement persistent knowledge bases and fine-tuning on user interactions to improve responses over time.
|
||||
- **Context Awareness**: Enhance context management with better summarization and retrieval of past conversations.
|
||||
|
||||
### 4. Productivity and Usability Enhancements
|
||||
- **Export and Sharing**: Add options to export session results to formats like PDF, Markdown, or integrate with documentation tools (e.g., Notion, Confluence).
|
||||
- **Scheduled Tasks**: Enable cron-like scheduling for autonomous task execution.
|
||||
- **Multi-User Support**: Implement user accounts for shared access and collaboration features.
|
||||
|
||||
### 5. Security and Reliability
|
||||
- **Sandboxing and Permissions**: Improve security with containerized tool execution and user-defined permission levels.
|
||||
- **Error Recovery**: Add automatic retry mechanisms, fallback strategies, and detailed error reporting.
|
||||
- **Audit Logging**: Enhance logging for compliance and debugging.
|
||||
|
||||
### 6. Plugin Ecosystem Expansion
|
||||
- **Community Plugin Repository**: Create an online hub for user-contributed plugins.
|
||||
- **Plugin Marketplace**: Allow users to rate and install plugins easily, with dependency management.
|
||||
|
||||
### 7. Performance Optimizations
|
||||
- **Caching**: Implement caching for API calls and tool results to reduce latency.
|
||||
- **Parallel Execution**: Enable concurrent tool usage for faster task completion.
|
||||
- **Model Selection**: Expand support for multiple AI models and allow dynamic switching.
|
||||
|
||||
These features would position the PR Assistant as a more versatile and powerful tool, appealing to developers, researchers, and productivity enthusiasts. Implementation should prioritize backward compatibility and maintain the CLI-first approach while adding optional interfaces.
|
||||
21
pr/tools/__init__.py
Normal file
21
pr/tools/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.filesystem import (
|
||||
read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace
|
||||
)
|
||||
from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process
|
||||
from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor
|
||||
from pr.tools.database import db_set, db_get, db_query
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
|
||||
__all__ = [
|
||||
'get_tools_definition',
|
||||
'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace',
|
||||
'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor',
|
||||
'run_command', 'run_command_interactive',
|
||||
'db_set', 'db_get', 'db_query',
|
||||
'http_fetch', 'web_search', 'web_search_news',
|
||||
'python_exec','tail_process', 'kill_process',
|
||||
'apply_patch', 'create_diff'
|
||||
]
|
||||
444
pr/tools/base.py
Normal file
444
pr/tools/base.py
Normal file
@ -0,0 +1,444 @@
|
||||
def get_tools_definition():
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "kill_process",
|
||||
"description": "Terminate a background process by its PID. Use this to stop processes started with run_command that exceeded their timeout.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'."
|
||||
}
|
||||
},
|
||||
"required": ["pid"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tail_process",
|
||||
"description": "Monitor and retrieve output from a background process by its PID. Use this to check on processes started with run_command that exceeded their timeout.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'."
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["pid"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "http_fetch",
|
||||
"description": "Fetch content from an HTTP URL",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The URL to fetch"},
|
||||
"headers": {"type": "object", "description": "Optional HTTP headers"}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_command",
|
||||
"description": "Execute a shell command and capture output. Returns immediately after timeout with PID if still running. Use tail_process to monitor or kill_process to terminate long-running commands.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "The shell command to execute"},
|
||||
"timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_command_interactive",
|
||||
"description": "Execute an interactive terminal command that requires user input or displays UI. The command runs in the user's terminal. Returns exit code only.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read contents of a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"content": {"type": "string", "description": "Content to write"}
|
||||
},
|
||||
"required": ["filepath", "content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_directory",
|
||||
"description": "List directory contents",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Directory path", "default": "."},
|
||||
"recursive": {"type": "boolean", "description": "List recursively", "default": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mkdir",
|
||||
"description": "Create a new directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path of the directory to create"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "chdir",
|
||||
"description": "Change the current working directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to change to"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "getpwd",
|
||||
"description": "Get the current working directory",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_set",
|
||||
"description": "Set a key-value pair in the database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "The key"},
|
||||
"value": {"type": "string", "description": "The value"}
|
||||
},
|
||||
"required": ["key", "value"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_get",
|
||||
"description": "Get a value from the database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "The key"}
|
||||
},
|
||||
"required": ["key"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_query",
|
||||
"description": "Execute a database query",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "SQL query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Perform a web search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search_news",
|
||||
"description": "Perform a web search for news",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query for news"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python_exec",
|
||||
"description": "Execute Python code",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Python code to execute"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "index_source_directory",
|
||||
"description": "Index directory recursively and read all source files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to index"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_replace",
|
||||
"description": "Search and replace text in a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"old_string": {"type": "string", "description": "String to replace"},
|
||||
"new_string": {"type": "string", "description": "Replacement string"}
|
||||
},
|
||||
"required": ["filepath", "old_string", "new_string"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"description": "Apply a patch to a file, especially for source code",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file to patch"},
|
||||
"patch_content": {"type": "string", "description": "The patch content as a string"}
|
||||
},
|
||||
"required": ["filepath", "patch_content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_diff",
|
||||
"description": "Create a unified diff between two files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file1": {"type": "string", "description": "Path to the first file"},
|
||||
"file2": {"type": "string", "description": "Path to the second file"},
|
||||
"fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"},
|
||||
"tofile": {"type": "string", "description": "Label for the second file", "default": "file2"}
|
||||
},
|
||||
"required": ["file1", "file2"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "open_editor",
|
||||
"description": "Open the RPEditor for a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "close_editor",
|
||||
"description": "Close the RPEditor. Always close files when finished editing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_insert_text",
|
||||
"description": "Insert text at cursor position in the editor",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"text": {"type": "string", "description": "Text to insert"},
|
||||
"line": {"type": "integer", "description": "Line number (optional)"},
|
||||
"col": {"type": "integer", "description": "Column number (optional)"}
|
||||
},
|
||||
"required": ["filepath", "text"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_replace_text",
|
||||
"description": "Replace text in a range",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"start_line": {"type": "integer", "description": "Start line"},
|
||||
"start_col": {"type": "integer", "description": "Start column"},
|
||||
"end_line": {"type": "integer", "description": "End line"},
|
||||
"end_col": {"type": "integer", "description": "End column"},
|
||||
"new_text": {"type": "string", "description": "New text"}
|
||||
},
|
||||
"required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_search",
|
||||
"description": "Search for a pattern in the file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"pattern": {"type": "string", "description": "Regex pattern"},
|
||||
"start_line": {"type": "integer", "description": "Start line", "default": 0}
|
||||
},
|
||||
"required": ["filepath", "pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_file_diff",
|
||||
"description": "Display a visual colored diff between two files with syntax highlighting and statistics",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath1": {"type": "string", "description": "Path to the original file"},
|
||||
"filepath2": {"type": "string", "description": "Path to the modified file"},
|
||||
"format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"}
|
||||
},
|
||||
"required": ["filepath1", "filepath2"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_edit_summary",
|
||||
"description": "Display a summary of all edit operations performed during the session",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_edit_timeline",
|
||||
"description": "Display a timeline of all edit operations with details",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"show_content": {"type": "boolean", "description": "Show content previews", "default": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "clear_edit_tracker",
|
||||
"description": "Clear the edit tracker to start fresh",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
164
pr/tools/command.py
Normal file
164
pr/tools/command.py
Normal file
@ -0,0 +1,164 @@
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import select
|
||||
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
|
||||
|
||||
_processes = {}
|
||||
|
||||
def _register_process(pid:int, process):
|
||||
_processes[pid] = process
|
||||
return _processes
|
||||
|
||||
def _get_process(pid:int):
|
||||
return _processes.get(pid)
|
||||
|
||||
def kill_process(pid:int):
|
||||
try:
|
||||
process = _get_process(pid)
|
||||
if process:
|
||||
process.kill()
|
||||
_processes.pop(pid)
|
||||
|
||||
mux_name = f"cmd-{pid}"
|
||||
if get_multiplexer(mux_name):
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {"status": "success", "message": f"Process {pid} has been killed"}
|
||||
else:
|
||||
return {"status": "error", "error": f"Process {pid} not found"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def tail_process(pid: int, timeout: int = 30):
|
||||
process = _get_process(pid)
|
||||
if process:
|
||||
mux_name = f"cmd-{pid}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
|
||||
if not mux:
|
||||
mux_name, mux = create_multiplexer(mux_name, show_output=True)
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
timeout_duration = timeout
|
||||
stdout_content = ""
|
||||
stderr_content = ""
|
||||
|
||||
while True:
|
||||
if process.poll() is not None:
|
||||
remaining_stdout, remaining_stderr = process.communicate()
|
||||
if remaining_stdout:
|
||||
mux.write_stdout(remaining_stdout)
|
||||
stdout_content += remaining_stdout
|
||||
if remaining_stderr:
|
||||
mux.write_stderr(remaining_stderr)
|
||||
stderr_content += remaining_stderr
|
||||
|
||||
if pid in _processes:
|
||||
_processes.pop(pid)
|
||||
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
return {
|
||||
"status": "running",
|
||||
"message": "Process is still running. Call tail_process again to continue monitoring.",
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": pid
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
if line:
|
||||
mux.write_stdout(line)
|
||||
stdout_content += line
|
||||
elif pipe == process.stderr:
|
||||
line = process.stderr.readline()
|
||||
if line:
|
||||
mux.write_stderr(line)
|
||||
stderr_content += line
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
else:
|
||||
return {"status": "error", "error": f"Process {pid} not found"}
|
||||
|
||||
|
||||
def run_command(command, timeout=30):
|
||||
mux_name = None
|
||||
try:
|
||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
_register_process(process.pid, process)
|
||||
|
||||
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
|
||||
|
||||
start_time = time.time()
|
||||
timeout_duration = timeout
|
||||
stdout_content = ""
|
||||
stderr_content = ""
|
||||
|
||||
while True:
|
||||
if process.poll() is not None:
|
||||
remaining_stdout, remaining_stderr = process.communicate()
|
||||
if remaining_stdout:
|
||||
mux.write_stdout(remaining_stdout)
|
||||
stdout_content += remaining_stdout
|
||||
if remaining_stderr:
|
||||
mux.write_stderr(remaining_stderr)
|
||||
stderr_content += remaining_stderr
|
||||
|
||||
if process.pid in _processes:
|
||||
_processes.pop(process.pid)
|
||||
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
return {
|
||||
"status": "running",
|
||||
"message": f"Process still running after {timeout}s timeout. Use tail_process({process.pid}) to monitor or kill_process({process.pid}) to terminate.",
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": process.pid,
|
||||
"mux_name": mux_name
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
if line:
|
||||
mux.write_stdout(line)
|
||||
stdout_content += line
|
||||
elif pipe == process.stderr:
|
||||
line = process.stderr.readline()
|
||||
if line:
|
||||
mux.write_stderr(line)
|
||||
stderr_content += line
|
||||
except Exception as e:
|
||||
if mux_name:
|
||||
close_multiplexer(mux_name)
|
||||
return {"status": "error", "error": str(e)}
|
||||
def run_command_interactive(command):
|
||||
try:
|
||||
return_code = os.system(command)
|
||||
return {"status": "success", "returncode": return_code}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
47
pr/tools/database.py
Normal file
47
pr/tools/database.py
Normal file
@ -0,0 +1,47 @@
|
||||
import time
|
||||
|
||||
def db_set(key, value, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
try:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
|
||||
VALUES (?, ?, ?)""", (key, value, time.time()))
|
||||
db_conn.commit()
|
||||
return {"status": "success", "message": f"Set {key}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def db_get(key, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
try:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SELECT value FROM kv_store WHERE key = ?", (key,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return {"status": "success", "value": result[0]}
|
||||
else:
|
||||
return {"status": "error", "error": "Key not found"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def db_query(query, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
try:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute(query)
|
||||
|
||||
if query.strip().upper().startswith('SELECT'):
|
||||
results = cursor.fetchall()
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
return {"status": "success", "columns": columns, "rows": results}
|
||||
else:
|
||||
db_conn.commit()
|
||||
return {"status": "success", "rows_affected": cursor.rowcount}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
144
pr/tools/editor.py
Normal file
144
pr/tools/editor.py
Normal file
@ -0,0 +1,144 @@
|
||||
from pr.editor import RPEditor
|
||||
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
|
||||
from ..ui.diff_display import display_diff, get_diff_stats
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
from ..tools.patch import display_content_diff
|
||||
import os
|
||||
import os.path
|
||||
|
||||
_editors = {}
|
||||
|
||||
def get_editor(filepath):
|
||||
if filepath not in _editors:
|
||||
_editors[filepath] = RPEditor(filepath)
|
||||
return _editors[filepath]
|
||||
|
||||
def close_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = get_editor(path)
|
||||
editor.close()
|
||||
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
mux.write_stdout(f"Closed editor for: {path}\n")
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {"status": "success", "message": f"Editor closed for {path}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def open_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = RPEditor(path)
|
||||
editor.start()
|
||||
|
||||
mux_name = f"editor-{path}"
|
||||
mux_name, mux = create_multiplexer(mux_name, show_output=True)
|
||||
mux.write_stdout(f"Opened editor for: {path}\n")
|
||||
|
||||
return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
if line is not None and col is not None:
|
||||
editor.move_cursor_to(line, col)
|
||||
editor.insert_text(text)
|
||||
editor.save_file()
|
||||
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
location = f" at line {line}, col {col}" if line is not None and col is not None else ""
|
||||
preview = text[:50] + "..." if len(text) > 50 else text
|
||||
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
mux.write_stdout(diff_result["visual_diff"] + "\n")
|
||||
|
||||
tracker.mark_completed(operation)
|
||||
result = {"status": "success", "message": f"Inserted text in {path}"}
|
||||
close_editor(filepath)
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
old_content = f.read()
|
||||
|
||||
start_pos = start_line * 1000 + start_col
|
||||
end_pos = end_line * 1000 + end_col
|
||||
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
|
||||
content=new_text, old_content=old_content)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
editor.replace_text(start_line, start_col, end_line, end_col, new_text)
|
||||
editor.save_file()
|
||||
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
preview = new_text[:50] + "..." if len(new_text) > 50 else new_text
|
||||
mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n")
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
mux.write_stdout(diff_result["visual_diff"] + "\n")
|
||||
|
||||
tracker.mark_completed(operation)
|
||||
result = {"status": "success", "message": f"Replaced text in {path}"}
|
||||
close_editor(filepath)
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_search(filepath, pattern, start_line=0):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = RPEditor(path)
|
||||
results = editor.search(pattern, start_line)
|
||||
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n")
|
||||
|
||||
result = {"status": "success", "results": results}
|
||||
close_editor(filepath)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
287
pr/tools/filesystem.py
Normal file
287
pr/tools/filesystem.py
Normal file
@ -0,0 +1,287 @@
|
||||
import os
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Dict
|
||||
from pr.editor import RPEditor
|
||||
from ..ui.diff_display import display_diff, get_diff_stats
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
from ..tools.patch import display_content_diff
|
||||
|
||||
_id = 0
|
||||
|
||||
def get_uid():
|
||||
global _id
|
||||
_id += 3
|
||||
return _id
|
||||
|
||||
def read_file(filepath, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
if db_conn:
|
||||
from pr.tools.database import db_set
|
||||
db_set("read:" + path, "true", db_conn)
|
||||
return {"status": "success", "content": content}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
old_content = ""
|
||||
is_new_file = not os.path.exists(path)
|
||||
|
||||
if not is_new_file and db_conn:
|
||||
from pr.tools.database import db_get
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
|
||||
if not is_new_file:
|
||||
with open(path, 'r') as f:
|
||||
old_content = f.read()
|
||||
|
||||
operation = track_edit('WRITE', filepath, content=content, old_content=old_content)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
if show_diff and not is_new_file:
|
||||
diff_result = display_content_diff(old_content, content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
print(diff_result["visual_diff"])
|
||||
|
||||
editor = RPEditor(path)
|
||||
editor.set_text(content)
|
||||
editor.save_file()
|
||||
|
||||
if os.path.exists(path) and db_conn:
|
||||
try:
|
||||
cursor = db_conn.cursor()
|
||||
file_hash = hashlib.md5(old_content.encode()).hexdigest()
|
||||
|
||||
cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,))
|
||||
result = cursor.fetchone()
|
||||
version = (result[0] + 1) if result[0] else 1
|
||||
|
||||
cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(filepath, old_content, file_hash, time.time(), version))
|
||||
db_conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tracker.mark_completed(operation)
|
||||
|
||||
message = f"File written to {path}"
|
||||
if show_diff and not is_new_file:
|
||||
stats = get_diff_stats(old_content, content)
|
||||
message += f" ({stats['insertions']}+ {stats['deletions']}-)"
|
||||
|
||||
return {"status": "success", "message": message}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def list_directory(path=".", recursive=False):
|
||||
try:
|
||||
path = os.path.expanduser(path)
|
||||
items = []
|
||||
if recursive:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for name in files:
|
||||
item_path = os.path.join(root, name)
|
||||
items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)})
|
||||
for name in dirs:
|
||||
items.append({"path": os.path.join(root, name), "type": "directory"})
|
||||
else:
|
||||
for item in os.listdir(path):
|
||||
item_path = os.path.join(path, item)
|
||||
items.append({
|
||||
"name": item,
|
||||
"type": "directory" if os.path.isdir(item_path) else "file",
|
||||
"size": os.path.getsize(item_path) if os.path.isfile(item_path) else None
|
||||
})
|
||||
return {"status": "success", "items": items}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(os.path.expanduser(path), exist_ok=True)
|
||||
return {"status": "success", "message": f"Directory created at {path}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def chdir(path):
|
||||
try:
|
||||
os.chdir(os.path.expanduser(path))
|
||||
return {"status": "success", "new_path": os.getcwd()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def getpwd():
|
||||
try:
|
||||
return {"status": "success", "path": os.getcwd()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def index_source_directory(path):
|
||||
extensions = [
|
||||
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp",
|
||||
".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go"
|
||||
]
|
||||
source_files = []
|
||||
try:
|
||||
for root, _, files in os.walk(os.path.expanduser(path)):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in extensions):
|
||||
filepath = os.path.join(root, file)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
source_files.append({
|
||||
"path": filepath,
|
||||
"content": content
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
return {"status": "success", "indexed_files": source_files}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def search_replace(filepath, old_string, new_string, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if not os.path.exists(path):
|
||||
return {"status": "error", "error": "File does not exist"}
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
content = content.replace(old_string, new_string)
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
_editors = {}
|
||||
|
||||
def get_editor(filepath):
|
||||
if filepath not in _editors:
|
||||
_editors[filepath] = RPEditor(filepath)
|
||||
return _editors[filepath]
|
||||
|
||||
def close_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = get_editor(path)
|
||||
editor.close()
|
||||
return {"status": "success", "message": f"Editor closed for {path}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def open_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = RPEditor(path)
|
||||
editor.start()
|
||||
return {"status": "success", "message": f"Editor opened for {path}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
if line is not None and col is not None:
|
||||
editor.move_cursor_to(line, col)
|
||||
editor.insert_text(text)
|
||||
editor.save_file()
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
print(diff_result["visual_diff"])
|
||||
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Inserted text in {path}"}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
old_content = f.read()
|
||||
|
||||
start_pos = start_line * 1000 + start_col
|
||||
end_pos = end_line * 1000 + end_col
|
||||
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
|
||||
content=new_text, old_content=old_content)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
editor.replace_text(start_line, start_col, end_line, end_col, new_text)
|
||||
editor.save_file()
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
print(diff_result["visual_diff"])
|
||||
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Replaced text in {path}"}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_edit_summary():
|
||||
from ..ui.edit_feedback import display_edit_summary
|
||||
return display_edit_summary()
|
||||
|
||||
|
||||
def display_edit_timeline(show_content=False):
|
||||
from ..ui.edit_feedback import display_edit_timeline
|
||||
return display_edit_timeline(show_content)
|
||||
|
||||
|
||||
def clear_edit_tracker():
|
||||
from ..ui.edit_feedback import clear_tracker
|
||||
clear_tracker()
|
||||
return {"status": "success", "message": "Edit tracker cleared"}
|
||||
91
pr/tools/patch.py
Normal file
91
pr/tools/patch.py
Normal file
@ -0,0 +1,91 @@
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess
|
||||
import difflib
|
||||
from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay
|
||||
|
||||
def apply_patch(filepath, patch_content, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
# Write patch to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f:
|
||||
f.write(patch_content)
|
||||
patch_file = f.name
|
||||
# Run patch command
|
||||
result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path))
|
||||
os.unlink(patch_file)
|
||||
if result.returncode == 0:
|
||||
return {"status": "success", "output": result.stdout.strip()}
|
||||
else:
|
||||
return {"status": "error", "error": result.stderr.strip()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'):
|
||||
try:
|
||||
path1 = os.path.expanduser(file1)
|
||||
path2 = os.path.expanduser(file2)
|
||||
with open(path1, 'r') as f1, open(path2, 'r') as f2:
|
||||
content1 = f1.read()
|
||||
content2 = f2.read()
|
||||
|
||||
if visual:
|
||||
visual_diff = display_diff(content1, content2, fromfile, format_type)
|
||||
stats = get_diff_stats(content1, content2)
|
||||
lines1 = content1.splitlines(keepends=True)
|
||||
lines2 = content2.splitlines(keepends=True)
|
||||
plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
|
||||
return {
|
||||
"status": "success",
|
||||
"diff": ''.join(plain_diff),
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
}
|
||||
else:
|
||||
lines1 = content1.splitlines(keepends=True)
|
||||
lines2 = content2.splitlines(keepends=True)
|
||||
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
|
||||
return {"status": "success", "diff": ''.join(diff)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3):
|
||||
try:
|
||||
path1 = os.path.expanduser(filepath1)
|
||||
path2 = os.path.expanduser(filepath2)
|
||||
|
||||
with open(path1, 'r') as f1:
|
||||
old_content = f1.read()
|
||||
with open(path2, 'r') as f2:
|
||||
new_content = f2.read()
|
||||
|
||||
visual_diff = display_diff(old_content, new_content, filepath1, format_type)
|
||||
stats = get_diff_stats(old_content, new_content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_content_diff(old_content, new_content, filename='file', format_type='unified'):
|
||||
try:
|
||||
visual_diff = display_diff(old_content, new_content, filename, format_type)
|
||||
stats = get_diff_stats(old_content, new_content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
13
pr/tools/python_exec.py
Normal file
13
pr/tools/python_exec.py
Normal file
@ -0,0 +1,13 @@
|
||||
import traceback
|
||||
from io import StringIO
|
||||
import contextlib
|
||||
|
||||
def python_exec(code, python_globals):
|
||||
try:
|
||||
output = StringIO()
|
||||
with contextlib.redirect_stdout(output):
|
||||
exec(code, python_globals)
|
||||
|
||||
return {"status": "success", "output": output.getvalue()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e), "traceback": traceback.format_exc()}
|
||||
36
pr/tools/web.py
Normal file
36
pr/tools/web.py
Normal file
@ -0,0 +1,36 @@
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
import json
|
||||
|
||||
def http_fetch(url, headers=None):
|
||||
try:
|
||||
req = urllib.request.Request(url)
|
||||
if headers:
|
||||
for key, value in headers.items():
|
||||
req.add_header(key, value)
|
||||
|
||||
with urllib.request.urlopen(req) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
return {"status": "success", "content": content[:10000]}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _perform_search(base_url, query, params=None):
|
||||
try:
|
||||
full_url = f"https://static.molodetz.nl/search.cgi?query={query}"
|
||||
|
||||
with urllib.request.urlopen(full_url) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
return {"status": "success", "content": json.loads(content)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def web_search(query):
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
def web_search_news(query):
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
5
pr/ui/__init__.py
Normal file
5
pr/ui/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from pr.ui.colors import Colors
|
||||
from pr.ui.rendering import highlight_code, render_markdown
|
||||
from pr.ui.display import display_tool_call, print_autonomous_header
|
||||
|
||||
__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header']
|
||||
14
pr/ui/colors.py
Normal file
14
pr/ui/colors.py
Normal file
@ -0,0 +1,14 @@
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
MAGENTA = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
GRAY = '\033[90m'
|
||||
WHITE = '\033[97m'
|
||||
BG_BLUE = '\033[44m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_RED = '\033[41m'
|
||||
199
pr/ui/diff_display.py
Normal file
199
pr/ui/diff_display.py
Normal file
@ -0,0 +1,199 @@
|
||||
import difflib
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from .colors import Colors
|
||||
|
||||
|
||||
class DiffStats:
|
||||
def __init__(self):
|
||||
self.insertions = 0
|
||||
self.deletions = 0
|
||||
self.modifications = 0
|
||||
self.files_changed = 0
|
||||
|
||||
@property
|
||||
def total_changes(self):
|
||||
return self.insertions + self.deletions
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.files_changed} file(s) changed, {self.insertions} insertions(+), {self.deletions} deletions(-)"
|
||||
|
||||
|
||||
class DiffLine:
|
||||
def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None,
|
||||
new_line_num: Optional[int] = None):
|
||||
self.line_type = line_type
|
||||
self.content = content
|
||||
self.old_line_num = old_line_num
|
||||
self.new_line_num = new_line_num
|
||||
|
||||
def format(self, show_line_nums: bool = True) -> str:
|
||||
color = {
|
||||
'add': Colors.GREEN,
|
||||
'delete': Colors.RED,
|
||||
'context': Colors.GRAY,
|
||||
'header': Colors.CYAN,
|
||||
'stats': Colors.BLUE
|
||||
}.get(self.line_type, Colors.RESET)
|
||||
|
||||
prefix = {
|
||||
'add': '+ ',
|
||||
'delete': '- ',
|
||||
'context': ' ',
|
||||
'header': '',
|
||||
'stats': ''
|
||||
}.get(self.line_type, ' ')
|
||||
|
||||
if show_line_nums and self.line_type in ('add', 'delete', 'context'):
|
||||
old_num = str(self.old_line_num) if self.old_line_num else ' '
|
||||
new_num = str(self.new_line_num) if self.new_line_num else ' '
|
||||
line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} "
|
||||
else:
|
||||
line_num_str = ''
|
||||
|
||||
return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}"
|
||||
|
||||
|
||||
class DiffDisplay:
|
||||
def __init__(self, context_lines: int = 3):
|
||||
self.context_lines = context_lines
|
||||
|
||||
def create_diff(self, old_content: str, new_content: str,
|
||||
filename: str = "file") -> Tuple[List[DiffLine], DiffStats]:
|
||||
old_lines = old_content.splitlines(keepends=True)
|
||||
new_lines = new_content.splitlines(keepends=True)
|
||||
|
||||
diff_lines = []
|
||||
stats = DiffStats()
|
||||
stats.files_changed = 1
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
old_lines, new_lines,
|
||||
fromfile=f"a/{filename}",
|
||||
tofile=f"b/{filename}",
|
||||
n=self.context_lines
|
||||
)
|
||||
|
||||
old_line_num = 0
|
||||
new_line_num = 0
|
||||
|
||||
for line in diff:
|
||||
if line.startswith('---') or line.startswith('+++'):
|
||||
diff_lines.append(DiffLine('header', line.rstrip()))
|
||||
elif line.startswith('@@'):
|
||||
diff_lines.append(DiffLine('header', line.rstrip()))
|
||||
old_line_num, new_line_num = self._parse_hunk_header(line)
|
||||
elif line.startswith('+'):
|
||||
stats.insertions += 1
|
||||
diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num))
|
||||
new_line_num += 1
|
||||
elif line.startswith('-'):
|
||||
stats.deletions += 1
|
||||
diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None))
|
||||
old_line_num += 1
|
||||
elif line.startswith(' '):
|
||||
diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num))
|
||||
old_line_num += 1
|
||||
new_line_num += 1
|
||||
|
||||
stats.modifications = min(stats.insertions, stats.deletions)
|
||||
|
||||
return diff_lines, stats
|
||||
|
||||
def _parse_hunk_header(self, header: str) -> Tuple[int, int]:
|
||||
try:
|
||||
parts = header.split('@@')[1].strip().split()
|
||||
old_start = int(parts[0].split(',')[0].replace('-', ''))
|
||||
new_start = int(parts[1].split(',')[0].replace('+', ''))
|
||||
return old_start, new_start
|
||||
except (IndexError, ValueError):
|
||||
return 0, 0
|
||||
|
||||
def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats,
|
||||
show_line_nums: bool = True, show_stats: bool = True) -> str:
|
||||
output = []
|
||||
|
||||
if show_stats:
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}DIFF SUMMARY{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BLUE}{stats}{Colors.RESET}\n")
|
||||
|
||||
for line in diff_lines:
|
||||
output.append(line.format(show_line_nums))
|
||||
|
||||
if show_stats:
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
return '\n'.join(output)
|
||||
|
||||
def display_file_diff(self, old_content: str, new_content: str,
|
||||
filename: str = "file", show_line_nums: bool = True) -> str:
|
||||
diff_lines, stats = self.create_diff(old_content, new_content, filename)
|
||||
|
||||
if not diff_lines:
|
||||
return f"{Colors.GRAY}No changes detected{Colors.RESET}"
|
||||
|
||||
return self.render_diff(diff_lines, stats, show_line_nums)
|
||||
|
||||
def display_side_by_side(self, old_content: str, new_content: str,
|
||||
filename: str = "file", width: int = 80) -> str:
|
||||
old_lines = old_content.splitlines()
|
||||
new_lines = new_content.splitlines()
|
||||
|
||||
matcher = difflib.SequenceMatcher(None, old_lines, new_lines)
|
||||
output = []
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
|
||||
|
||||
half_width = (width - 5) // 2
|
||||
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == 'equal':
|
||||
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
|
||||
elif tag == 'replace':
|
||||
max_lines = max(i2 - i1, j2 - j1)
|
||||
for i in range(max_lines):
|
||||
old_line = old_lines[i1 + i] if i1 + i < i2 else ""
|
||||
new_line = new_lines[j1 + i] if j1 + i < j2 else ""
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}")
|
||||
elif tag == 'delete':
|
||||
for old_line in old_lines[i1:i2]:
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
|
||||
elif tag == 'insert':
|
||||
for new_line in new_lines[j1:j2]:
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
|
||||
|
||||
def display_diff(old_content: str, new_content: str, filename: str = "file",
|
||||
format_type: str = "unified", context_lines: int = 3) -> str:
|
||||
displayer = DiffDisplay(context_lines)
|
||||
|
||||
if format_type == "side-by-side":
|
||||
return displayer.display_side_by_side(old_content, new_content, filename)
|
||||
else:
|
||||
return displayer.display_file_diff(old_content, new_content, filename)
|
||||
|
||||
|
||||
def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]:
|
||||
displayer = DiffDisplay()
|
||||
_, stats = displayer.create_diff(old_content, new_content)
|
||||
|
||||
return {
|
||||
'insertions': stats.insertions,
|
||||
'deletions': stats.deletions,
|
||||
'modifications': stats.modifications,
|
||||
'total_changes': stats.total_changes,
|
||||
'files_changed': stats.files_changed
|
||||
}
|
||||
46
pr/ui/display.py
Normal file
46
pr/ui/display.py
Normal file
@ -0,0 +1,46 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from pr.ui.colors import Colors
|
||||
|
||||
def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
status_icons = {
|
||||
"running": ("⚙", Colors.YELLOW),
|
||||
"success": ("✓", Colors.GREEN),
|
||||
"error": ("✗", Colors.RED)
|
||||
}
|
||||
|
||||
icon, color = status_icons.get(status, ("•", Colors.WHITE))
|
||||
|
||||
print(f"\n{Colors.BOLD}{'═' * 80}{Colors.RESET}")
|
||||
print(f"{color}{icon} {Colors.BOLD}{Colors.CYAN}TOOL: {tool_name}{Colors.RESET}")
|
||||
print(f"{Colors.BOLD}{'─' * 80}{Colors.RESET}")
|
||||
|
||||
if arguments:
|
||||
print(f"{Colors.YELLOW}Parameters:{Colors.RESET}")
|
||||
for key, value in arguments.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:100] + "..."
|
||||
print(f" {Colors.CYAN}• {key}:{Colors.RESET} {value_str}")
|
||||
|
||||
if result is not None and status != "running":
|
||||
print(f"\n{Colors.YELLOW}Result:{Colors.RESET}")
|
||||
result_str = json.dumps(result, indent=2) if isinstance(result, dict) else str(result)
|
||||
|
||||
if len(result_str) > 500:
|
||||
result_str = result_str[:500] + f"\n{Colors.GRAY}... (truncated){Colors.RESET}"
|
||||
|
||||
if status == "success":
|
||||
print(f"{Colors.GREEN}{result_str}{Colors.RESET}")
|
||||
elif status == "error":
|
||||
print(f"{Colors.RED}{result_str}{Colors.RESET}")
|
||||
else:
|
||||
print(result_str)
|
||||
|
||||
print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n")
|
||||
|
||||
def print_autonomous_header(task):
|
||||
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
|
||||
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
|
||||
print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n")
|
||||
198
pr/ui/edit_feedback.py
Normal file
198
pr/ui/edit_feedback.py
Normal file
@ -0,0 +1,198 @@
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .colors import Colors
|
||||
from .progress import ProgressBar
|
||||
|
||||
|
||||
class EditOperation:
|
||||
def __init__(self, op_type: str, filepath: str, start_pos: int = 0,
|
||||
end_pos: int = 0, content: str = "", old_content: str = ""):
|
||||
self.op_type = op_type
|
||||
self.filepath = filepath
|
||||
self.start_pos = start_pos
|
||||
self.end_pos = end_pos
|
||||
self.content = content
|
||||
self.old_content = old_content
|
||||
self.timestamp = datetime.now()
|
||||
self.status = "pending"
|
||||
|
||||
def format_operation(self) -> str:
|
||||
op_colors = {
|
||||
'INSERT': Colors.GREEN,
|
||||
'REPLACE': Colors.YELLOW,
|
||||
'DELETE': Colors.RED,
|
||||
'WRITE': Colors.BLUE
|
||||
}
|
||||
|
||||
color = op_colors.get(self.op_type, Colors.RESET)
|
||||
status_icon = {
|
||||
'pending': '○',
|
||||
'in_progress': '◐',
|
||||
'completed': '●',
|
||||
'failed': '✗'
|
||||
}.get(self.status, '○')
|
||||
|
||||
return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}"
|
||||
|
||||
def format_details(self, show_content: bool = True) -> str:
|
||||
output = [self.format_operation()]
|
||||
|
||||
if self.op_type in ('INSERT', 'REPLACE'):
|
||||
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
|
||||
|
||||
if show_content:
|
||||
if self.old_content:
|
||||
lines = self.old_content.split('\n')
|
||||
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
|
||||
output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
|
||||
|
||||
if self.content:
|
||||
lines = self.content.split('\n')
|
||||
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
|
||||
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
|
||||
|
||||
return '\n'.join(output)
|
||||
|
||||
|
||||
class EditTracker:
|
||||
def __init__(self):
|
||||
self.operations: List[EditOperation] = []
|
||||
self.current_file: Optional[str] = None
|
||||
|
||||
def add_operation(self, op_type: str, filepath: str, **kwargs) -> EditOperation:
|
||||
op = EditOperation(op_type, filepath, **kwargs)
|
||||
self.operations.append(op)
|
||||
self.current_file = filepath
|
||||
return op
|
||||
|
||||
def mark_in_progress(self, operation: EditOperation):
|
||||
operation.status = "in_progress"
|
||||
|
||||
def mark_completed(self, operation: EditOperation):
|
||||
operation.status = "completed"
|
||||
|
||||
def mark_failed(self, operation: EditOperation):
|
||||
operation.status = "failed"
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
stats = {
|
||||
'total': len(self.operations),
|
||||
'completed': sum(1 for op in self.operations if op.status == 'completed'),
|
||||
'pending': sum(1 for op in self.operations if op.status == 'pending'),
|
||||
'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'),
|
||||
'failed': sum(1 for op in self.operations if op.status == 'failed')
|
||||
}
|
||||
return stats
|
||||
|
||||
def get_completion_percentage(self) -> float:
|
||||
if not self.operations:
|
||||
return 0.0
|
||||
stats = self.get_stats()
|
||||
return (stats['completed'] / stats['total']) * 100
|
||||
|
||||
def display_progress(self) -> str:
|
||||
if not self.operations:
|
||||
return f"{Colors.GRAY}No edit operations tracked{Colors.RESET}"
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
stats = self.get_stats()
|
||||
completion = self.get_completion_percentage()
|
||||
|
||||
progress_bar = ProgressBar(total=stats['total'], width=40)
|
||||
progress_bar.current = stats['completed']
|
||||
bar_display = progress_bar._get_bar_display()
|
||||
|
||||
output.append(f"Progress: {bar_display}")
|
||||
output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
|
||||
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n")
|
||||
|
||||
output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}")
|
||||
for i, op in enumerate(self.operations[-5:], 1):
|
||||
output.append(f"{i}. {op.format_operation()}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
|
||||
def display_timeline(self, show_content: bool = False) -> str:
|
||||
if not self.operations:
|
||||
return f"{Colors.GRAY}No edit operations tracked{Colors.RESET}"
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT TIMELINE{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
for i, op in enumerate(self.operations, 1):
|
||||
timestamp = op.timestamp.strftime("%H:%M:%S")
|
||||
output.append(f"{Colors.GRAY}[{timestamp}]{Colors.RESET} {i}.")
|
||||
output.append(op.format_details(show_content))
|
||||
output.append("")
|
||||
|
||||
stats = self.get_stats()
|
||||
output.append(f"{Colors.BOLD}Summary:{Colors.RESET}")
|
||||
output.append(f"{Colors.BLUE}Total operations: {stats['total']}, "
|
||||
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
|
||||
def display_summary(self) -> str:
|
||||
if not self.operations:
|
||||
return f"{Colors.GRAY}No edits to summarize{Colors.RESET}"
|
||||
|
||||
stats = self.get_stats()
|
||||
files_modified = len(set(op.filepath for op in self.operations))
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.GREEN}EDIT SUMMARY{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
output.append(f"{Colors.GREEN}Files Modified: {files_modified}{Colors.RESET}")
|
||||
output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}")
|
||||
output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}")
|
||||
|
||||
if stats['failed'] > 0:
|
||||
output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}")
|
||||
op_types = {}
|
||||
for op in self.operations:
|
||||
op_types[op.op_type] = op_types.get(op.op_type, 0) + 1
|
||||
|
||||
for op_type, count in sorted(op_types.items()):
|
||||
output.append(f" {op_type}: {count}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
|
||||
def clear(self):
|
||||
self.operations.clear()
|
||||
self.current_file = None
|
||||
|
||||
|
||||
tracker = EditTracker()
|
||||
|
||||
|
||||
def track_edit(op_type: str, filepath: str, **kwargs) -> EditOperation:
|
||||
return tracker.add_operation(op_type, filepath, **kwargs)
|
||||
|
||||
|
||||
def display_edit_progress() -> str:
|
||||
return tracker.display_progress()
|
||||
|
||||
|
||||
def display_edit_timeline(show_content: bool = False) -> str:
|
||||
return tracker.display_timeline(show_content)
|
||||
|
||||
|
||||
def display_edit_summary() -> str:
|
||||
return tracker.display_summary()
|
||||
|
||||
|
||||
def clear_tracker():
|
||||
tracker.clear()
|
||||
69
pr/ui/output.py
Normal file
69
pr/ui/output.py
Normal file
@ -0,0 +1,69 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class OutputFormatter:
|
||||
|
||||
def __init__(self, format_type: str = 'text', quiet: bool = False):
|
||||
self.format_type = format_type
|
||||
self.quiet = quiet
|
||||
|
||||
def output(self, data: Any, message_type: str = 'response'):
|
||||
if self.quiet and message_type not in ['error', 'result']:
|
||||
return
|
||||
|
||||
if self.format_type == 'json':
|
||||
self._output_json(data, message_type)
|
||||
elif self.format_type == 'structured':
|
||||
self._output_structured(data, message_type)
|
||||
else:
|
||||
self._output_text(data, message_type)
|
||||
|
||||
def _output_json(self, data: Any, message_type: str):
|
||||
output = {
|
||||
'type': message_type,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'data': data
|
||||
}
|
||||
print(json.dumps(output, indent=2))
|
||||
|
||||
def _output_structured(self, data: Any, message_type: str):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
print(f"{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
print(f"- {item}")
|
||||
else:
|
||||
print(data)
|
||||
|
||||
def _output_text(self, data: Any, message_type: str):
|
||||
if isinstance(data, (dict, list)):
|
||||
print(json.dumps(data, indent=2))
|
||||
else:
|
||||
print(data)
|
||||
|
||||
def error(self, message: str):
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'error': message}, 'error')
|
||||
else:
|
||||
print(f"Error: {message}", file=sys.stderr)
|
||||
|
||||
def success(self, message: str):
|
||||
if not self.quiet:
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'success': message}, 'success')
|
||||
else:
|
||||
print(message)
|
||||
|
||||
def info(self, message: str):
|
||||
if not self.quiet:
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'info': message}, 'info')
|
||||
else:
|
||||
print(message)
|
||||
|
||||
def result(self, data: Any):
|
||||
self.output(data, 'result')
|
||||
76
pr/ui/progress.py
Normal file
76
pr/ui/progress.py
Normal file
@ -0,0 +1,76 @@
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
class ProgressIndicator:
|
||||
|
||||
def __init__(self, message: str = "Working", show: bool = True):
|
||||
self.message = message
|
||||
self.show = show
|
||||
self.running = False
|
||||
self.thread = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.show:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.show:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._animate, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1.0)
|
||||
sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r')
|
||||
sys.stdout.flush()
|
||||
|
||||
def _animate(self):
|
||||
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||||
idx = 0
|
||||
|
||||
while self.running:
|
||||
sys.stdout.write(f'\r{spinner[idx]} {self.message}...')
|
||||
sys.stdout.flush()
|
||||
idx = (idx + 1) % len(spinner)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
|
||||
def __init__(self, total: int, description: str = "Progress", width: int = 40):
|
||||
self.total = total
|
||||
self.description = description
|
||||
self.width = width
|
||||
self.current = 0
|
||||
|
||||
def update(self, amount: int = 1):
|
||||
self.current += amount
|
||||
self._display()
|
||||
|
||||
def _display(self):
|
||||
if self.total == 0:
|
||||
percent = 100
|
||||
else:
|
||||
percent = int((self.current / self.total) * 100)
|
||||
|
||||
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
|
||||
bar = '█' * filled + '░' * (self.width - filled)
|
||||
|
||||
sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})')
|
||||
sys.stdout.flush()
|
||||
|
||||
if self.current >= self.total:
|
||||
sys.stdout.write('\n')
|
||||
|
||||
def finish(self):
|
||||
self.current = self.total
|
||||
self._display()
|
||||
90
pr/ui/rendering.py
Normal file
90
pr/ui/rendering.py
Normal file
@ -0,0 +1,90 @@
|
||||
import re
|
||||
from pr.ui.colors import Colors
|
||||
from pr.config import LANGUAGE_KEYWORDS
|
||||
|
||||
def highlight_code(code, language=None, syntax_highlighting=True):
|
||||
if not syntax_highlighting:
|
||||
return code
|
||||
|
||||
if not language:
|
||||
if 'def ' in code or 'import ' in code:
|
||||
language = 'python'
|
||||
elif 'function ' in code or 'const ' in code:
|
||||
language = 'javascript'
|
||||
elif 'public ' in code or 'class ' in code:
|
||||
language = 'java'
|
||||
|
||||
if language and language in LANGUAGE_KEYWORDS:
|
||||
keywords = LANGUAGE_KEYWORDS[language]
|
||||
for keyword in keywords:
|
||||
pattern = r'\b' + re.escape(keyword) + r'\b'
|
||||
code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code)
|
||||
|
||||
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
|
||||
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
|
||||
|
||||
code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE)
|
||||
code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE)
|
||||
|
||||
return code
|
||||
|
||||
def render_markdown(text, syntax_highlighting=True):
|
||||
if not syntax_highlighting:
|
||||
return text
|
||||
|
||||
code_blocks = []
|
||||
def extract_code_block(match):
|
||||
lang = match.group(1) or ''
|
||||
code = match.group(2)
|
||||
highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting)
|
||||
placeholder = f"%%CODEBLOCK{len(code_blocks)}%%"
|
||||
full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}'
|
||||
code_blocks.append(full_block)
|
||||
return placeholder
|
||||
|
||||
text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL)
|
||||
|
||||
inline_codes = []
|
||||
def extract_inline_code(match):
|
||||
code = match.group(1)
|
||||
placeholder = f"%%INLINECODE{len(inline_codes)}%%"
|
||||
inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}')
|
||||
return placeholder
|
||||
|
||||
text = re.sub(r'`([^`]+)`', extract_inline_code, text)
|
||||
|
||||
lines = text.split('\n')
|
||||
processed_lines = []
|
||||
for line in lines:
|
||||
if line.startswith('### '):
|
||||
line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}'
|
||||
elif line.startswith('## '):
|
||||
line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}'
|
||||
elif line.startswith('# '):
|
||||
line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}'
|
||||
elif line.startswith('> '):
|
||||
line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}'
|
||||
elif re.match(r'^\s*[\*\-\+]\s', line):
|
||||
match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
elif re.match(r'^\s*\d+\.\s', line):
|
||||
match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
processed_lines.append(line)
|
||||
text = '\n'.join(processed_lines)
|
||||
|
||||
text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text)
|
||||
text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text)
|
||||
|
||||
for i, code in enumerate(inline_codes):
|
||||
text = text.replace(f'%%INLINECODE{i}%%', code)
|
||||
for i, block in enumerate(code_blocks):
|
||||
text = text.replace(f'%%CODEBLOCK{i}%%', block)
|
||||
|
||||
return text
|
||||
5
pr/workflows/__init__.py
Normal file
5
pr/workflows/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
|
||||
from .workflow_engine import WorkflowEngine
|
||||
from .workflow_storage import WorkflowStorage
|
||||
|
||||
__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage']
|
||||
91
pr/workflows/workflow_definition.py
Normal file
91
pr/workflows/workflow_definition.py
Normal file
@ -0,0 +1,91 @@
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
|
||||
@dataclass
|
||||
class WorkflowStep:
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any]
|
||||
step_id: str
|
||||
condition: Optional[str] = None
|
||||
on_success: Optional[List[str]] = None
|
||||
on_failure: Optional[List[str]] = None
|
||||
retry_count: int = 0
|
||||
timeout_seconds: int = 300
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'tool_name': self.tool_name,
|
||||
'arguments': self.arguments,
|
||||
'step_id': self.step_id,
|
||||
'condition': self.condition,
|
||||
'on_success': self.on_success,
|
||||
'on_failure': self.on_failure,
|
||||
'retry_count': self.retry_count,
|
||||
'timeout_seconds': self.timeout_seconds
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> 'WorkflowStep':
|
||||
return WorkflowStep(
|
||||
tool_name=data['tool_name'],
|
||||
arguments=data['arguments'],
|
||||
step_id=data['step_id'],
|
||||
condition=data.get('condition'),
|
||||
on_success=data.get('on_success'),
|
||||
on_failure=data.get('on_failure'),
|
||||
retry_count=data.get('retry_count', 0),
|
||||
timeout_seconds=data.get('timeout_seconds', 300)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Workflow:
|
||||
name: str
|
||||
description: str
|
||||
steps: List[WorkflowStep]
|
||||
execution_mode: ExecutionMode = ExecutionMode.SEQUENTIAL
|
||||
variables: Dict[str, Any] = field(default_factory=dict)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'steps': [step.to_dict() for step in self.steps],
|
||||
'execution_mode': self.execution_mode.value,
|
||||
'variables': self.variables,
|
||||
'tags': self.tags
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> 'Workflow':
|
||||
return Workflow(
|
||||
name=data['name'],
|
||||
description=data['description'],
|
||||
steps=[WorkflowStep.from_dict(step) for step in data['steps']],
|
||||
execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')),
|
||||
variables=data.get('variables', {}),
|
||||
tags=data.get('tags', [])
|
||||
)
|
||||
|
||||
def add_step(self, step: WorkflowStep):
|
||||
self.steps.append(step)
|
||||
|
||||
def get_step(self, step_id: str) -> Optional[WorkflowStep]:
|
||||
for step in self.steps:
|
||||
if step.step_id == step_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_initial_steps(self) -> List[WorkflowStep]:
|
||||
if self.execution_mode == ExecutionMode.SEQUENTIAL:
|
||||
return [self.steps[0]] if self.steps else []
|
||||
elif self.execution_mode == ExecutionMode.PARALLEL:
|
||||
return self.steps
|
||||
else:
|
||||
return [step for step in self.steps if not step.condition]
|
||||
192
pr/workflows/workflow_engine.py
Normal file
192
pr/workflows/workflow_engine.py
Normal file
@ -0,0 +1,192 @@
|
||||
import time
|
||||
import re
|
||||
from typing import Dict, Any, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
|
||||
|
||||
class WorkflowExecutionContext:
|
||||
def __init__(self):
|
||||
self.variables: Dict[str, Any] = {}
|
||||
self.step_results: Dict[str, Any] = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
|
||||
def set_variable(self, name: str, value: Any):
|
||||
self.variables[name] = value
|
||||
|
||||
def get_variable(self, name: str, default: Any = None) -> Any:
|
||||
return self.variables.get(name, default)
|
||||
|
||||
def set_step_result(self, step_id: str, result: Any):
|
||||
self.step_results[step_id] = result
|
||||
|
||||
def get_step_result(self, step_id: str) -> Any:
|
||||
return self.step_results.get(step_id)
|
||||
|
||||
def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]):
|
||||
self.execution_log.append({
|
||||
'timestamp': time.time(),
|
||||
'event_type': event_type,
|
||||
'step_id': step_id,
|
||||
'details': details
|
||||
})
|
||||
|
||||
class WorkflowEngine:
|
||||
def __init__(self, tool_executor: Callable, max_workers: int = 5):
|
||||
self.tool_executor = tool_executor
|
||||
self.max_workers = max_workers
|
||||
|
||||
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
|
||||
if not condition:
|
||||
return True
|
||||
|
||||
try:
|
||||
safe_locals = {
|
||||
'variables': context.variables,
|
||||
'results': context.step_results
|
||||
}
|
||||
return eval(condition, {"__builtins__": {}}, safe_locals)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]:
|
||||
substituted = {}
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str):
|
||||
pattern = r'\$\{([^}]+)\}'
|
||||
matches = re.findall(pattern, value)
|
||||
for match in matches:
|
||||
if match.startswith('step.'):
|
||||
step_id = match.split('.', 1)[1]
|
||||
replacement = context.get_step_result(step_id)
|
||||
if replacement is not None:
|
||||
value = value.replace(f'${{{match}}}', str(replacement))
|
||||
elif match.startswith('var.'):
|
||||
var_name = match.split('.', 1)[1]
|
||||
replacement = context.get_variable(var_name)
|
||||
if replacement is not None:
|
||||
value = value.replace(f'${{{match}}}', str(replacement))
|
||||
substituted[key] = value
|
||||
else:
|
||||
substituted[key] = value
|
||||
return substituted
|
||||
|
||||
def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]:
|
||||
if not self._evaluate_condition(step.condition, context):
|
||||
context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'})
|
||||
return {'status': 'skipped', 'step_id': step.step_id}
|
||||
|
||||
arguments = self._substitute_variables(step.arguments, context)
|
||||
|
||||
start_time = time.time()
|
||||
retry_attempts = 0
|
||||
last_error = None
|
||||
|
||||
while retry_attempts <= step.retry_count:
|
||||
try:
|
||||
context.log_event('executing', step.step_id, {
|
||||
'tool': step.tool_name,
|
||||
'arguments': arguments,
|
||||
'attempt': retry_attempts + 1
|
||||
})
|
||||
|
||||
result = self.tool_executor(step.tool_name, arguments)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
context.set_step_result(step.step_id, result)
|
||||
context.log_event('completed', step.step_id, {
|
||||
'execution_time': execution_time,
|
||||
'result_size': len(str(result)) if result else 0
|
||||
})
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'step_id': step.step_id,
|
||||
'result': result,
|
||||
'execution_time': execution_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
retry_attempts += 1
|
||||
if retry_attempts <= step.retry_count:
|
||||
time.sleep(1 * retry_attempts)
|
||||
|
||||
context.log_event('failed', step.step_id, {'error': last_error})
|
||||
return {
|
||||
'status': 'failed',
|
||||
'step_id': step.step_id,
|
||||
'error': last_error,
|
||||
'execution_time': time.time() - start_time
|
||||
}
|
||||
|
||||
def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any],
|
||||
workflow: Workflow) -> List[WorkflowStep]:
|
||||
next_steps = []
|
||||
|
||||
if result['status'] == 'success' and completed_step.on_success:
|
||||
for step_id in completed_step.on_success:
|
||||
step = workflow.get_step(step_id)
|
||||
if step:
|
||||
next_steps.append(step)
|
||||
|
||||
elif result['status'] == 'failed' and completed_step.on_failure:
|
||||
for step_id in completed_step.on_failure:
|
||||
step = workflow.get_step(step_id)
|
||||
if step:
|
||||
next_steps.append(step)
|
||||
|
||||
elif workflow.execution_mode == ExecutionMode.SEQUENTIAL:
|
||||
current_index = workflow.steps.index(completed_step)
|
||||
if current_index + 1 < len(workflow.steps):
|
||||
next_steps.append(workflow.steps[current_index + 1])
|
||||
|
||||
return next_steps
|
||||
|
||||
def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext:
|
||||
context = WorkflowExecutionContext()
|
||||
|
||||
if initial_variables:
|
||||
context.variables.update(initial_variables)
|
||||
|
||||
if workflow.variables:
|
||||
context.variables.update(workflow.variables)
|
||||
|
||||
context.log_event('workflow_started', 'workflow', {'name': workflow.name})
|
||||
|
||||
if workflow.execution_mode == ExecutionMode.PARALLEL:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(self._execute_step, step, context): step
|
||||
for step in workflow.steps
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
step = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
context.log_event('step_completed', step.step_id, result)
|
||||
except Exception as e:
|
||||
context.log_event('step_failed', step.step_id, {'error': str(e)})
|
||||
|
||||
else:
|
||||
pending_steps = workflow.get_initial_steps()
|
||||
executed_step_ids = set()
|
||||
|
||||
while pending_steps:
|
||||
step = pending_steps.pop(0)
|
||||
|
||||
if step.step_id in executed_step_ids:
|
||||
continue
|
||||
|
||||
result = self._execute_step(step, context)
|
||||
executed_step_ids.add(step.step_id)
|
||||
|
||||
next_steps = self._get_next_steps(step, result, workflow)
|
||||
pending_steps.extend(next_steps)
|
||||
|
||||
context.log_event('workflow_completed', 'workflow', {
|
||||
'total_steps': len(context.step_results),
|
||||
'executed_steps': list(context.step_results.keys())
|
||||
})
|
||||
|
||||
return context
|
||||
214
pr/workflows/workflow_storage.py
Normal file
214
pr/workflows/workflow_storage.py
Normal file
@ -0,0 +1,214 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from .workflow_definition import Workflow
|
||||
|
||||
class WorkflowStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._initialize_storage()
|
||||
|
||||
def _initialize_storage(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow_id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
workflow_data TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
execution_count INTEGER DEFAULT 0,
|
||||
last_execution_at INTEGER,
|
||||
tags TEXT
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS workflow_executions (
|
||||
execution_id TEXT PRIMARY KEY,
|
||||
workflow_id TEXT NOT NULL,
|
||||
started_at INTEGER NOT NULL,
|
||||
completed_at INTEGER,
|
||||
status TEXT NOT NULL,
|
||||
execution_log TEXT,
|
||||
variables TEXT,
|
||||
step_results TEXT,
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
|
||||
)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_workflow(self, workflow: Workflow) -> str:
|
||||
import hashlib
|
||||
|
||||
workflow_data = json.dumps(workflow.to_dict())
|
||||
workflow_id = hashlib.sha256(workflow.name.encode()).hexdigest()[:16]
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
current_time = int(time.time())
|
||||
tags_json = json.dumps(workflow.tags)
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO workflows
|
||||
(workflow_id, name, description, workflow_data, created_at, updated_at, tags)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (workflow_id, workflow.name, workflow.description, workflow_data,
|
||||
current_time, current_time, tags_json))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return workflow_id
|
||||
|
||||
def load_workflow(self, workflow_id: str) -> Optional[Workflow]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
workflow_dict = json.loads(row[0])
|
||||
return Workflow.from_dict(workflow_dict)
|
||||
return None
|
||||
|
||||
def load_workflow_by_name(self, name: str) -> Optional[Workflow]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
workflow_dict = json.loads(row[0])
|
||||
return Workflow.from_dict(workflow_dict)
|
||||
return None
|
||||
|
||||
def list_workflows(self, tag: Optional[str] = None) -> List[dict]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if tag:
|
||||
cursor.execute('''
|
||||
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
|
||||
FROM workflows
|
||||
WHERE tags LIKE ?
|
||||
ORDER BY name
|
||||
''', (f'%"{tag}"%',))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
|
||||
FROM workflows
|
||||
ORDER BY name
|
||||
''')
|
||||
|
||||
workflows = []
|
||||
for row in cursor.fetchall():
|
||||
workflows.append({
|
||||
'workflow_id': row[0],
|
||||
'name': row[1],
|
||||
'description': row[2],
|
||||
'execution_count': row[3],
|
||||
'last_execution_at': row[4],
|
||||
'tags': json.loads(row[5]) if row[5] else []
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return workflows
|
||||
|
||||
def delete_workflow(self, workflow_id: str) -> bool:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted
|
||||
|
||||
def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str:
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
execution_id = str(uuid.uuid4())[:16]
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time())
|
||||
completed_at = int(time.time())
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO workflow_executions
|
||||
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
execution_id,
|
||||
workflow_id,
|
||||
started_at,
|
||||
completed_at,
|
||||
'completed',
|
||||
json.dumps(execution_context.execution_log),
|
||||
json.dumps(execution_context.variables),
|
||||
json.dumps(execution_context.step_results)
|
||||
))
|
||||
|
||||
cursor.execute('''
|
||||
UPDATE workflows
|
||||
SET execution_count = execution_count + 1,
|
||||
last_execution_at = ?
|
||||
WHERE workflow_id = ?
|
||||
''', (completed_at, workflow_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return execution_id
|
||||
|
||||
def get_execution_history(self, workflow_id: str, limit: int = 10) -> List[dict]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT execution_id, started_at, completed_at, status
|
||||
FROM workflow_executions
|
||||
WHERE workflow_id = ?
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (workflow_id, limit))
|
||||
|
||||
executions = []
|
||||
for row in cursor.fetchall():
|
||||
executions.append({
|
||||
'execution_id': row[0],
|
||||
'started_at': row[1],
|
||||
'completed_at': row[2],
|
||||
'status': row[3]
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return executions
|
||||
115
pyproject.toml
Normal file
115
pyproject.toml
Normal file
@ -0,0 +1,115 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "pr-assistant"
|
||||
version = "1.0.0"
|
||||
description = "Professional CLI AI assistant with autonomous execution capabilities"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {text = "MIT"}
|
||||
keywords = ["ai", "assistant", "cli", "automation", "openrouter", "autonomous"]
|
||||
authors = [
|
||||
{name = "retoor", email = "retoor@example.com"}
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"black>=23.0.0",
|
||||
"flake8>=6.0.0",
|
||||
"mypy>=1.0.0",
|
||||
"pre-commit>=3.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
pr = "pr.__main__:main"
|
||||
rp = "pr.__main__:main"
|
||||
rpe = "pr.editor:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/retoor/pr-assistant"
|
||||
Documentation = "https://github.com/retoor/pr-assistant#readme"
|
||||
Repository = "https://github.com/retoor/pr-assistant"
|
||||
"Bug Tracker" = "https://github.com/retoor/pr-assistant/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["pr*"]
|
||||
exclude = ["tests*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --cov=pr --cov-report=term-missing --cov-report=html"
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
__pycache__
|
||||
| \.git
|
||||
| \.mypy_cache
|
||||
| \.pytest_cache
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
disallow_incomplete_defs = false
|
||||
check_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["pr"]
|
||||
omit = ["*/tests/*", "*/__pycache__/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
"if __name__ == .__main__.:",
|
||||
"if TYPE_CHECKING:",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 100
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["tests", "venv", ".venv"]
|
||||
skips = ["B101"]
|
||||
7
rp.py
Executable file
7
rp.py
Executable file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from pr.__main__ import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
53
tests/conftest.py
Normal file
53
tests/conftest.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield tmpdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_response():
|
||||
return {
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Test response'
|
||||
}
|
||||
}
|
||||
],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 15
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
args = MagicMock()
|
||||
args.message = None
|
||||
args.model = None
|
||||
args.api_url = None
|
||||
args.model_list_url = None
|
||||
args.interactive = False
|
||||
args.verbose = False
|
||||
args.no_syntax = False
|
||||
args.include_env = False
|
||||
args.context = None
|
||||
args.api_mode = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_context_file(temp_dir):
|
||||
context_path = os.path.join(temp_dir, '.rcontext.txt')
|
||||
with open(context_path, 'w') as f:
|
||||
f.write('Sample context content\n')
|
||||
return context_path
|
||||
127
tests/test_agents.py
Normal file
127
tests/test_agents.py
Normal file
@ -0,0 +1,127 @@
|
||||
import pytest
|
||||
import time
|
||||
from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles
|
||||
from pr.agents.agent_manager import AgentManager, AgentInstance
|
||||
from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType
|
||||
|
||||
def test_get_agent_role():
|
||||
role = get_agent_role('coding')
|
||||
assert isinstance(role, AgentRole)
|
||||
assert role.name == 'coding'
|
||||
|
||||
def test_list_agent_roles():
|
||||
roles = list_agent_roles()
|
||||
assert isinstance(roles, dict)
|
||||
assert len(roles) > 0
|
||||
assert 'coding' in roles
|
||||
|
||||
def test_agent_role():
|
||||
role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[])
|
||||
assert role.name == 'test'
|
||||
|
||||
def test_agent_instance():
|
||||
role = get_agent_role('coding')
|
||||
instance = AgentInstance(agent_id='test', role=role)
|
||||
assert instance.agent_id == 'test'
|
||||
assert instance.role == role
|
||||
|
||||
def test_agent_manager_init():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
assert mgr is not None
|
||||
|
||||
def test_agent_manager_create_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
agent = mgr.create_agent('coding', 'test_agent')
|
||||
assert agent is not None
|
||||
|
||||
def test_agent_manager_get_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test_agent')
|
||||
agent = mgr.get_agent('test_agent')
|
||||
assert isinstance(agent, AgentInstance)
|
||||
|
||||
def test_agent_manager_remove_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test_agent')
|
||||
mgr.remove_agent('test_agent')
|
||||
agent = mgr.get_agent('test_agent')
|
||||
assert agent is None
|
||||
|
||||
def test_agent_manager_send_agent_message():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'a')
|
||||
mgr.create_agent('coding', 'b')
|
||||
mgr.send_agent_message('a', 'b', 'test')
|
||||
assert True
|
||||
|
||||
def test_agent_manager_get_agent_messages():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test')
|
||||
messages = mgr.get_agent_messages('test')
|
||||
assert isinstance(messages, list)
|
||||
|
||||
def test_agent_manager_get_session_summary():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
summary = mgr.get_session_summary()
|
||||
assert isinstance(summary, str)
|
||||
|
||||
def test_agent_manager_collaborate_agents():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research'])
|
||||
assert result is not None
|
||||
|
||||
def test_agent_manager_execute_agent_task():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test')
|
||||
result = mgr.execute_agent_task('test', 'task')
|
||||
assert result is not None
|
||||
|
||||
def test_agent_manager_clear_session():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.clear_session()
|
||||
assert True
|
||||
|
||||
def test_agent_message():
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
assert msg.from_agent == 'a'
|
||||
|
||||
def test_agent_message_to_dict():
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
d = msg.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
|
||||
def test_agent_message_from_dict():
|
||||
d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'}
|
||||
msg = AgentMessage.from_dict(d)
|
||||
assert isinstance(msg, AgentMessage)
|
||||
|
||||
def test_agent_communication_bus_init():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
assert bus is not None
|
||||
|
||||
def test_agent_communication_bus_send_message():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus.send_message(msg)
|
||||
assert True
|
||||
|
||||
def test_agent_communication_bus_receive_messages():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus.send_message(msg)
|
||||
messages = bus.receive_messages('b')
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_agent_communication_bus_get_conversation_history():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus.send_message(msg)
|
||||
history = bus.get_conversation_history('a', 'b')
|
||||
assert len(history) == 1
|
||||
|
||||
def test_agent_communication_bus_mark_as_read():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus.send_message(msg)
|
||||
bus.mark_as_read(msg.message_id)
|
||||
assert True
|
||||
31
tests/test_config.py
Normal file
31
tests/test_config.py
Normal file
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from pr import config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
|
||||
def test_default_model_exists(self):
|
||||
assert hasattr(config, 'DEFAULT_MODEL')
|
||||
assert isinstance(config.DEFAULT_MODEL, str)
|
||||
assert len(config.DEFAULT_MODEL) > 0
|
||||
|
||||
def test_api_url_exists(self):
|
||||
assert hasattr(config, 'DEFAULT_API_URL')
|
||||
assert config.DEFAULT_API_URL.startswith('http')
|
||||
|
||||
def test_file_paths_exist(self):
|
||||
assert hasattr(config, 'DB_PATH')
|
||||
assert hasattr(config, 'LOG_FILE')
|
||||
assert hasattr(config, 'HISTORY_FILE')
|
||||
|
||||
def test_autonomous_config(self):
|
||||
assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS')
|
||||
assert config.MAX_AUTONOMOUS_ITERATIONS > 0
|
||||
|
||||
assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD')
|
||||
assert config.CONTEXT_COMPRESSION_THRESHOLD > 0
|
||||
|
||||
def test_language_keywords(self):
|
||||
assert hasattr(config, 'LANGUAGE_KEYWORDS')
|
||||
assert 'python' in config.LANGUAGE_KEYWORDS
|
||||
assert isinstance(config.LANGUAGE_KEYWORDS['python'], list)
|
||||
35
tests/test_context.py
Normal file
35
tests/test_context.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from pr.core.context import should_compress_context, compress_context
|
||||
from pr.config import RECENT_MESSAGES_TO_KEEP
|
||||
|
||||
|
||||
class TestContextManagement:
|
||||
|
||||
def test_should_compress_context_below_threshold(self):
|
||||
messages = [{'role': 'user', 'content': 'test'}] * 10
|
||||
assert should_compress_context(messages) is False
|
||||
|
||||
def test_should_compress_context_above_threshold(self):
|
||||
messages = [{'role': 'user', 'content': 'test'}] * 35
|
||||
assert should_compress_context(messages) is True
|
||||
|
||||
def test_compress_context_preserves_system_message(self):
|
||||
messages = [
|
||||
{'role': 'system', 'content': 'System prompt'},
|
||||
{'role': 'user', 'content': 'Hello'},
|
||||
{'role': 'assistant', 'content': 'Hi'},
|
||||
] * 40 # Ensure compression
|
||||
compressed = compress_context(messages)
|
||||
assert compressed[0]['role'] == 'system'
|
||||
assert 'System prompt' in compressed[0]['content']
|
||||
|
||||
def test_compress_context_keeps_recent_messages(self):
|
||||
messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)]
|
||||
compressed = compress_context(messages)
|
||||
# Should keep recent messages
|
||||
recent = compressed[-RECENT_MESSAGES_TO_KEEP:]
|
||||
assert len(recent) == RECENT_MESSAGES_TO_KEEP
|
||||
# Check that the messages are the most recent ones
|
||||
for i, msg in enumerate(recent):
|
||||
expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i
|
||||
assert msg['content'] == f'msg{expected_index}'
|
||||
118
tests/test_tools.py
Normal file
118
tests/test_tools.py
Normal file
@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from pr.tools.filesystem import read_file, write_file, list_directory, search_replace
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
from pr.tools.base import get_tools_definition
|
||||
|
||||
|
||||
class TestFilesystemTools:
|
||||
|
||||
def test_write_and_read_file(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
content = 'Hello, World!'
|
||||
|
||||
write_result = write_file(filepath, content)
|
||||
assert write_result['status'] == 'success'
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert read_result['status'] == 'success'
|
||||
assert content in read_result['content']
|
||||
|
||||
def test_read_nonexistent_file(self):
|
||||
result = read_file('/nonexistent/path/file.txt')
|
||||
assert result['status'] == 'error'
|
||||
|
||||
def test_list_directory(self, temp_dir):
|
||||
test_file = os.path.join(temp_dir, 'testfile.txt')
|
||||
with open(test_file, 'w') as f:
|
||||
f.write('test')
|
||||
|
||||
result = list_directory(temp_dir)
|
||||
assert result['status'] == 'success'
|
||||
assert any(item['name'] == 'testfile.txt' for item in result['items'])
|
||||
|
||||
def test_search_replace(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
content = 'Hello, World!'
|
||||
with open(filepath, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
result = search_replace(filepath, 'World', 'Universe')
|
||||
assert result['status'] == 'success'
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert 'Hello, Universe!' in read_result['content']
|
||||
|
||||
|
||||
class TestPatchTools:
|
||||
|
||||
def test_create_diff(self, temp_dir):
|
||||
file1 = os.path.join(temp_dir, 'file1.txt')
|
||||
file2 = os.path.join(temp_dir, 'file2.txt')
|
||||
with open(file1, 'w') as f:
|
||||
f.write('line1\nline2\nline3\n')
|
||||
with open(file2, 'w') as f:
|
||||
f.write('line1\nline2 modified\nline3\n')
|
||||
|
||||
result = create_diff(file1, file2)
|
||||
assert result['status'] == 'success'
|
||||
assert 'line2' in result['diff']
|
||||
assert 'line2 modified' in result['diff']
|
||||
|
||||
def test_apply_patch(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'file.txt')
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('line1\nline2\nline3\n')
|
||||
|
||||
# Create a simple patch
|
||||
patch_content = """--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,3 +1,3 @@
|
||||
line1
|
||||
-line2
|
||||
+line2 modified
|
||||
line3
|
||||
"""
|
||||
result = apply_patch(filepath, patch_content)
|
||||
assert result['status'] == 'success'
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert 'line2 modified' in read_result['content']
|
||||
|
||||
|
||||
class TestToolDefinitions:
|
||||
|
||||
def test_get_tools_definition_returns_list(self):
|
||||
tools = get_tools_definition()
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) > 0
|
||||
|
||||
def test_all_tools_have_required_fields(self):
|
||||
tools = get_tools_definition()
|
||||
|
||||
for tool in tools:
|
||||
assert 'type' in tool
|
||||
assert tool['type'] == 'function'
|
||||
assert 'function' in tool
|
||||
|
||||
func = tool['function']
|
||||
assert 'name' in func
|
||||
assert 'description' in func
|
||||
assert 'parameters' in func
|
||||
|
||||
def test_filesystem_tools_present(self):
|
||||
tools = get_tools_definition()
|
||||
tool_names = [t['function']['name'] for t in tools]
|
||||
|
||||
assert 'read_file' in tool_names
|
||||
assert 'write_file' in tool_names
|
||||
assert 'list_directory' in tool_names
|
||||
assert 'search_replace' in tool_names
|
||||
|
||||
def test_patch_tools_present(self):
|
||||
tools = get_tools_definition()
|
||||
tool_names = [t['function']['name'] for t in tools]
|
||||
|
||||
assert 'apply_patch' in tool_names
|
||||
assert 'create_diff' in tool_names
|
||||
Loading…
Reference in New Issue
Block a user